alphaFold2 | 模型细节之Evoformer

2022-11-22 15:14:33 浏览数 (2)

<<AlphaFold2专题>>

alphaFold2 | 解决问题及背景(一)

alphaFold2 | 模型框架搭建(二)

alphaFold2 | 模型细节之特征提取(三)

  • 文章转自微信公众号:机器学习炼丹术
  • 作者:陈亦新(欢迎交流共同进步)
  • Evoformer
    • 1.1 FeedForward
    • 1.3 MsaAttentionBlock
    • 1.4 gating
    • 1.5 PairwiseAttentionBlock

    上一篇文章谈了一下MSA和pair representation特征的构建,现在我们来看模型结构了。

Evoformer

代码语言:javascript复制
x, m = self.net(
            x,
            m,
            mask = x_mask,
            msa_mask = msa_mask
        )

x就是我们的pair特征,m就是我们的msa特征。net的构建定义为:

代码语言:javascript复制
self.net = Evoformer(
            dim = dim,
            depth = depth,
            seq_len = max_seq_len,
            heads = heads,
            dim_head = dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout
        )

OK。可以看到这个net就是我们要找的Evoformer的部分了。

代码语言:javascript复制

class Evoformer(nn.Module):
    def __init__(
        self,
        *,
        depth,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([EvoformerBlock(**kwargs) for _ in range(depth)])

    def forward(
        self,
        x,
        m,
        mask = None,
        msa_mask = None
    ):
        inp = (x, m, mask, msa_mask)
        x, m, *_ = checkpoint_sequential(self.layers, 1, inp)
        return x, m

可以看到这个Evoformer类,非常的短。就是一个空壳罢了,里面核心还是在EvoformerBlock这个组件的定义当中。

代码语言:javascript复制
class EvoformerBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        seq_len,
        heads,
        dim_head,
        attn_dropout,
        ff_dropout,
        global_column_attn = False
    ):
        super().__init__()
        self.layer = nn.ModuleList([
            PairwiseAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, global_column_attn = global_column_attn),
            FeedForward(dim = dim, dropout = ff_dropout),
            MsaAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout),
            FeedForward(dim = dim, dropout = ff_dropout),
        ])

    def forward(self, inputs):
        x, m, mask, msa_mask = inputs
        attn, ff, msa_attn, msa_ff = self.layer

        # msa attention and transition

        m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
        m = msa_ff(m)   m

        # pairwise attention and transition

        x = attn(x, mask = mask, msa_repr = m, msa_mask = msa_mask)
        x = ff(x)   x

        return x, m, mask, msa_mask

分成这三个组件:

  • PairwiseAttentionBlock
  • FeedForward
  • MsaAttentionBlock

我们来看论文当中,补充材料的部分:

1.1 FeedForward

代码语言:javascript复制
class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )
        init_zero_(self.net[-1])

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.net(x)

这个模块就是LayerNorm在加上全连接层。这个GEGLU倒是一个新的激活函数,咱们看看:

代码语言:javascript复制
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

1.3 MsaAttentionBlock

代码语言:javascript复制
class MsaAttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        seq_len,
        heads,
        dim_head,
        dropout = 0.
    ):
        super().__init__()
        self.row_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
        self.col_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True)

    def forward(
        self,
        x,
        mask = None,
        pairwise_repr = None
    ):
    # 还有加上了残差链接
        x = self.row_attn(x, mask = mask, edges = pairwise_repr)   x
        x = self.col_attn(x, mask = mask)   x
        return x

这个部分的组件其实就对应论文的这个部分:

然后分成:

  • Row-wise gated self-attention with pair bias
  • Column-size gated self-attention
代码语言:javascript复制
class AxialAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        row_attn = True,
        col_attn = True,
        accept_edges = False,
        global_query_attn = False,
        **kwargs
    ):
        super().__init__()
        assert not (not row_attn and not col_attn), 'row or column attention must be turned on'

        self.row_attn = row_attn
        self.col_attn = col_attn
        self.global_query_attn = global_query_attn

        self.norm = nn.LayerNorm(dim)

        self.attn = Attention(dim = dim, heads = heads, **kwargs)

        self.edges_to_attn_bias = nn.Sequential(
            nn.Linear(dim, heads, bias = False),
            Rearrange('b i j h -> b h i j')
        ) if accept_edges else None

    def forward(self, x, edges = None, mask = None):
        assert self.row_attn ^ self.col_attn, 'has to be either row or column attention, but not both'
        # x就是我们提取的mse特征,edges就是我们的pair representation特征
        # b是batch h是msa的数量,1个target 5个数据库中匹配的
        # w是128也就是染色体序列数量,d是token特征维度256
        b, h, w, d = x.shape
        
        x = self.norm(x)

        # axial attention

        if self.col_attn:
            axial_dim = w
            mask_fold_axial_eq = 'b h w -> (b w) h'
            input_fold_eq = 'b h w d -> (b w) h d'
            output_fold_eq = '(b w) h d -> b h w d'

        elif self.row_attn:
            axial_dim = h
            mask_fold_axial_eq = 'b h w -> (b h) w'
            input_fold_eq = 'b h w d -> (b h) w d'
            output_fold_eq = '(b h) w d -> b h w d'
        
        x = rearrange(x, input_fold_eq)

        if exists(mask):
            mask = rearrange(mask, mask_fold_axial_eq)

        attn_bias = None
        if exists(self.edges_to_attn_bias) and exists(edges):
            attn_bias = self.edges_to_attn_bias(edges)
            attn_bias = repeat(attn_bias, 'b h i j -> (b x) h i j', x = axial_dim)

        tie_dim = axial_dim if self.global_query_attn else None

        out = self.attn(x, mask = mask, attn_bias = attn_bias, tie_dim = tie_dim)
        out = rearrange(out, output_fold_eq, h = h, w = w)

        return out

可以看到,对行和列做自注意力机制其实是相似的操作:

  • 对列进行操作:那么x先被rearrange成(b w) h d的形式。所以这里面自注意力的维度,是h维度。h维度msa的数量,所以这种情况下产生的self-attention map其实就是氨基酸序列数量之前的相关性矩阵。
  • 同理对于row attention,那么就是对w进行自注意力机制,那么构建的self-attention map就是每一个氨基酸之间的相关性,假设序列长度为128,那么就是128x128的attention map。

在row attention当中,还会引入edges的概念。也就是让pair representation特征引入。edges输入进来的形状应该是(batch,128,128,256)的特征,先经过self.edges_to_attn_bias的处理,简单地说就是一个全连接层。这个全连接层是这个样子的:

这里面的heads=8,所以输出的特征形状是(batch,128,128,8),然后经过rearrange,变成了(batch,8,128,128).

经过repeat的操作,上面的形状变成(axial_dim,8,128,128)这里batch=1省略了。

最后就是把msa特征,对应的attn_bias特征放到attn这个模型类当中去。

我们来看看这个Attention类如何实现的:

代码语言:javascript复制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        seq_len = None,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        gating = True
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.seq_len = seq_len
        self.heads= heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.gating = nn.Linear(dim, inner_dim)
        nn.init.constant_(self.gating.weight, 0.)
        nn.init.constant_(self.gating.bias, 1.)

        self.dropout = nn.Dropout(dropout)
        init_zero_(self.to_out)

    def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None):
        device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)

        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        i, j = q.shape[-2], k.shape[-2]

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # scale

        q = q * self.scale

        # query / key similarities

        if exists(tie_dim):
            # as in the paper, for the extra MSAs
            # they average the queries along the rows of the MSAs
            # they named this particular module MSAColumnGlobalAttention

            q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k))
            q = q.mean(dim = 1)

            dots = einsum('b h i d, b r h j d -> b r h i j', q, k)
            dots = rearrange(dots, 'b r ... -> (b r) ...')
        else:
            dots = einsum('b h i d, b h j d -> b h i j', q, k)

        # add attention bias, if supplied (for pairwise to msa attention communication)

        if exists(attn_bias):
            dots = dots   attn_bias

        # masking

        if exists(mask):
            mask = default(mask, lambda: torch.ones(1, i, device = device).bool())
            context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool())
            mask_value = -torch.finfo(dots.dtype).max
            mask = mask[:, None, :, None] * context_mask[:, None, None, :]
            dots = dots.masked_fill(~mask, mask_value)

        # attention

        dots = dots - dots.max(dim = -1, keepdims = True).values
        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')

        # gating

        gates = self.gating(x)
        out = out * gates.sigmoid()

        # combine to out

        out = self.to_out(out)
        return out

forward函数当中,先产生qkv。head是8,inner_dim=64,所以这里面有8个64特征产生。然后通过rearrange的操作,把qkv的最后一个维度从8x64拆分成两个维度。

这里如何把(128,128,8)的pair representation特征加入呢?很简单,因为我们在做row-size attention的时候,里面k和q之间点乘产生的attention map其实就是(128,128)尺寸的。然后8刚好对应8个head,所以通过这个操作就可以融合特征了:

1.4 gating

上述的内容和这个alphafold2论文的补充材料中的图2是对应的上的。msa信息变成attention map后,用加法加上pair representation的特征。然后用value和这个融合了pair和msa信息的attention map做一个点乘得到新的value.然后就是到了这个gate操作的地方了。

很简单,就这样一个sigmoid操作就行了。

感觉加入gate的原因,应该是alphafold有四次循环,所以要引入RNN当中的遗忘门的思想?

1.5 PairwiseAttentionBlock

目前为止,我们还需要这一部分的代码:

代码语言:javascript复制
class PairwiseAttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        seq_len,
        heads,
        dim_head,
        dropout = 0.,
        global_column_attn = False
    ):
        super().__init__()
        self.outer_mean = OuterMean(dim)

        self.triangle_attention_outgoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
        self.triangle_attention_ingoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True, accept_edges = True, global_query_attn = global_column_attn)
        self.triangle_multiply_outgoing = TriangleMultiplicativeModule(dim = dim, mix = 'outgoing')
        self.triangle_multiply_ingoing = TriangleMultiplicativeModule(dim = dim, mix = 'ingoing')

    def forward(
        self,
        x,
        mask = None,
        msa_repr = None,
        msa_mask = None
    ):
        if exists(msa_repr):
            x = x   self.outer_mean(msa_repr, mask = msa_mask)

        x = self.triangle_multiply_outgoing(x, mask = mask)   x
        x = self.triangle_multiply_ingoing(x, mask = mask)   x
        x = self.triangle_attention_outgoing(x, edges = x, mask = mask)   x
        x = self.triangle_attention_ingoing(x, edges = x, mask = mask)   x
        return x

这里面出现了一个新的模型类:TriangleMultiplicativeModule

代码语言:javascript复制
class TriangleMultiplicativeModule(nn.Module):
    def __init__(
        self,
        *,
        dim,
        hidden_dim = None,
        mix = 'ingoing'
    ):
        super().__init__()
        assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'

        hidden_dim = default(hidden_dim, dim)
        self.norm = nn.LayerNorm(dim)

        self.left_proj = nn.Linear(dim, hidden_dim)
        self.right_proj = nn.Linear(dim, hidden_dim)

        self.left_gate = nn.Linear(dim, hidden_dim)
        self.right_gate = nn.Linear(dim, hidden_dim)
        self.out_gate = nn.Linear(dim, hidden_dim)

        # initialize all gating to be identity

        for gate in (self.left_gate, self.right_gate, self.out_gate):
            nn.init.constant_(gate.weight, 0.)
            nn.init.constant_(gate.bias, 1.)

        if mix == 'outgoing':
            self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
        elif mix == 'ingoing':
            self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

        self.to_out_norm = nn.LayerNorm(hidden_dim)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, mask = None):
        assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
        if exists(mask):
            mask = rearrange(mask, 'b i j -> b i j ()')

        x = self.norm(x)
        # 做了一个linear映射
        left = self.left_proj(x)
        right = self.right_proj(x)

        if exists(mask):
            left = left * mask
            right = right * mask
        # 一个遗忘门
        left_gate = self.left_gate(x).sigmoid()
        right_gate = self.right_gate(x).sigmoid()
        out_gate = self.out_gate(x).sigmoid()
        # 进行遗忘
        left = left * left_gate
        right = right * right_gate
        # 对行或者列进行相乘
        out = einsum(self.mix_einsum_eq, left, right)
        # layernorm
        out = self.to_out_norm(out)
        out = out * out_gate
        return self.to_out(out)

0 人点赞