Jamba: 混合Transformer和Mamba的语言模型.

0. TL;DR

本文介绍了 Jamba,一种基于混合 Transformer-Mamba 架构的新型语言模型。Jamba 通过结合 TransformerMamba 层,以及引入混合专家(MoE)模块,实现了高性能和高吞吐量,同时支持超长上下文处理。实验结果表明,Jamba 在标准语言模型基准测试和长上下文评估中均表现出色,特别是在处理长达 256K tokens 的上下文时。Jamba 的混合架构不仅提高了模型的性能,还显著降低了内存使用和计算需求。

1. 背景介绍

近年来,Transformer 模型在自然语言处理领域取得了巨大成功,但其高内存和计算需求限制了对长序列的处理能力。另一方面,状态空间模型(SSMs)如 Mamba 在处理长序列方面表现出更高的效率,但在性能上仍落后于同等规模的 Transformer 模型。为了结合两者的优点,本文提出了 Jamba,一种混合 Transformer-Mamba 架构,通过引入 MoE 模块,进一步提高了模型的容量和性能。

2. Jamba 模型

Jamba 是一种混合解码器架构,它结合了 Transformer 层、Mamba 层 和 混合专家 (MoE) 模块。每个 Jamba 块包含以下组件:

Jamba 通过混合 TransformerMamba 层,平衡了内存使用、吞吐量和性能。具体来说,每个 Jamba 块中的层以一定的比例混合 TransformerMamba 层。在论文中每个块包含 $8$ 层,比例为 $1:7$($1$ 层 Transformer 和 $7$ 层 Mamba)。

MoE 模块允许增加模型的容量,同时保持活动参数的数量可控。在 Jamba 中,MoE 模块被应用于某些 MLP 层。具体来说,每隔两层使用一次 MoE,每个 MoE 层包含 $16$ 个专家,每个 token 使用前 $2$ 个专家。

Jamba与最近公开的模型进行了比较,显示了其即使在256K的标记上下文中也能保持小型KV缓存的优势。

3. 实验分析

Jamba 在多个学术基准测试中表现出色,与同规模或更大规模的公开可用模型(如 Llama-2 70BMixtral)相当。具体来说,Jamba 在常识推理、阅读理解和其他任务的数据集(如 HellaSwag、WinoGrande、ARC-E、ARC-Challenge、PIQA、BoolQ、QuAC、GSM8K、HumanEval、MMLU 和 BBH)上均取得了较高的性能。

Jamba 在长上下文评估中表现出色,特别是在处理长达 256K tokens 的上下文时。具体来说,Jambaneedle-in-a-haystack 评估(在长文本中检索一个状态)中取得了优异的成绩,这表明其能够有效地从长上下文中检索信息。此外,Jamba 在分类任务数据集(Trec-Fine、NLU Intent、Banking77、CLINC150)以及问答数据集(LongFQA、CUAD、NarrativeQA、NQ、SFiction)中也表现出色,特别是在处理长输入时。

Jamba 在吞吐量方面表现出显著优势,特别是在处理长上下文时。具体来说,Jamba 的吞吐量是 Mixtral 的 $3$ 倍,这使其在实际应用中更具优势。