<<AlphaFold2专题>>
alphaFold2 | 解决问题及背景(一)
alphaFold2 | 模型框架搭建(二)
alphaFold2 | 模型细节之特征提取(三)
- 文章转自微信公众号:机器学习炼丹术
- 作者:陈亦新(欢迎交流共同进步)
- Evoformer
- 1.1 FeedForward
- 1.3 MsaAttentionBlock
- 1.4 gating
- 1.5 PairwiseAttentionBlock
上一篇文章谈了一下MSA和pair representation特征的构建,现在我们来看模型结构了。
Evoformer
代码语言:javascript复制x, m = self.net(
x,
m,
mask = x_mask,
msa_mask = msa_mask
)
x就是我们的pair特征,m就是我们的msa特征。net的构建定义为:
代码语言:javascript复制self.net = Evoformer(
dim = dim,
depth = depth,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
OK。可以看到这个net就是我们要找的Evoformer的部分了。
代码语言:javascript复制
class Evoformer(nn.Module):
def __init__(
self,
*,
depth,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([EvoformerBlock(**kwargs) for _ in range(depth)])
def forward(
self,
x,
m,
mask = None,
msa_mask = None
):
inp = (x, m, mask, msa_mask)
x, m, *_ = checkpoint_sequential(self.layers, 1, inp)
return x, m
可以看到这个Evoformer类,非常的短。就是一个空壳罢了,里面核心还是在EvoformerBlock这个组件的定义当中。
代码语言:javascript复制class EvoformerBlock(nn.Module):
def __init__(
self,
*,
dim,
seq_len,
heads,
dim_head,
attn_dropout,
ff_dropout,
global_column_attn = False
):
super().__init__()
self.layer = nn.ModuleList([
PairwiseAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, global_column_attn = global_column_attn),
FeedForward(dim = dim, dropout = ff_dropout),
MsaAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim = dim, dropout = ff_dropout),
])
def forward(self, inputs):
x, m, mask, msa_mask = inputs
attn, ff, msa_attn, msa_ff = self.layer
# msa attention and transition
m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
m = msa_ff(m) m
# pairwise attention and transition
x = attn(x, mask = mask, msa_repr = m, msa_mask = msa_mask)
x = ff(x) x
return x, m, mask, msa_mask
分成这三个组件:
PairwiseAttentionBlock
FeedForward
MsaAttentionBlock
我们来看论文当中,补充材料的部分:
1.1 FeedForward
代码语言:javascript复制class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
init_zero_(self.net[-1])
def forward(self, x, **kwargs):
x = self.norm(x)
return self.net(x)
这个模块就是LayerNorm在加上全连接层。这个GEGLU倒是一个新的激活函数,咱们看看:
代码语言:javascript复制class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
1.3 MsaAttentionBlock
代码语言:javascript复制class MsaAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.
):
super().__init__()
self.row_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.col_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True)
def forward(
self,
x,
mask = None,
pairwise_repr = None
):
# 还有加上了残差链接
x = self.row_attn(x, mask = mask, edges = pairwise_repr) x
x = self.col_attn(x, mask = mask) x
return x
这个部分的组件其实就对应论文的这个部分:
然后分成:
- Row-wise gated self-attention with pair bias
- Column-size gated self-attention
class AxialAttention(nn.Module):
def __init__(
self,
dim,
heads,
row_attn = True,
col_attn = True,
accept_edges = False,
global_query_attn = False,
**kwargs
):
super().__init__()
assert not (not row_attn and not col_attn), 'row or column attention must be turned on'
self.row_attn = row_attn
self.col_attn = col_attn
self.global_query_attn = global_query_attn
self.norm = nn.LayerNorm(dim)
self.attn = Attention(dim = dim, heads = heads, **kwargs)
self.edges_to_attn_bias = nn.Sequential(
nn.Linear(dim, heads, bias = False),
Rearrange('b i j h -> b h i j')
) if accept_edges else None
def forward(self, x, edges = None, mask = None):
assert self.row_attn ^ self.col_attn, 'has to be either row or column attention, but not both'
# x就是我们提取的mse特征,edges就是我们的pair representation特征
# b是batch h是msa的数量,1个target 5个数据库中匹配的
# w是128也就是染色体序列数量,d是token特征维度256
b, h, w, d = x.shape
x = self.norm(x)
# axial attention
if self.col_attn:
axial_dim = w
mask_fold_axial_eq = 'b h w -> (b w) h'
input_fold_eq = 'b h w d -> (b w) h d'
output_fold_eq = '(b w) h d -> b h w d'
elif self.row_attn:
axial_dim = h
mask_fold_axial_eq = 'b h w -> (b h) w'
input_fold_eq = 'b h w d -> (b h) w d'
output_fold_eq = '(b h) w d -> b h w d'
x = rearrange(x, input_fold_eq)
if exists(mask):
mask = rearrange(mask, mask_fold_axial_eq)
attn_bias = None
if exists(self.edges_to_attn_bias) and exists(edges):
attn_bias = self.edges_to_attn_bias(edges)
attn_bias = repeat(attn_bias, 'b h i j -> (b x) h i j', x = axial_dim)
tie_dim = axial_dim if self.global_query_attn else None
out = self.attn(x, mask = mask, attn_bias = attn_bias, tie_dim = tie_dim)
out = rearrange(out, output_fold_eq, h = h, w = w)
return out
可以看到,对行和列做自注意力机制其实是相似的操作:
- 对列进行操作:那么x先被
rearrange
成(b w) h d的形式。所以这里面自注意力的维度,是h维度。h维度msa的数量,所以这种情况下产生的self-attention map其实就是氨基酸序列数量之前的相关性矩阵。 - 同理对于row attention,那么就是对w进行自注意力机制,那么构建的self-attention map就是每一个氨基酸之间的相关性,假设序列长度为128,那么就是128x128的attention map。
在row attention当中,还会引入edges的概念。也就是让pair representation特征引入。edges输入进来的形状应该是(batch,128,128,256)的特征,先经过self.edges_to_attn_bias
的处理,简单地说就是一个全连接层。这个全连接层是这个样子的:
这里面的heads=8,所以输出的特征形状是(batch,128,128,8),然后经过rearrange,变成了(batch,8,128,128).
经过repeat的操作,上面的形状变成(axial_dim,8,128,128)这里batch=1省略了。
最后就是把msa特征,对应的attn_bias特征放到attn这个模型类当中去。
我们来看看这个Attention类如何实现的:
代码语言:javascript复制class Attention(nn.Module):
def __init__(
self,
dim,
seq_len = None,
heads = 8,
dim_head = 64,
dropout = 0.,
gating = True
):
super().__init__()
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.gating = nn.Linear(dim, inner_dim)
nn.init.constant_(self.gating.weight, 0.)
nn.init.constant_(self.gating.bias, 1.)
self.dropout = nn.Dropout(dropout)
init_zero_(self.to_out)
def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None):
device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
i, j = q.shape[-2], k.shape[-2]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# query / key similarities
if exists(tie_dim):
# as in the paper, for the extra MSAs
# they average the queries along the rows of the MSAs
# they named this particular module MSAColumnGlobalAttention
q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k))
q = q.mean(dim = 1)
dots = einsum('b h i d, b r h j d -> b r h i j', q, k)
dots = rearrange(dots, 'b r ... -> (b r) ...')
else:
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# add attention bias, if supplied (for pairwise to msa attention communication)
if exists(attn_bias):
dots = dots attn_bias
# masking
if exists(mask):
mask = default(mask, lambda: torch.ones(1, i, device = device).bool())
context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool())
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
dots = dots.masked_fill(~mask, mask_value)
# attention
dots = dots - dots.max(dim = -1, keepdims = True).values
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# gating
gates = self.gating(x)
out = out * gates.sigmoid()
# combine to out
out = self.to_out(out)
return out
forward函数当中,先产生qkv。head是8,inner_dim=64,所以这里面有8个64特征产生。然后通过rearrange的操作,把qkv的最后一个维度从8x64拆分成两个维度。
这里如何把(128,128,8)的pair representation特征加入呢?很简单,因为我们在做row-size attention的时候,里面k和q之间点乘产生的attention map其实就是(128,128)尺寸的。然后8刚好对应8个head,所以通过这个操作就可以融合特征了:
1.4 gating
上述的内容和这个alphafold2论文的补充材料中的图2是对应的上的。msa信息变成attention map后,用加法加上pair representation的特征。然后用value和这个融合了pair和msa信息的attention map做一个点乘得到新的value.然后就是到了这个gate操作的地方了。
很简单,就这样一个sigmoid操作就行了。
感觉加入gate的原因,应该是alphafold有四次循环,所以要引入RNN当中的遗忘门的思想?
1.5 PairwiseAttentionBlock
目前为止,我们还需要这一部分的代码:
代码语言:javascript复制class PairwiseAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.,
global_column_attn = False
):
super().__init__()
self.outer_mean = OuterMean(dim)
self.triangle_attention_outgoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.triangle_attention_ingoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True, accept_edges = True, global_query_attn = global_column_attn)
self.triangle_multiply_outgoing = TriangleMultiplicativeModule(dim = dim, mix = 'outgoing')
self.triangle_multiply_ingoing = TriangleMultiplicativeModule(dim = dim, mix = 'ingoing')
def forward(
self,
x,
mask = None,
msa_repr = None,
msa_mask = None
):
if exists(msa_repr):
x = x self.outer_mean(msa_repr, mask = msa_mask)
x = self.triangle_multiply_outgoing(x, mask = mask) x
x = self.triangle_multiply_ingoing(x, mask = mask) x
x = self.triangle_attention_outgoing(x, edges = x, mask = mask) x
x = self.triangle_attention_ingoing(x, edges = x, mask = mask) x
return x
这里面出现了一个新的模型类:TriangleMultiplicativeModule
代码语言:javascript复制class TriangleMultiplicativeModule(nn.Module):
def __init__(
self,
*,
dim,
hidden_dim = None,
mix = 'ingoing'
):
super().__init__()
assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'
hidden_dim = default(hidden_dim, dim)
self.norm = nn.LayerNorm(dim)
self.left_proj = nn.Linear(dim, hidden_dim)
self.right_proj = nn.Linear(dim, hidden_dim)
self.left_gate = nn.Linear(dim, hidden_dim)
self.right_gate = nn.Linear(dim, hidden_dim)
self.out_gate = nn.Linear(dim, hidden_dim)
# initialize all gating to be identity
for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)
if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'
self.to_out_norm = nn.LayerNorm(hidden_dim)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, mask = None):
assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
if exists(mask):
mask = rearrange(mask, 'b i j -> b i j ()')
x = self.norm(x)
# 做了一个linear映射
left = self.left_proj(x)
right = self.right_proj(x)
if exists(mask):
left = left * mask
right = right * mask
# 一个遗忘门
left_gate = self.left_gate(x).sigmoid()
right_gate = self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()
# 进行遗忘
left = left * left_gate
right = right * right_gate
# 对行或者列进行相乘
out = einsum(self.mix_einsum_eq, left, right)
# layernorm
out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)