【论文复现】基于多模态深度学习方法的单细胞多组学数据聚类(【生物信息学】实验二:多组学数据融合:scMDC)

2024-07-30 08:40:50 浏览数 (1)

一、实验介绍

1. 论文:基于多模态深度学习方法的单细胞多组学数据聚类

Clustering of single-cell multi-omics data with a multimodal deep learning method | Nature Communications

https://www.nature.com/articles/s41467-022-35031-9

Abstract

Single-cell multimodal sequencing technologies are developed to simultaneously profile different modalities of data in the same cell. It provides a unique opportunity to jointly analyze multimodal data at the single-cell level for the identification of distinct cell types. A correct clustering result is essential for the downstream complex biological functional studies. However, combining different data sources for clustering analysis of single-cell multimodal data remains a statistical and computational challenge. Here, we develop a novel multimodal deep learning method, scMDC, for single-cell multi-omics data clustering analysis. scMDC is an end-to-end deep model that explicitly characterizes different data sources and jointly learns latent features of deep embedding for clustering analysis. Extensive simulation and real-data experiments reveal that scMDC outperforms existing single-cell single-modal and multimodal clustering methods on different single-cell multimodal datasets. The linear scalability of running time makes scMDC a promising method for analyzing large multimodal datasets. 开发单细胞多模态测序技术以同时分析同一细胞中的不同数据模式。它提供了一个独特的机会,可以在单细胞水平上联合分析多模态数据,以鉴定不同的细胞类型。正确的聚类结果对于下游复杂的生物学功能研究至关重要。然而,组合不同的数据源对单细胞多模态数据进行聚类分析仍然是一个统计和计算挑战。在这里,我们开发了一种新颖的多模态深度学习方法scMDC,用于单细胞多组学数据聚类分析。scMDC 是一个端到端的深度模型,它明确表征不同的数据源,并共同学习深度嵌入的潜在特征以进行聚类分析。大量的仿真和真实数据实验表明,scMDC在不同的单细胞多模态数据集上优于现有的单细胞单模态和多模态聚类方法。运行时的线性可扩展性使scMDC成为分析大型多模态数据集的有前途的方法。

2. Github链接

GitHub - xianglin226/scMDC:单细胞多组学深度聚类

https://github.com/xianglin226/scMDC

二、实验环境

0. 作者要求

代码语言:javascript复制
Python 3.8.1

Pytorch 1.6.0

Scanpy 1.6.0

SKlearn 0.22.1

Numpy 1.18.1

h5py 2.9.0

  • 本研究中scMDC的所有实验都是在Nvidia Tesla P100(16G)GPU上进行的。
  • 我们建议在 conda 环境中安装 conda 环境 (conda create -n scMDC)。
  • scMDC 对包含 5000 cells的数据集进行聚类大约需要 3分钟。

1. 环境复现

未重新配置环境,继续使用前文深度学习系列文章的环境,实践表明可行:

代码语言:javascript复制
conda create -n DL python=3.7 
代码语言:javascript复制
conda activate DL
代码语言:javascript复制
pip install torch==1.8.1 cu102 torchvision==0.9.1 cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
代码语言:javascript复制
conda install matplotlib
代码语言:javascript复制
conda install scikit-learn
实验一
代码语言:javascript复制
conda install pandas
代码语言:javascript复制
conda install seaborn
代码语言:javascript复制
conda install networkx
代码语言:javascript复制
conda install statsmodels
代码语言:javascript复制
pip install pyHSICLasso

注:本人的实验环境按照上述顺序安装各种库,若想尝试一起安装(天知道会不会出问题)

实验二(本实验)
代码语言:javascript复制
pip install scanpy

2. 库版本介绍

自己配置的py3.7环境各版本介绍

软件包

本实验版本

目前最新版

matplotlib

3.5.3

3.8.0

numpy

1.21.6

1.26.0

python

3.7.16

scikit-learn

0.22.1

1.3.0

torch

1.8.1 cu102

2.0.1

torchaudio

0.8.1

2.0.2

torchvision

0.9.1 cu102

0.15.2

实验一

networkx

2.6.3

3.1

pandas

1.2.3

2.1.1

pyHSICLasso

1.4.2

1.4.2

seaborn

0.12.2

0.13.0

statsmodels

0.13.5

0.14.0

实验二
代码语言:javascript复制
Installing collected packages: stdlib_list, natsort, llvmlite, h5py, session-info, numba, pynndescent, anndata, umap-learn, scanpy
Successfully installed anndata-0.8.0 h5py-3.8.0 llvmlite-0.39.1 natsort-8.4.0 numba-0.56.4 pynndescent-0.5.10 scanpy-1.9.3 session-info-1.0.0 stdlib_list-0.9.0 umap-learn-0.5.4

3. IDE

建议使用Pycharm

win11 安装 Anaconda(2022.10) pycharm(2022.3/2023.1.4) 配置虚拟环境_QomolangmaH的博客-CSDN博客https://blog.csdn.net/m0_63834988/article/details/128693741https://blog.csdn.net/m0_63834988/article/details/128693741https://blog.csdn.net/m0_63834988/article/details/128693741

https://blog.csdn.net/m0_63834988/article/details/128693741

三、实验内容

1. 用法

  • 准备 h5 格式的输入数据。(参阅“数据集”文件夹中的自述文件)
  • 根据“脚本”文件夹中的运行脚本运行scMDC(如果您处理mRNA ATAC数据并使用run_scMDC_batch.py进行多批次数据聚类,请注意参数设置)
  • 基于训练良好的 scMDC 模型通过run_LRP.py运行 DE 分析(请参阅“脚本”文件夹中的 LRP 运行脚本)

2. 输出

  • scMDC输出数据的潜在表示,可用于进一步的下游分析,并通过t-SNE或Umap可视化;
  • 多批次 scMDC 输出集成数据集的潜在表示,在其上校正批处理效应。
  • LRP输出一个基因等级,指示基因对给定簇的重要性,可用于通路分析。

3. 参数

  • --n_clusters:簇数(K);如果此参数设置为 -1,scMDC 将估计 K。
  • --cutoff:一个周期的比率,在此之前,模型只训练低级自动编码器。
  • --batch_size:批量大小。
  • --data_file:数据输入的路径。
    • 数据格式:H5。
    • 结构:X1(RNA),X2(ADT或ATAC),Y(标签,如果退出),批次(多批次数据聚类的批次指示符)。
  • --maxiter:训练的最大时期。默认值:10000。
  • --pretrain_epochs:预训练的周期数。默认值:400。
  • --gamma:聚类损失系数。默认值:0.1。
  • --phi1 和 phi2:预训练和聚类阶段的 KL 损失系数。默认值:CITE-Seq 为 0.001;0.005 表示 SMAGE-Seq*。
  • --update_interval:检查性能的间隔。默认值:1。 --tol:停止模型的条件,即更改标签的百分比。
  • --tol:0.001。
  • --ae_weights:权重文件的路径。
  • --save_dir:存储输出的目录。
  • --ae_weight_file:存储权重的目录。
  • --resolution:用于估计 k 的分辨率参数,默认值:0.2。
  • --n_neighbors:估计 K 的n_neighbors参数,默认值:30。
  • --embedding_file:如果保存嵌入文件。默认值:否
  • --prediction_file:如果保存预测文件。默认值:否
  • --encodeLayer:RNA 的低电平编码器层:默认值:CITE-Seq 的 [256,64,32,16];[256,128,64] 对于 SMAGE-seq。
  • --decodeLayer1:ADT 的低级编码器层:默认:CITE-Seq 的 [16,64,256]。[64,128,256] 为 SMAGE-seq。
  • --decodeLayer2:高级编码器的层。默认值:[16,20] 表示 CITE-Seq。[64,128,256] 为 SMAGE-seq。
  • --sigma1:RNA数据上的噪声。默认值:2.5。
  • --sigma2:ADT数据上的噪声。默认值:CITE-Seq 为 1.5;2.5 对于 SMAGE-Seq
  • --filter1:如果对基因进行特征选择。默认值:否。
  • --filter2:如果在ATAC上进行功能选择。默认值:否。
  • --f1:如果进行羽毛选择,则用于聚类的高可变基因(在X1中)的数量。默认值:2000 -
  • -f2:如果进行羽毛选择,则来自 ATAC(在 X2 中)的高变量基因数用于聚类。默认值:2000
  • *为方便起见,我们将 10X 单细胞多组 ATAC 基因表达技术表示为 SMAGE-seq。

4. run_scMDC

利用scMultiCluster模型联合利用scRNA-seq和其他模态数据进行细胞类型的无监督聚类,并评估了聚类效果。

设置超参数

读取数据,数据包括两个模态:scRNA-seq读数矩阵X1和另一种基因组数据X2(如ADT/ATAC),以及标签y。

代码语言:javascript复制
    import argparse

    parser = argparse.ArgumentParser(description='train',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--n_clusters', default=12, type=int)
    parser.add_argument('--cutoff', default=0.5, type=float,
                        help='Start to train combined layer after what ratio of epoch')
    parser.add_argument('--batch_size', default=256, type=int)
    # parser.add_argument('--data_file', default=f'../datasets/SMAGESeq_10X_pbmc_10k_granulocyte_plus.h5')
    parser.add_argument('--data_file', default=f'../datasets/output500.h5')
    parser.add_argument('--maxiter', default=5000, type=int)
    parser.add_argument('--pretrain_epochs', default=200, type=int)
    parser.add_argument('--gamma', default=.1, type=float,
                        help='coefficient of clustering loss')
    parser.add_argument('--tau', default=.1, type=float,
                        help='fuzziness of clustering loss')
    parser.add_argument('--phi1', default=0.005, type=float,
                        help='coefficient of KL loss in pretraining stage')
    parser.add_argument('--phi2', default=0.005, type=float,
                        help='coefficient of KL loss in clustering stage')
    parser.add_argument('--update_interval', default=1, type=int)
    parser.add_argument('--tol', default=0.001, type=float)
    parser.add_argument('--lr', default=1., type=float)
    parser.add_argument('--ae_weights', default=None)
    parser.add_argument('--save_dir', default='atac_pbmc10k')
    parser.add_argument('--ae_weight_file', default='AE_weights_pbmc10k.pth.tar')
    parser.add_argument('--resolution', default=0.2, type=float)
    parser.add_argument('--n_neighbors', default=30, type=int)
    parser.add_argument('--embedding_file', default=True)
    parser.add_argument('--prediction_file', default=True)
    parser.add_argument('-el', '--encodeLayer', nargs=' ', default=[256, 128, 64])
    parser.add_argument('-dl1', '--decodeLayer1', nargs=' ', default=[64, 128, 256])
    parser.add_argument('-dl2', '--decodeLayer2', nargs=' ', default=[64, 128, 256])
    parser.add_argument('--sigma1', default=2.5, type=float)
    parser.add_argument('--sigma2', default=2.5, type=float)
    parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
    parser.add_argument('--f2', default=1000, type=float, help='Number of ADT/ATAC after feature selection')
    parser.add_argument('--filter1', default=True, help='Do mRNA selection')
    parser.add_argument('--filter2', default=True, help='Do ADT/ATAC selection')
    parser.add_argument('--run', default=1, type=int)
    parser.add_argument('--device', default='cpu')
    args = parser.parse_args()
    print(args)

输出:

代码语言:javascript复制
Namespace(ae_weight_file='AE_weights_pbmc10k.pth.tar', ae_weights=None, batch_size=256, cutoff=0.5, data_file='../datasets/output500.h5', decodeLayer1=[64, 128, 256], decodeLayer2=[64, 128, 256], device='cpu', embedding_file=True, encodeLayer=[256, 128, 64], f1=1000, f2=1000, filter1=True, filter2=True, gamma=0.1, lr=1.0, maxiter=5000, n_clusters=12, n_neighbors=30, phi1=0.005, phi2=0.005, prediction_file=True, pretrain_epochs=200, resolution=0.2, run=1, save_dir='atac_pbmc10k', sigma1=2.5, sigma2=2.5, tau=0.1, tol=0.001, update_interval=1)
对scRNA-seq数据进行预处理

规范化,批量效应校正,log变换等。

代码语言:javascript复制
    data_mat = h5py.File(args.data_file)
    x1 = np.array(data_mat['X1'])
    x2 = np.array(data_mat['X2'])
    y = np.array(data_mat['Y'])
    data_mat.close()

    # #Gene filter
    # if args.filter1:
    #     importantGenes = geneSelection(x1, n=args.f1, plot=False)
    #     x1 = x1[:, importantGenes]
    # if args.filter2:
    #     importantGenes = geneSelection(x2, n=args.f2, plot=False)
    #     x2 = x2[:, importantGenes]

    print('------------------------------------')
    print(x1.shape)
    print(x2.shape)

    # preprocessing scRNA-seq read counts matrix
    adata1 = sc.AnnData(x1)
    adata1.obs['Group'] = y

    adata1 = read_dataset(adata1,
                          transpose=False,
                          test_split=False,
                          copy=True)

    adata1 = normalize(adata1,
                       size_factors=True,
                       normalize_input=True,
                       filter_min_counts=False,
                       logtrans_input=True,
                       nor=0)

    adata2 = sc.AnnData(x2)
    adata2.obs['Group'] = y
    adata2 = read_dataset(adata2,
                          transpose=False,
                          test_split=False,
                          copy=True)

    adata2 = normalize(adata2,
                       size_factors=True,
                       normalize_input=True,
                       logtrans_input=True)

    # adata2 = clr_normalize_each_cell(adata2)

    # adata1.write_h5ad('adata1.h5ad')  # 保存adata1为h5ad文件
    # adata2.write_h5ad('adata2.h5ad')  # 保存adata2为h5ad文件
    # print('------------------------------------')
    # print(adata1.X.shape)
    # print(adata2.X.shape)

    input_size1 = adata1.n_vars
    input_size2 = adata2.n_vars

    print(args)

输出:

代码语言:javascript复制
(11020, 500)
(11020, 500)
Namespace(ae_weight_file='AE_weights_pbmc10k.pth.tar', ae_weights=None, batch_size=256, cutoff=0.5, data_file='../datasets/output500.h5', decodeLayer1=[64, 128, 256], decodeLayer2=[64, 128, 256], device='cpu', embedding_file=True, encodeLayer=[256, 128, 64], f1=1000, f2=1000, filter1=True, filter2=True, gamma=0.1, lr=1.0, maxiter=5000, n_clusters=12, n_neighbors=30, phi1=0.005, phi2=0.005, prediction_file=True, pretrain_epochs=200, resolution=0.2, run=1, save_dir='atac_pbmc10k', sigma1=2.5, sigma2=2.5, tau=0.1, tol=0.001, update_interval=1)
构建scMultiCluster模型
  • 该模型由一个共享的encoder和两个私有的decoder组成,编码器来连接来自不同模态的数据,并使用两个解码器来单独解码来自每个模态的数据。
  • 其中使用 ZINB 损失作为重构损失,因为该损失能很好的表征大量丢失且稀疏的数据,实际应用中已经发现可以有效地拟合scRNA-seq数据并改善表示学习和聚类结果。
  • 而使用KL损失可以吸引相似的细胞并分离不同的细胞,帮助得到更好的聚类结果。
  • 通过加入Clustering损失能直接针对聚类结果进行优化,实现最优的聚类结果。
代码语言:javascript复制
    encodeLayer = list(map(int, args.encodeLayer))
    decodeLayer1 = list(map(int, args.decodeLayer1))
    decodeLayer2 = list(map(int, args.decodeLayer2))
    
    model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
                        encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
                        activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 
                        cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
    
    print(str(model))

输出:

代码语言:javascript复制
scMultiCluster(
  (encoder): Sequential(
    (0): Linear(in_features=1000, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ELU(alpha=1.0)
  )
  (decoder1): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
  )
  (decoder2): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
  )
  (dec_mean1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): MeanAct()
  )
  (dec_disp1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): DispAct()
  )
  (dec_mean2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): MeanAct()
  )
  (dec_disp2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): DispAct()
  )
  (dec_pi1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): Sigmoid()
  )
  (dec_pi2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): Sigmoid()
  )
  (zinb_loss): ZINBLoss()
)
预训练

此时只使用了ZINB损失和KL损失

代码语言:javascript复制
    if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
            
    t0 = time()
    if args.ae_weights is None:
        model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 
                X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, batch_size=args.batch_size, 
                epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
    else:
        if os.path.isfile(args.ae_weights):
            print("==> loading checkpoint '{}'".format(args.ae_weights))
            checkpoint = torch.load(args.ae_weights)
            model.load_state_dict(checkpoint['ae_state_dict'])
        else:
            print("==> no checkpoint found at '{}'".format(args.ae_weights))
            raise ValueError
    
    print('Pretraining time: %d seconds.' % int(time() - t0))
使用KMeans确定聚类数k
代码语言:javascript复制
    #get k
    latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
    latent = latent.cpu().numpy()
    if args.n_clusters == -1:
       n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
    else:
       print("n_cluster is defined as "   str(args.n_clusters))
       n_clusters = args.n_clusters
代码语言:javascript复制
n_cluster is defined as 12
微调模型,训练聚类层

正式的训练过程,此时加入了聚类损失进行模型的训练,从而得到较好的聚类结果

代码语言:javascript复制
    y_pred, _, _, _, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
                                   X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=y,
                                   n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter,
                                   update_interval=args.update_interval, tol=args.tol, lr=args.lr,
                                   save_dir=args.save_dir)
    print('Total time: %d seconds.' % int(time() - t0))
保存预测结果和embedding到文件
代码语言:javascript复制
    if args.prediction_file:
       y_pred_ = best_map(y, y_pred) - 1
       np.savetxt(args.save_dir   "/"   str(args.run)   "_pred.csv", y_pred_, delimiter=",")
    
    if args.embedding_file:
       final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
       final_latent = final_latent.cpu().numpy()
       np.savetxt(args.save_dir   "/"   str(args.run)   "_embedding.csv", final_latent, delimiter=",")
输出预测结果y_pred,计算指标AMI、NMI和ARI
代码语言:javascript复制
    y_pred_ = best_map(y, y_pred)
    ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
    nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
    ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
    print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))

实验结果:

代码语言:javascript复制
Final: AMI= 0.6565, NMI= 0.6575, ARI= 0.6813

5. 聚类结果可视化

代码语言:javascript复制
import umap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

umap_model = umap.UMAP()
# 读取CSV文件
encoded_data = pd.read_csv('./atac_pbmc10k/1_embedding.csv', header=None)
cluster_labels = pd.read_csv('./atac_pbmc10k/1_pred.csv', header=None)
unique_labels = np.unique(cluster_labels.astype(int))

# 对编码结果进行UMAP转换
umap_data = umap_model.fit_transform(encoded_data)
fig, ax = plt.subplots()

# 使用Matplotlib绘制UMAP转换结果的散点图
for label in unique_labels:
    # 提取属于当前簇的数据点的索引
    indices = np.where(cluster_labels == label)[0]

    # 获取属于当前簇的数据点的UMAP坐标
    cluster_points = umap_data[indices, :]

    # 使用不同的颜色标记当前簇的数据点
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {label}')
ax.set_xlabel('UMAP Dimension 1')
ax.set_ylabel('UMAP Dimension 2')
ax.set_title('UMAP Visualization with Cluster Colors')

# 显示图例,并将图例放置在右上角的位置
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left')
x_min, x_max = np.min(umap_data[:, 0]), np.max(umap_data[:, 0])
y_min, y_max = np.min(umap_data[:, 1]), np.max(umap_data[:, 1])
x_margin = (x_max - x_min) * 0.1
y_margin = (y_max - y_min) * 0.1
plt.xlim(x_min - x_margin, x_max   x_margin)
plt.ylim(y_min - y_margin, y_max   y_margin)
# 调整子图布局以缩小UMAP结果的整体尺寸
plt.tight_layout()
# 保存为图片
plt.savefig('umap_visualization.png')
plt.show()
plt.close()
可视化结果

附录:run_scMDC训练过程

代码语言:javascript复制
### Autoencoder: Successfully preprocessed 500 genes and 11020 cells.
F:ProgrammingPycharmProjectsDeepLearningbioclass2scMDC-masterscMDC-mastersrcrun_scMDC.py:106: FutureWarning: X.dtype being converted to np.float32 from float64. In the next version of anndata (0.9) conversion will not be automatic. Pass dtype explicitly to avoid this warning. Pass `AnnData(X, dtype=X.dtype, ...)` to get the future behavour.
  adata2 = sc.AnnData(x2)
### Autoencoder: Successfully preprocessed 500 genes and 11020 cells.
Namespace(ae_weight_file='AE_weights_pbmc10k.pth.tar', ae_weights=None, batch_size=256, cutoff=0.5, data_file='../datasets/output500.h5', decodeLayer1=[64, 128, 256], decodeLayer2=[64, 128, 256], device='cpu', embedding_file=True, encodeLayer=[256, 128, 64], f1=1000, f2=1000, filter1=True, filter2=True, gamma=0.1, lr=1.0, maxiter=5000, n_clusters=12, n_neighbors=30, phi1=0.005, phi2=0.005, prediction_file=True, pretrain_epochs=200, resolution=0.2, run=1, save_dir='atac_pbmc10k', sigma1=2.5, sigma2=2.5, tau=0.1, tol=0.001, update_interval=1)
scMultiCluster(
  (encoder): Sequential(
    (0): Linear(in_features=1000, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ELU(alpha=1.0)
  )
  (decoder1): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
  )
  (decoder2): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
  )
  (dec_mean1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): MeanAct()
  )
  (dec_disp1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): DispAct()
  )
  (dec_mean2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): MeanAct()
  )
  (dec_disp2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): DispAct()
  )
  (dec_pi1): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): Sigmoid()
  )
  (dec_pi2): Sequential(
    (0): Linear(in_features=256, out_features=500, bias=True)
    (1): Sigmoid()
  )
  (zinb_loss): ZINBLoss()
)
Pretraining stage
Pretrain epoch 1, Total loss:1.782213, ZINB loss1:0.803113, ZINB loss2:0.979099, KL loss:0.000000
Pretrain epoch 11, Total loss:1.526047, ZINB loss1:0.686576, ZINB loss2:0.839471, KL loss:0.000000
Pretrain epoch 21, Total loss:1.516047, ZINB loss1:0.680581, ZINB loss2:0.835466, KL loss:0.000000
Pretrain epoch 31, Total loss:1.513802, ZINB loss1:0.679482, ZINB loss2:0.834320, KL loss:0.000000
Pretrain epoch 41, Total loss:1.510133, ZINB loss1:0.677538, ZINB loss2:0.832595, KL loss:0.000000
Pretrain epoch 51, Total loss:1.509518, ZINB loss1:0.676798, ZINB loss2:0.832721, KL loss:0.000000
Pretrain epoch 61, Total loss:1.505354, ZINB loss1:0.675105, ZINB loss2:0.830248, KL loss:0.000000
Pretrain epoch 71, Total loss:1.503640, ZINB loss1:0.673856, ZINB loss2:0.829784, KL loss:0.000000
Pretrain epoch 81, Total loss:1.501737, ZINB loss1:0.672963, ZINB loss2:0.828775, KL loss:0.000000
Pretrain epoch 91, Total loss:1.500950, ZINB loss1:0.672635, ZINB loss2:0.828315, KL loss:0.000000
Pretrain epoch 101, Total loss:1.501092, ZINB loss1:0.671503, ZINB loss2:0.827678, KL loss:0.382204
Pretrain epoch 111, Total loss:1.500101, ZINB loss1:0.671465, ZINB loss2:0.826861, KL loss:0.355161
Pretrain epoch 121, Total loss:1.500591, ZINB loss1:0.671512, ZINB loss2:0.827387, KL loss:0.338605
Pretrain epoch 131, Total loss:1.500021, ZINB loss1:0.671240, ZINB loss2:0.827154, KL loss:0.325242
Pretrain epoch 141, Total loss:1.497220, ZINB loss1:0.669683, ZINB loss2:0.825961, KL loss:0.315069
Pretrain epoch 151, Total loss:1.496627, ZINB loss1:0.669643, ZINB loss2:0.825458, KL loss:0.305160
Pretrain epoch 161, Total loss:1.495756, ZINB loss1:0.669362, ZINB loss2:0.824900, KL loss:0.298926
Pretrain epoch 171, Total loss:1.495618, ZINB loss1:0.669167, ZINB loss2:0.824976, KL loss:0.295155
Pretrain epoch 181, Total loss:1.495205, ZINB loss1:0.669032, ZINB loss2:0.824735, KL loss:0.287623
Pretrain epoch 191, Total loss:1.495886, ZINB loss1:0.669239, ZINB loss2:0.825226, KL loss:0.284036
Pretraining time: 1238 seconds.
n_cluster is defined as 12
Clustering stage
Initializing cluster centers with kmeans.
Initializing k-means: AMI= 0.6138, NMI= 0.6150, ARI= 0.5139
Clustering   1: AMI= 0.6138, NMI= 0.6150, ARI= 0.5139 ,acc= 0.6455
#Epoch 1: Total: 2.262678 Clustering Loss: 0.513242 ZINB Loss1: 0.666067 ZINB Loss2: 0.822290 KL Loss: 0.261079
Clustering   2: AMI= 0.6151, NMI= 0.6163, ARI= 0.5216 ,acc= 0.6529
Clustering   3: AMI= 0.6148, NMI= 0.6160, ARI= 0.5194 ,acc= 0.6439
Clustering   4: AMI= 0.6161, NMI= 0.6173, ARI= 0.5207 ,acc= 0.6378
Clustering   5: AMI= 0.6162, NMI= 0.6174, ARI= 0.5217 ,acc= 0.6398
Clustering   6: AMI= 0.6169, NMI= 0.6181, ARI= 0.5223 ,acc= 0.6431
Clustering   7: AMI= 0.6165, NMI= 0.6177, ARI= 0.5244 ,acc= 0.6524
Clustering   8: AMI= 0.6174, NMI= 0.6186, ARI= 0.5276 ,acc= 0.6597
Clustering   9: AMI= 0.6172, NMI= 0.6184, ARI= 0.5238 ,acc= 0.6427
Clustering   10: AMI= 0.6177, NMI= 0.6188, ARI= 0.5240 ,acc= 0.6410
Clustering   11: AMI= 0.6174, NMI= 0.6186, ARI= 0.5234 ,acc= 0.6428
#Epoch 11: Total: 1.948029 Clustering Loss: 0.232234 ZINB Loss1: 0.667498 ZINB Loss2: 0.823401 KL Loss: 0.224896
Clustering   12: AMI= 0.6161, NMI= 0.6173, ARI= 0.5236 ,acc= 0.6477
Clustering   13: AMI= 0.6163, NMI= 0.6175, ARI= 0.5221 ,acc= 0.6398
Clustering   14: AMI= 0.6162, NMI= 0.6174, ARI= 0.5246 ,acc= 0.6507
Clustering   15: AMI= 0.6162, NMI= 0.6174, ARI= 0.5252 ,acc= 0.6567
Clustering   16: AMI= 0.6160, NMI= 0.6172, ARI= 0.5253 ,acc= 0.6532
Clustering   17: AMI= 0.6171, NMI= 0.6183, ARI= 0.5364 ,acc= 0.6773
Clustering   18: AMI= 0.6163, NMI= 0.6176, ARI= 0.5235 ,acc= 0.6440
Clustering   19: AMI= 0.6179, NMI= 0.6191, ARI= 0.5318 ,acc= 0.6679
Clustering   20: AMI= 0.6161, NMI= 0.6173, ARI= 0.5235 ,acc= 0.6448
Clustering   21: AMI= 0.6174, NMI= 0.6186, ARI= 0.5292 ,acc= 0.6605
#Epoch 21: Total: 1.904330 Clustering Loss: 0.193283 ZINB Loss1: 0.667779 ZINB Loss2: 0.823706 KL Loss: 0.219561
Clustering   22: AMI= 0.6159, NMI= 0.6171, ARI= 0.5233 ,acc= 0.6420
Clustering   23: AMI= 0.6151, NMI= 0.6163, ARI= 0.5231 ,acc= 0.6491
Clustering   24: AMI= 0.6157, NMI= 0.6169, ARI= 0.5229 ,acc= 0.6378
Clustering   25: AMI= 0.6172, NMI= 0.6184, ARI= 0.5334 ,acc= 0.6711
Clustering   26: AMI= 0.6158, NMI= 0.6170, ARI= 0.5265 ,acc= 0.6700
Clustering   27: AMI= 0.6181, NMI= 0.6193, ARI= 0.5386 ,acc= 0.6788
Clustering   28: AMI= 0.6158, NMI= 0.6170, ARI= 0.5239 ,acc= 0.6397
Clustering   29: AMI= 0.6188, NMI= 0.6200, ARI= 0.5383 ,acc= 0.6785
Clustering   30: AMI= 0.6171, NMI= 0.6183, ARI= 0.5255 ,acc= 0.6495
Clustering   31: AMI= 0.6174, NMI= 0.6186, ARI= 0.5264 ,acc= 0.6532
#Epoch 31: Total: 1.881909 Clustering Loss: 0.174528 ZINB Loss1: 0.667321 ZINB Loss2: 0.823386 KL Loss: 0.216674
Clustering   32: AMI= 0.6159, NMI= 0.6171, ARI= 0.5226 ,acc= 0.6438
Clustering   33: AMI= 0.6182, NMI= 0.6194, ARI= 0.5309 ,acc= 0.6648
Clustering   34: AMI= 0.6159, NMI= 0.6171, ARI= 0.5229 ,acc= 0.6410
Clustering   35: AMI= 0.6166, NMI= 0.6178, ARI= 0.5244 ,acc= 0.6443
Clustering   36: AMI= 0.6157, NMI= 0.6169, ARI= 0.5232 ,acc= 0.6404
Clustering   37: AMI= 0.6179, NMI= 0.6190, ARI= 0.5260 ,acc= 0.6488
Clustering   38: AMI= 0.6183, NMI= 0.6195, ARI= 0.5290 ,acc= 0.6600
Clustering   39: AMI= 0.6179, NMI= 0.6191, ARI= 0.5312 ,acc= 0.6658
Clustering   40: AMI= 0.6181, NMI= 0.6193, ARI= 0.5320 ,acc= 0.6681
Clustering   41: AMI= 0.6186, NMI= 0.6198, ARI= 0.5366 ,acc= 0.6754
#Epoch 41: Total: 1.867074 Clustering Loss: 0.161755 ZINB Loss1: 0.667553 ZINB Loss2: 0.823085 KL Loss: 0.214682
Clustering   42: AMI= 0.6179, NMI= 0.6191, ARI= 0.5297 ,acc= 0.6630
Clustering   43: AMI= 0.6173, NMI= 0.6185, ARI= 0.5256 ,acc= 0.6487
Clustering   44: AMI= 0.6150, NMI= 0.6162, ARI= 0.5231 ,acc= 0.6488
Clustering   45: AMI= 0.6175, NMI= 0.6187, ARI= 0.5279 ,acc= 0.6546
Clustering   46: AMI= 0.6186, NMI= 0.6198, ARI= 0.5315 ,acc= 0.6661
Clustering   47: AMI= 0.6180, NMI= 0.6192, ARI= 0.5317 ,acc= 0.6664
Clustering   48: AMI= 0.6171, NMI= 0.6183, ARI= 0.5326 ,acc= 0.6680
Clustering   49: AMI= 0.6186, NMI= 0.6198, ARI= 0.5321 ,acc= 0.6668
Clustering   50: AMI= 0.6171, NMI= 0.6183, ARI= 0.5286 ,acc= 0.6609
Clustering   51: AMI= 0.6186, NMI= 0.6198, ARI= 0.5317 ,acc= 0.6655
#Epoch 51: Total: 1.854746 Clustering Loss: 0.151936 ZINB Loss1: 0.667286 ZINB Loss2: 0.822858 KL Loss: 0.212667
Clustering   52: AMI= 0.6165, NMI= 0.6177, ARI= 0.5271 ,acc= 0.6545
Clustering   53: AMI= 0.6151, NMI= 0.6163, ARI= 0.5232 ,acc= 0.6374
Clustering   54: AMI= 0.6173, NMI= 0.6185, ARI= 0.5273 ,acc= 0.6571
Clustering   55: AMI= 0.6160, NMI= 0.6172, ARI= 0.5227 ,acc= 0.6375
Clustering   56: AMI= 0.6151, NMI= 0.6163, ARI= 0.5230 ,acc= 0.6504
Clustering   57: AMI= 0.6190, NMI= 0.6202, ARI= 0.5405 ,acc= 0.6825
Clustering   58: AMI= 0.6180, NMI= 0.6192, ARI= 0.5296 ,acc= 0.6625
Clustering   59: AMI= 0.6159, NMI= 0.6171, ARI= 0.5231 ,acc= 0.6391
Clustering   60: AMI= 0.6174, NMI= 0.6186, ARI= 0.5297 ,acc= 0.6619
Clustering   61: AMI= 0.6176, NMI= 0.6188, ARI= 0.5277 ,acc= 0.6547
#Epoch 61: Total: 1.844312 Clustering Loss: 0.143666 ZINB Loss1: 0.667285 ZINB Loss2: 0.822651 KL Loss: 0.210710
Clustering   62: AMI= 0.6178, NMI= 0.6190, ARI= 0.5306 ,acc= 0.6644
Clustering   63: AMI= 0.6197, NMI= 0.6209, ARI= 0.5392 ,acc= 0.6800
Clustering   64: AMI= 0.6153, NMI= 0.6165, ARI= 0.5229 ,acc= 0.6436
Clustering   65: AMI= 0.6148, NMI= 0.6160, ARI= 0.5224 ,acc= 0.6383
Clustering   66: AMI= 0.6150, NMI= 0.6162, ARI= 0.5228 ,acc= 0.6380
Clustering   67: AMI= 0.6173, NMI= 0.6186, ARI= 0.5262 ,acc= 0.6549
Clustering   68: AMI= 0.6163, NMI= 0.6175, ARI= 0.5235 ,acc= 0.6450
Clustering   69: AMI= 0.6155, NMI= 0.6167, ARI= 0.5227 ,acc= 0.6414
Clustering   70: AMI= 0.6182, NMI= 0.6194, ARI= 0.5315 ,acc= 0.6672
Clustering   71: AMI= 0.6177, NMI= 0.6189, ARI= 0.5292 ,acc= 0.6623
#Epoch 71: Total: 1.837503 Clustering Loss: 0.137466 ZINB Loss1: 0.667154 ZINB Loss2: 0.823006 KL Loss: 0.209877
Clustering   72: AMI= 0.6193, NMI= 0.6205, ARI= 0.5354 ,acc= 0.6741
Clustering   73: AMI= 0.6178, NMI= 0.6190, ARI= 0.5271 ,acc= 0.6565
Clustering   74: AMI= 0.6164, NMI= 0.6176, ARI= 0.5283 ,acc= 0.6595
Clustering   75: AMI= 0.6205, NMI= 0.6217, ARI= 0.5467 ,acc= 0.6907
Clustering   76: AMI= 0.6184, NMI= 0.6196, ARI= 0.5365 ,acc= 0.6762
Clustering   77: AMI= 0.6162, NMI= 0.6175, ARI= 0.5233 ,acc= 0.6411
Clustering   78: AMI= 0.6165, NMI= 0.6177, ARI= 0.5250 ,acc= 0.6468
Clustering   79: AMI= 0.6182, NMI= 0.6194, ARI= 0.5325 ,acc= 0.6694
Clustering   80: AMI= 0.6148, NMI= 0.6160, ARI= 0.5219 ,acc= 0.6413
Clustering   81: AMI= 0.6155, NMI= 0.6167, ARI= 0.5226 ,acc= 0.6381
#Epoch 81: Total: 1.829174 Clustering Loss: 0.131217 ZINB Loss1: 0.666992 ZINB Loss2: 0.822744 KL Loss: 0.208221
Clustering   82: AMI= 0.6169, NMI= 0.6181, ARI= 0.5292 ,acc= 0.6621
Clustering   83: AMI= 0.6161, NMI= 0.6173, ARI= 0.5231 ,acc= 0.6398
Clustering   84: AMI= 0.6174, NMI= 0.6186, ARI= 0.5304 ,acc= 0.6645
Clustering   85: AMI= 0.6143, NMI= 0.6155, ARI= 0.5223 ,acc= 0.6382
Clustering   86: AMI= 0.6182, NMI= 0.6194, ARI= 0.5368 ,acc= 0.6767
Clustering   87: AMI= 0.6189, NMI= 0.6201, ARI= 0.5410 ,acc= 0.6833
Clustering   88: AMI= 0.6162, NMI= 0.6174, ARI= 0.5242 ,acc= 0.6465
Clustering   89: AMI= 0.6157, NMI= 0.6169, ARI= 0.5253 ,acc= 0.6517
Clustering   90: AMI= 0.6170, NMI= 0.6182, ARI= 0.5278 ,acc= 0.6595
Clustering   91: AMI= 0.6166, NMI= 0.6178, ARI= 0.5306 ,acc= 0.6647
#Epoch 91: Total: 1.822761 Clustering Loss: 0.126242 ZINB Loss1: 0.666825 ZINB Loss2: 0.822552 KL Loss: 0.207142
Clustering   92: AMI= 0.6182, NMI= 0.6194, ARI= 0.5399 ,acc= 0.6824
Clustering   93: AMI= 0.6136, NMI= 0.6148, ARI= 0.5246 ,acc= 0.6499
Clustering   94: AMI= 0.6169, NMI= 0.6182, ARI= 0.5400 ,acc= 0.6822
Clustering   95: AMI= 0.6180, NMI= 0.6192, ARI= 0.5478 ,acc= 0.6946
Clustering   96: AMI= 0.6080, NMI= 0.6093, ARI= 0.5229 ,acc= 0.6466
Clustering   97: AMI= 0.6151, NMI= 0.6163, ARI= 0.5506 ,acc= 0.6985
Clustering   98: AMI= 0.6151, NMI= 0.6163, ARI= 0.5461 ,acc= 0.6935
Clustering   99: AMI= 0.6134, NMI= 0.6146, ARI= 0.5449 ,acc= 0.6916
Clustering   100: AMI= 0.6119, NMI= 0.6131, ARI= 0.5476 ,acc= 0.6966
Clustering   101: AMI= 0.6103, NMI= 0.6115, ARI= 0.5355 ,acc= 0.6808
#Epoch 101: Total: 1.817155 Clustering Loss: 0.122029 ZINB Loss1: 0.666804 ZINB Loss2: 0.822372 KL Loss: 0.205949
Clustering   102: AMI= 0.6162, NMI= 0.6174, ARI= 0.5599 ,acc= 0.7123
Clustering   103: AMI= 0.6121, NMI= 0.6134, ARI= 0.5448 ,acc= 0.6955
Clustering   104: AMI= 0.6194, NMI= 0.6207, ARI= 0.5718 ,acc= 0.7265
Clustering   105: AMI= 0.6128, NMI= 0.6140, ARI= 0.5524 ,acc= 0.7064
Clustering   106: AMI= 0.6159, NMI= 0.6171, ARI= 0.5718 ,acc= 0.7266
Clustering   107: AMI= 0.6322, NMI= 0.6334, ARI= 0.6308 ,acc= 0.7708
Clustering   108: AMI= 0.6300, NMI= 0.6312, ARI= 0.6223 ,acc= 0.7665
Clustering   109: AMI= 0.6290, NMI= 0.6302, ARI= 0.6208 ,acc= 0.7643
Clustering   110: AMI= 0.6233, NMI= 0.6245, ARI= 0.6049 ,acc= 0.7538
Clustering   111: AMI= 0.6494, NMI= 0.6506, ARI= 0.6700 ,acc= 0.7955
#Epoch 111: Total: 1.812602 Clustering Loss: 0.118209 ZINB Loss1: 0.666840 ZINB Loss2: 0.822511 KL Loss: 0.205043
Clustering   112: AMI= 0.6432, NMI= 0.6444, ARI= 0.6577 ,acc= 0.7883
Clustering   113: AMI= 0.6419, NMI= 0.6431, ARI= 0.6555 ,acc= 0.7860
Clustering   114: AMI= 0.6212, NMI= 0.6224, ARI= 0.5924 ,acc= 0.7449
Clustering   115: AMI= 0.6342, NMI= 0.6354, ARI= 0.6385 ,acc= 0.7747
Clustering   116: AMI= 0.6132, NMI= 0.6144, ARI= 0.5697 ,acc= 0.7258
Clustering   117: AMI= 0.6266, NMI= 0.6278, ARI= 0.6163 ,acc= 0.7622
Clustering   118: AMI= 0.6493, NMI= 0.6505, ARI= 0.6698 ,acc= 0.7952
Clustering   119: AMI= 0.6267, NMI= 0.6279, ARI= 0.6201 ,acc= 0.7635
Clustering   120: AMI= 0.6365, NMI= 0.6378, ARI= 0.6473 ,acc= 0.7811
Clustering   121: AMI= 0.6457, NMI= 0.6469, ARI= 0.6624 ,acc= 0.7910
#Epoch 121: Total: 1.808842 Clustering Loss: 0.114941 ZINB Loss1: 0.666752 ZINB Loss2: 0.822904 KL Loss: 0.204244
Clustering   122: AMI= 0.6381, NMI= 0.6392, ARI= 0.6429 ,acc= 0.7790
Clustering   123: AMI= 0.6523, NMI= 0.6534, ARI= 0.6751 ,acc= 0.7985
Clustering   124: AMI= 0.6448, NMI= 0.6460, ARI= 0.6590 ,acc= 0.7888
Clustering   125: AMI= 0.6446, NMI= 0.6458, ARI= 0.6620 ,acc= 0.7908
Clustering   126: AMI= 0.6444, NMI= 0.6456, ARI= 0.6567 ,acc= 0.7877
Clustering   127: AMI= 0.6531, NMI= 0.6542, ARI= 0.6764 ,acc= 0.7994
Clustering   128: AMI= 0.6526, NMI= 0.6538, ARI= 0.6750 ,acc= 0.7986
Clustering   129: AMI= 0.6348, NMI= 0.6360, ARI= 0.6342 ,acc= 0.7741
Clustering   130: AMI= 0.6513, NMI= 0.6524, ARI= 0.6723 ,acc= 0.7970
Clustering   131: AMI= 0.6371, NMI= 0.6383, ARI= 0.6385 ,acc= 0.7770
#Epoch 131: Total: 1.803789 Clustering Loss: 0.111295 ZINB Loss1: 0.666789 ZINB Loss2: 0.822329 KL Loss: 0.203376
Clustering   132: AMI= 0.6534, NMI= 0.6545, ARI= 0.6764 ,acc= 0.7994
Clustering   133: AMI= 0.6560, NMI= 0.6570, ARI= 0.6818 ,acc= 0.8019
Clustering   134: AMI= 0.6565, NMI= 0.6576, ARI= 0.6813 ,acc= 0.8022
Clustering   135: AMI= 0.6565, NMI= 0.6575, ARI= 0.6813 ,acc= 0.8021
delta_label  0.0006352087114337568 < tol  0.001
Reach tolerance threshold. Stopping training.
Total time: 1696 seconds.
Final: AMI= 0.6565, NMI= 0.6575, ARI= 0.6813

0 人点赞