概述
Github官方地址:GLM-4
网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。
微调
微调格式:
代码语言:javascript复制[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
}
]
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
]
微调源码地址:finetune.py Loss计算代码:
代码语言:javascript复制def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
# batched_conv 是一个数组
# conv 是数组内的单个 message
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
# conv 是数组内的单个 message
# message 是 单个role json对象
for message in conv:
message = process_message(message)
# 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
# 获取 input 文本的数字表示(ids)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
# 计算整句的 mask
new_loss_masks = [loss_mask_val] * len(new_input_ids)
# 拼接message中的每段json
input_ids = new_input_ids
# 拼接message中每段json对应的mask
loss_masks = new_loss_masks
# 追加结尾的 token id
input_ids.append(tokenizer.eos_token_id)
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
# 添加到label,计算loss
labels.append(input_id)
else:
# -100 不处理,即ignore_index
labels.append(-100)
max_length = max_input_length max_output_length 1
# 截断
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}
注释在代码中已经写明。process_batch
方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)
Loss计算如下图所示: