背景
ChatGLM3-6B是10月底最新发布的智谱AI语言大模型。效果确实有明显的进步。但从文档上来看,仅有几个Demo以及B站官网视频 https://www.bilibili.com/video/BV1uC4y1J7yA 可供参考。但如果希望深入研究,关键的调用:
代码语言:txt复制model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
return_past_key_values=True,
max_length=max_length,
top_p=top_p,
temperature=temperature)
到底每个参数是什么含义?
由于Huggingface上、modelscope.cn上以及chatglm的github上,都没有详细的核心接口说明。全网检索很久,也没有找到答案。最后经过研究,可以通过源码文件来了解:https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py
本文通过给出相关接口注释,帮助大家了解相关接口的用法。
源码溯源
在huggingface的ChatGLM3-6B的主页中,点击Files标签页。
可以发现modeling_chatglm.py文件,接口代码即在其中。
接口注释
聊天函数
代码语言:python代码运行次数:0复制 @torch.inference_mode()
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
**kwargs):
"""
聊天函数,接受一段文本查询,返回模型的响应。
参数:
tokenizer: 用于处理输入和输出文本的tokenizer对象。
query (str): 用户的文本输入。
history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
max_length (int, 可选): 生成文本的最大长度。默认为8192。
num_beams (int, 可选): Beam搜索的宽度,如果值大于1,则使用Beam搜索。默认为1。
do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
**kwargs: 其他传递给模型生成函数的参数。
返回:
response (str): 模型的响应文本。
history (List[Dict]): 更新后的对话历史。
"""
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to(self.device)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = self.process_response(response, history)
return response, history
流式聊天函数
代码语言:python代码运行次数:0复制 @torch.inference_mode()
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
logits_processor=None, return_past_key_values=False, **kwargs):
"""
流式聊天函数,接受一段文本查询,返回模型的响应。这个函数是一个生成器,可以在流式处理中使用。
参数:
tokenizer: 用于处理输入和输出文本的tokenizer对象。
query (str): 用户的文本输入。
history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
past_key_values (List[Tensor], 可选): 用于transformer模型的过去的键值对。默认为None。
max_length (int, 可选): 生成文本的最大长度。默认为8192。
do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
return_past_key_values (bool, 可选): 是否返回过去的键值对,用于下一步的生成。默认为False。
**kwargs: 其他传递给模型生成函数的参数。
返回:
response (str): 模型的响应文本。
history (List[Dict]): 更新后的对话历史。
past_key_values (List[Tensor], 可选): 如果return_past_key_values为True,返回用于下一步生成的过去的键值对。
"""
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
inputs = tokenizer.build_chat_input(query, role=role)
inputs = inputs.to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids = past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
response, new_history = self.process_response(response, history)
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history