大语言模型 MOE 简明实现指南

2024-06-22 08:49:13 浏览数 (2)

这篇文章中,我简要实现一下大语言模型的 MOE 模块。MOE 模块位于每个GPT层中,位于注意力模块的后面,每个MOE模块包含若干个MLP模块作为专家。这些专家是稀疏的,也就是每次选择部分来调用,并不会调用全部,从而节省宝贵的算力。

首先定义一些常量,通常应该在模型配置文件里面。

代码语言:javascript复制
bs = 5 # 批量大小
seql = 32 # 序列长度
hid = 128 # 隐藏向量维度
nexp = 5 # 专家总数
topk # 所选的专家数量

模块的输入应该是句子中单词的隐藏向量。为了便于测试我直接取了随机数,正常情况下应该是有意义的值。首先需要转换成二维的,便于计算。

代码语言:javascript复制
x = torch.randn([bs, seql, hid])
x = x.reshape([-1, hid])
x.shape
# torch.Size([160, 128])

然后我们需要一个门(定义在__init__里面,将每个隐藏向量转换成专家得分,进一步经过 softmax 转换成归一化的得分,表示每个专家对这个向量的结果有多大贡献。注意这里我们为每个向量单独分配专家,可能向量#1分配到了专家#1和#2,而向量#2分配到了专家#3和#4,总之可能是不一样的。

代码语言:javascript复制
gate = torch.nn.Linear(hid, nexp)
exp_logits = gate(x)
exp_probs = torch.softmax(exp_logits, -1)
exp_probs.shape
# torch.Size([160, 5])

每个专家应该是 MLP(定义在__init__里面),但是为了演示我就直接省略了,大家可以从各个大语言模型的源码里面复制粘贴:

代码语言:javascript复制
experts = [lambda x: x for _ in range(nexp)]

对每个向量分配到的专家按照贡献度排序,得到每个向量地专家排名exp_topk及其得分sc_topk

exp_topk[i, j]表示第i个词的第j个专家的序号,sc_topk[i, j]表示它的得分。

代码语言:javascript复制
sc_topk, exp_topk = torch.topk(exp_probs, topk, -1)
sc_topk.shape
# torch.Size([160, 2])
exp_topk.shape
# torch.Size([160, 2])

将专家的得分归一化,因为我们选了两个,总和又不是一了,会对结果的大小有影响:

代码语言:javascript复制
sc_topk /= sc_topk.sum(-1, keepdim=True)

下面我们创建该层的结果数组,累加每个专家的输出,大小和输入一样:

代码语言:javascript复制
final_hidden_state = torch.zeros_like(x)

然后我们获取每个专家对应的单词序号,和对应的单词排名。exp_topk == exp_i把等于专家exp_i的位置标注为True其它的为False,然后where获取下标。

hid_idcs是调用专家exp_i的向量序号,hid_ranks是该专家对于对应向量的排名

代码语言:javascript复制
for exp_i in range(nexp):
    hid_idcs, hid_ranks = torch.where(exp_topk == exp_i)

注意每个专家被调用的次数都可能不一样:

代码语言:javascript复制
[torch.where(exp_topk == exp_i) for exp_i in range(nexp)]
'''
[tensor([  0,   1,   2,   3,  14,  16,  18,  21,  22,  30,  32,  39,  43,  44,
          45,  52,  55,  58,  66,  67,  72,  77,  78,  80,  83,  87,  89,  90,
          91,  93, 102, 103, 105, 107, 108, 115, 116, 117, 126, 131, 133, 134,
         135, 136, 137, 146, 147, 148, 149, 151, 157, 158]),
 tensor([  6,   8,   9,  11,  18,  19,  20,  23,  26,  27,  28,  31,  34,  35,
          37,  41,  47,  50,  51,  53,  54,  56,  57,  59,  60,  62,  63,  71,
          74,  75,  77,  78,  79,  82,  83,  84,  86,  93,  97,  98, 100, 107,
         109, 110, 111, 113, 114, 118, 120, 123, 124, 126, 127, 128, 129, 130,
         139, 140, 143, 144, 145, 150, 155, 159]),
 tensor([  0,   4,   7,   8,  10,  12,  13,  14,  16,  17,  24,  25,  26,  29,
          32,  33,  34,  36,  40,  41,  46,  47,  49,  50,  53,  58,  64,  65,
          68,  70,  72,  73,  76,  81,  82,  85,  88,  89,  92,  94, 101, 103,
         108, 109, 112, 114, 115, 119, 120, 121, 123, 125, 132, 133, 135, 138,
         139, 140, 141, 142, 145, 146, 147, 150, 152, 153, 155, 156, 158]),
 tensor([  1,   5,   6,   7,   9,  11,  12,  13,  15,  20,  22,  23,  28,  29,
          30,  31,  35,  37,  38,  40,  42,  46,  48,  54,  55,  56,  57,  60,
          61,  62,  64,  65,  67,  69,  70,  71,  73,  74,  79,  80,  81,  84,
          86,  95,  96,  98,  99, 102, 104, 106, 110, 111, 113, 116, 118, 119,
         122, 125, 128, 129, 132, 134, 138, 144, 153, 154, 157, 159]),
 tensor([  2,   3,   4,   5,  10,  15,  17,  19,  21,  24,  25,  27,  33,  36,
          38,  39,  42,  43,  44,  45,  48,  49,  51,  52,  59,  61,  63,  66,
          68,  69,  75,  76,  85,  87,  88,  90,  91,  92,  94,  95,  96,  97,
          99, 100, 101, 104, 105, 106, 112, 117, 121, 122, 124, 127, 130, 131,
         136, 137, 141, 142, 143, 148, 149, 151, 152, 154, 156])]
'''

然后我们把每个专家的向量获取到(x[hid_idcs]),传入该专家experts[exp_i](...)

代码语言:javascript复制
# for ...
    hidden_state = experts[exp_i](x[hid_idcs])
    hidden_state.shape
    # torch.Size([52, 128])

然后需要乘上专家权重,最后加一维以便权重和上面的向量对齐:

代码语言:javascript复制
# for ...
    weights = sc_topk[hid_idcs, hid_ranks].unsqueeze(-1)
    weights.shape
    # torch.Size([52, 1])
    hidden_state *= weights

然后将当前专家的输出填回到结果数组中:

代码语言:javascript复制
# for ...
    final_hidden_state[hid_idcs]  = hidden_state

每个专家都计算完之后,将结果数组变形成原始的形状,然后作为整个模块的输出:

代码语言:javascript复制
final_hidden_state = final_hidden_state.reshape([bs, seql, hid])

0 人点赞