这篇文章中,我简要实现一下大语言模型的 MOE 模块。MOE 模块位于每个GPT层中,位于注意力模块的后面,每个MOE模块包含若干个MLP模块作为专家。这些专家是稀疏的,也就是每次选择部分来调用,并不会调用全部,从而节省宝贵的算力。
首先定义一些常量,通常应该在模型配置文件里面。
代码语言:javascript复制bs = 5 # 批量大小
seql = 32 # 序列长度
hid = 128 # 隐藏向量维度
nexp = 5 # 专家总数
topk # 所选的专家数量
模块的输入应该是句子中单词的隐藏向量。为了便于测试我直接取了随机数,正常情况下应该是有意义的值。首先需要转换成二维的,便于计算。
代码语言:javascript复制x = torch.randn([bs, seql, hid])
x = x.reshape([-1, hid])
x.shape
# torch.Size([160, 128])
然后我们需要一个门(定义在__init__
里面,将每个隐藏向量转换成专家得分,进一步经过 softmax 转换成归一化的得分,表示每个专家对这个向量的结果有多大贡献。注意这里我们为每个向量单独分配专家,可能向量#1分配到了专家#1和#2,而向量#2分配到了专家#3和#4,总之可能是不一样的。
gate = torch.nn.Linear(hid, nexp)
exp_logits = gate(x)
exp_probs = torch.softmax(exp_logits, -1)
exp_probs.shape
# torch.Size([160, 5])
每个专家应该是 MLP(定义在__init__
里面),但是为了演示我就直接省略了,大家可以从各个大语言模型的源码里面复制粘贴:
experts = [lambda x: x for _ in range(nexp)]
对每个向量分配到的专家按照贡献度排序,得到每个向量地专家排名exp_topk
及其得分sc_topk
。
exp_topk[i, j]
表示第i
个词的第j
个专家的序号,sc_topk[i, j]
表示它的得分。
sc_topk, exp_topk = torch.topk(exp_probs, topk, -1)
sc_topk.shape
# torch.Size([160, 2])
exp_topk.shape
# torch.Size([160, 2])
将专家的得分归一化,因为我们选了两个,总和又不是一了,会对结果的大小有影响:
代码语言:javascript复制sc_topk /= sc_topk.sum(-1, keepdim=True)
下面我们创建该层的结果数组,累加每个专家的输出,大小和输入一样:
代码语言:javascript复制final_hidden_state = torch.zeros_like(x)
然后我们获取每个专家对应的单词序号,和对应的单词排名。exp_topk == exp_i
把等于专家exp_i
的位置标注为True
其它的为False
,然后where
获取下标。
hid_idcs
是调用专家exp_i
的向量序号,hid_ranks
是该专家对于对应向量的排名
for exp_i in range(nexp):
hid_idcs, hid_ranks = torch.where(exp_topk == exp_i)
注意每个专家被调用的次数都可能不一样:
代码语言:javascript复制[torch.where(exp_topk == exp_i) for exp_i in range(nexp)]
'''
[tensor([ 0, 1, 2, 3, 14, 16, 18, 21, 22, 30, 32, 39, 43, 44,
45, 52, 55, 58, 66, 67, 72, 77, 78, 80, 83, 87, 89, 90,
91, 93, 102, 103, 105, 107, 108, 115, 116, 117, 126, 131, 133, 134,
135, 136, 137, 146, 147, 148, 149, 151, 157, 158]),
tensor([ 6, 8, 9, 11, 18, 19, 20, 23, 26, 27, 28, 31, 34, 35,
37, 41, 47, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 71,
74, 75, 77, 78, 79, 82, 83, 84, 86, 93, 97, 98, 100, 107,
109, 110, 111, 113, 114, 118, 120, 123, 124, 126, 127, 128, 129, 130,
139, 140, 143, 144, 145, 150, 155, 159]),
tensor([ 0, 4, 7, 8, 10, 12, 13, 14, 16, 17, 24, 25, 26, 29,
32, 33, 34, 36, 40, 41, 46, 47, 49, 50, 53, 58, 64, 65,
68, 70, 72, 73, 76, 81, 82, 85, 88, 89, 92, 94, 101, 103,
108, 109, 112, 114, 115, 119, 120, 121, 123, 125, 132, 133, 135, 138,
139, 140, 141, 142, 145, 146, 147, 150, 152, 153, 155, 156, 158]),
tensor([ 1, 5, 6, 7, 9, 11, 12, 13, 15, 20, 22, 23, 28, 29,
30, 31, 35, 37, 38, 40, 42, 46, 48, 54, 55, 56, 57, 60,
61, 62, 64, 65, 67, 69, 70, 71, 73, 74, 79, 80, 81, 84,
86, 95, 96, 98, 99, 102, 104, 106, 110, 111, 113, 116, 118, 119,
122, 125, 128, 129, 132, 134, 138, 144, 153, 154, 157, 159]),
tensor([ 2, 3, 4, 5, 10, 15, 17, 19, 21, 24, 25, 27, 33, 36,
38, 39, 42, 43, 44, 45, 48, 49, 51, 52, 59, 61, 63, 66,
68, 69, 75, 76, 85, 87, 88, 90, 91, 92, 94, 95, 96, 97,
99, 100, 101, 104, 105, 106, 112, 117, 121, 122, 124, 127, 130, 131,
136, 137, 141, 142, 143, 148, 149, 151, 152, 154, 156])]
'''
然后我们把每个专家的向量获取到(x[hid_idcs]
),传入该专家experts[exp_i](...)
:
# for ...
hidden_state = experts[exp_i](x[hid_idcs])
hidden_state.shape
# torch.Size([52, 128])
然后需要乘上专家权重,最后加一维以便权重和上面的向量对齐:
代码语言:javascript复制# for ...
weights = sc_topk[hid_idcs, hid_ranks].unsqueeze(-1)
weights.shape
# torch.Size([52, 1])
hidden_state *= weights
然后将当前专家的输出填回到结果数组中:
代码语言:javascript复制# for ...
final_hidden_state[hid_idcs] = hidden_state
每个专家都计算完之后,将结果数组变形成原始的形状,然后作为整个模块的输出:
代码语言:javascript复制final_hidden_state = final_hidden_state.reshape([bs, seql, hid])