我们在前面的文章介绍了研究人员推出了一种挑战Transformer的新架构Mamba
他们的研究表明,Mamba是一种状态空间模型(SSM),在不同的模式(如语言、音频和时间序列)中表现出卓越的性能。为了说明这一点,研究人员使用Mamba-3B模型进行了语言建模实验。该模型超越了基于相同大小的Transformer的其他模型,并且在预训练和下游评估期间,它的表现与大小为其两倍的Transformer模型一样好。
Mamba的独特之处在于它的快速处理能力,选择性SSM层,以及受FlashAttention启发的硬件友好设计。这些特点使Mamba超越Transformer(Transformer没有了传统的注意力和MLP块)。
有很多人希望自己测试Mamba的效果,所以本文整理了一个能够在Colab上完整运行Mamba代码,代码中还使用了Mamba官方的3B模型来进行实际运行测试。
首先我们安装依赖,这是官网介绍的:
代码语言:javascript复制 !pip install causal-conv1d==1.0.0
!pip install mamba-ssm==1.0.1
然后直接使用transformers库读取预训练的Mamba-3B
代码语言:javascript复制 import torch
import os
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)
可以看到,3b的模型有11G
然后就是测试生成内容
代码语言:javascript复制 tokens = tokenizer("What is the meaning of life", return_tensors="pt")
input_ids = tokens.input_ids.to(device="cuda")
max_length = input_ids.shape[1] 80
fn = lambda: model.generate(
input_ids=input_ids, max_length=max_length, cg=True,
return_dict_in_generate=True, output_scores=True,
enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)
out = fn()
print(tokenizer.decode(out[0][0]))
这里还有一个chat的示例
代码语言:javascript复制 import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)
messages = []
user_message = """
What is the date for announcement
On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore.
"""
messages.append(dict(role="user",content=user_message))
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(out)
messages.append(dict(role="assistant",content=decoded[0].split("<|assistant|>n")[-1]))
print("Model:", decoded[0].split("<|assistant|>n")[-1])
这里我将所有代码整理成了Colab Notebook,有兴趣的可以直接使用:
https://colab.research.google.com/drive/1JyZpvncfSvtFZNOr3TU17Ff0BW5Nd_my?usp=sharing