该工作由忆生科技创始团队成员及旗下研究人员深度参与,由多所学校与机构的研究者共同完成,包括加州大学伯克利分校、宾夕法尼亚大学、密歇根大学、清华大学、忆生科技、香港大学、约翰·霍普金斯大学等。
忆生科技创始人马毅教授已受邀在今年四月的ICLR大会上就和此项成果相关的一系列白盒神经网络相关工作,进行为时一小时的主题报告(Keynote)。
论文标题:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
论文地址: https://arxiv.org/abs/2412.17810
项目主页: https://robinwu218.github.io/ToST/
项目已开源:https://github.com/RobinWu218/ToST
顶会收录:ICLR2025
媒体详尽报道:首个基于统计学的线性注意力机制ToST,高分拿下ICLR Spotlight
下图中Transcengram 即忆生科技
核心信息:
在论文《Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction》中,作者提出了一种新的Transformer注意力机制,称为Token Statistics Self-Attention(TSSA),其计算复杂度随token数量线性增长。该方法通过引入变分形式的最大编码率减少(MCR²)目标,推导出一种新的注意力模块,旨在解决传统自注意力机制计算复杂度高的问题。
主要内容概述:
研究背景:
问题描述:Transformer架构在各类任务中表现出色,但其自注意力机制的计算和内存复杂度随着token数量呈二次方增长,限制了其在长序列任务中的应用。
现有方法:为降低自注意力的计算复杂度,已有方法包括分块处理、滑动窗口、低秩近似和Nystrom扩展等,但这些方法仍需计算token间的相似度。
变分速率减少(Variational Rate Reduction):
目标:通过最大化特征表示的编码率,提升模型对数据结构的捕捉能力。
方法:引入变分形式的MCR²目标,将其分解为易于优化的部分,避免直接计算高维度的协方差矩阵行列式。
Token Statistics Self-Attention(TSSA):
推导过程:通过对变分MCR²目标进行梯度下降,推导出TSSA模块。
核心思想:TSSA不再计算token间的两两相似度,而是基于token的统计信息(如均值和方差)进行注意力计算,从而将计算复杂度从二次方降低到线性级别。
Token Statistics Transformer(ToST):
架构设计:将TSSA模块替代传统自注意力机制,构建新的Transformer架构ToST。
优势:ToST在保持或提升性能的同时,显著降低了计算和内存开销。
实验验证:
任务类型:在视觉、语言和长序列任务上进行测试。
性能表现:实验结果显示,ToST在多个基准数据集上与传统Transformer性能相当,且在计算效率上具有明显优势。
结论:
研究贡献:提出了一种新的注意力机制TSSA,突破了传统自注意力计算复杂度高的瓶颈,为处理长序列任务提供了新的思路。
未来工作:计划进一步优化TSSA的实现,并探索其在更多应用场景中的潜力。
参考文献:
Wu, Z., Ding, T., Lu, Y., Pai, D., Zhang, J., Wang, W., Yu, Y., Ma, Y., & Haeffele, B. D. (2024). Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction. arXiv preprint arXiv:2412.17810. https://arxiv.org/abs/2412.17810
附注:
该论文的代码和更多信息已在项目主页发布:https://github.com/RobinWu218/ToST