PGL图学习之基于GNN模型新冠疫苗任务[系列九]

2022-11-29 10:29:43 浏览数 (3)

PGL图学习之基于GNN模型新冠疫苗任务系列九

项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1

代码语言:txt复制
# 加载一些需要用到的模块,设置随机数
import json
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import networkx as nx

from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser

seed = 123
np.random.seed(seed)
random.seed(seed)

数据EDA

代码语言:txt复制
# https://www.kaggle.com/c/stanford-covid-vaccine/data
# 加载训练用的数据
df = pd.read_json('../data/data179441/train.json', lines=True)
# 查看一下数据集的内容
sample = df.loc[0]
print(sample)

index                                                                400
id                                                          id_2a7a4496f
sequence               GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
structure              .....(((...)))((((((((((((((((((((.((((....)))...
predicted_loop_type    EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
signal_to_noise                                                        0
SN_filter                                                              0
seq_length                                                           107
seq_scored                                                            68
reactivity_error       [146151.225, 146151.225, 146151.225, 146151.22...
deg_error_Mg_pH10      [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10         [222620.9531, 222620.9531, 222620.9531, 222620...
deg_error_Mg_50C       [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_50C          [191738.0886, 191738.0886, 191738.0886, 191738...
reactivity             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_pH10               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_50C                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Name: 0, dtype: object

例如 deg_50C、deg_Mg_50C 这样的值全为0的行,就是我们需要预测的。

structure一行,数据中的括号是为了构成边用的。

本案例要预测RNA序列不同位置的降解速率,训练数据中提供了多个ground值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50

reactivity - (1x68 vector 训练集,1x91测试集) 一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定RNA样本可能的二级结构。

deg_Mg_pH10 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高pH (pH 10)下的降解可能性。

deg_Mg_50 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高温(50摄氏度)下的降解可能性。

代码语言:txt复制
# 利用GraphParser构造图结构的数据
args = prepare_config("./config.yaml", isCreate=False, isSave=False)
parser = GraphParser(args) # GraphParser类来自data_parser.py
gdata = parser.parse(sample) # GraphParser里最主要的函数就是parse(self, sample)
代码语言:txt复制
{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 'edges': array([[  0,   1],
        [  1,   0],
        [  1,   2],
        ...,
        [142, 105],
        [106, 142],
        [142, 106]]),
 'efeat': array([[ 0.,  0.,  0.,  1.,  1.],
        [ 0.,  0.,  0., -1.,  1.],
        [ 0.,  0.,  0.,  1.,  1.],
        ...,
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.]], dtype=float32),
 'labels': array([[ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        ...,
        [ 0.    ,  0.9213,  0.    ],
        [ 6.8894,  3.5097,  5.7754],
        [ 0.    ,  1.8426,  6.0642],
          ...,        
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]], dtype=float32),
 'mask': array([[ True],
        [ True],
     ......
       [False]])}

nfeat —— 节点特征

edges —— 边

efeat —— 边特征

labels —— 节点标签有三种,所以这可以看成是一个多分类任务

图数据可视化

代码语言:txt复制
# 图数据可视化
fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])

nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])]   
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {
    "node_color": node_color,
}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)

plt.show()
Snipaste_2022-11-25_20-53-23.jpgSnipaste_2022-11-25_20-53-23.jpg

模型训练&预测

代码语言:txt复制
# 我们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中加入了边的特征(edge_feat)
# 然后修改 model.py 里的 GNNModel
# 使用修改后的模型,运行 main.py。为节省时间,设置 epochs = 100

!python main.py --config config.yaml

结果返回的是 MCRMSE 和 loss

{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}

代码语言:txt复制
[DEBUG] 2022-11-25 17:50:42,468 [  trainer.py:   66]:	{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}
[DEBUG] 2022-11-25 17:50:42,468 [  trainer.py:   73]:	write to tensorboard ../checkpoints/covid19/eval_history/eval
[DEBUG] 2022-11-25 17:50:42,469 [  trainer.py:   73]:	write to tensorboard ../checkpoints/covid19/eval_history/eval
[INFO] 2022-11-25 17:50:42,469 [  trainer.py:   76]:	[Eval:eval]:MCRMSE:0.5496758818626404	loss:0.3025484172316889
[INFO] 2022-11-25 17:50:42,602 [monitored_executor.py:  606]:	********** Stop Loop ************
[DEBUG] 2022-11-25 17:50:42,607 [monitored_executor.py:  199]:	saving step 12500 to ../checkpoints/covid19/model_12500
代码语言:txt复制
!python main.py --mode infer

0 人点赞