ChatGLM3-6B的Transformers.Model的核心接口说明

2023-11-13 12:40:55 浏览数 (1)

背景

ChatGLM3-6B是10月底最新发布的智谱AI语言大模型。效果确实有明显的进步。但从文档上来看,仅有几个Demo以及B站官网视频 https://www.bilibili.com/video/BV1uC4y1J7yA 可供参考。但如果希望深入研究,关键的调用:

代码语言:txt复制
model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
                            return_past_key_values=True,
                            max_length=max_length, 
                            top_p=top_p,
                            temperature=temperature)

到底每个参数是什么含义?

由于Huggingface上、modelscope.cn上以及chatglm的github上,都没有详细的核心接口说明。全网检索很久,也没有找到答案。最后经过研究,可以通过源码文件来了解:https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py

本文通过给出相关接口注释,帮助大家了解相关接口的用法。

源码溯源

在huggingface的ChatGLM3-6B的主页中,点击Files标签页。

可以发现modeling_chatglm.py文件,接口代码即在其中。

接口注释

聊天函数

代码语言:python代码运行次数:0复制
    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        """
        聊天函数,接受一段文本查询,返回模型的响应。

        参数:
            tokenizer: 用于处理输入和输出文本的tokenizer对象。
            query (str): 用户的文本输入。
            history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
            role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
            max_length (int, 可选): 生成文本的最大长度。默认为8192。
            num_beams (int, 可选): Beam搜索的宽度,如果值大于1,则使用Beam搜索。默认为1。
            do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
            top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
            temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
            logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
            **kwargs: 其他传递给模型生成函数的参数。

        返回:
            response (str): 模型的响应文本。
            history (List[Dict]): 更新后的对话历史。
        """
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        inputs = inputs.to(self.device)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        response = tokenizer.decode(outputs)
        history.append({"role": role, "content": query})
        response, history = self.process_response(response, history)
        return response, history

流式聊天函数

代码语言:python代码运行次数:0复制
    @torch.inference_mode()
    def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
                    past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
                    logits_processor=None, return_past_key_values=False, **kwargs):
        """
        流式聊天函数,接受一段文本查询,返回模型的响应。这个函数是一个生成器,可以在流式处理中使用。

        参数:
            tokenizer: 用于处理输入和输出文本的tokenizer对象。
            query (str): 用户的文本输入。
            history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
            role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
            past_key_values (List[Tensor], 可选): 用于transformer模型的过去的键值对。默认为None。
            max_length (int, 可选): 生成文本的最大长度。默认为8192。
            do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
            top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
            temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
            logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
            return_past_key_values (bool, 可选): 是否返回过去的键值对,用于下一步的生成。默认为False。
            **kwargs: 其他传递给模型生成函数的参数。

        返回:
            response (str): 模型的响应文本。
            history (List[Dict]): 更新后的对话历史。
            past_key_values (List[Tensor], 可选): 如果return_past_key_values为True,返回用于下一步生成的过去的键值对。
        """
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        if past_key_values is None:
            inputs = tokenizer.build_chat_input(query, history=history, role=role)
        else:
            inputs = tokenizer.build_chat_input(query, role=role)
        inputs = inputs.to(self.device)
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[0]
            if self.transformer.pre_seq_len is not None:
                past_length -= self.transformer.pre_seq_len
            inputs.position_ids  = past_length
            attention_mask = inputs.attention_mask
            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
            inputs['attention_mask'] = attention_mask
        history.append({"role": role, "content": query})
        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
                                            **gen_kwargs):
            if return_past_key_values:
                outputs, past_key_values = outputs
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
            response = tokenizer.decode(outputs)
            if response and response[-1] != "�":
                response, new_history = self.process_response(response, history)
                if return_past_key_values:
                    yield response, new_history, past_key_values
                else:
                    yield response, new_history

0 人点赞