迁移学习

2022-09-28 20:03:58 浏览数 (2)

迁移学习指的是在相同的模型下,我们在某一份数据上学习的知识可以应用到另外一份数据上去。也就是在某一个场景下学习的知识应用到另外一个场景,这两个场景间不同但是相关。

上式中的

称为源域(Source Domain),域可以理解成一组数据以及这组数据的特征组成的一种结构。

是数据,

是训练出来的模型,它表示当送入的数据X的时候能够预测对应的标签。这个过程我们可以称为学习任务

上式中的

称为目标域(Target Domain),

是目标域的数据,

是目标域训练出来的模型,它表示当送入目标域的数据X的时候能够预测对应的标签。这个过程称为学习任务

迁移学习就是在训练目标域的模型的时候能够使用源域的知识。

在上图中,我们在一份包含了圆和三角形的数据集上进行训练来识别圆和三角形,然后将模型给存储下来。将存储的模型给迁移到对四边形的数据样本进行学习的任务上来,这就是迁移学习。

为什么使用迁移学习

平时我们使用的都是监督学习,那为什么要使用迁移学习呢?原因就在于数据成本。不是任何一个行业都有充足的数据,都有大量的数据做标记。

假设我们需要对三种光 —— 可见光、红外光、紫外光下的图像进行处理。我们在可见光下做数据物体识别或者分类,在红外图下测温,在紫外图下检测放电。它们使用的相机的成本不同,可见光的相机比较便宜,获取数据的成本很低;红外相机的成本较高,获取数据的成本也较高;而紫外相机的成本特别高,使用的次数也有限,故而获取数据的成本特别高,训练数据集的数据量有限。这个时候就可以使用迁移学习。

迁移学习的分类

迁移学习有一个广泛认可的前提,就是

,源域的数据要远远多于目标域。

迁移学习可以分为三类,第一种叫做归纳式迁移学习 (Inductive Transfer Learning),是我们平时使用特别多的;第二种叫直推式迁移学习 (Transductive Transfer Learning);第三种叫无监督迁移学习 (Unsupevised Transfer Learning)。它们都有相关的领域 (Relate Areas),归纳式迁移学习有两种相关领域,一个是多任务学习 (Multi-task Learning),一个是自主学习 (Self-taught Learning),它们的区别在于多任务学习的源域的数据是带标签的,而自主学习的源域数据是不带标签的,它们的目标域数据都是带标签的。它们的相关任务可以是回归和分类都可以。一般目标域的数据量很小,标注的成本并不高,算法工程师完全可以自己标注。

直推式迁移学习的相关领域有领域自适应 (Domain Adaptation)、样本选择偏差 (Sample Selection Bias)、共变量偏移 (Co-variate Shift)。它的源域的数据是有标签的,而目标域的数据是没有标签的。无监督迁移学习没有特定的相关领域,它的源域和目标域的数据集都是没有标签的,而且它的任务也不是回归 (Regression) 和分类 (Classification) 而是聚类 (Clustering) 和降维 (Dimensionality Reduction)。

  • 多任务迁移学习

上表中源域是一个猫狗分类,目标域有两种,一个熊猫、考拉的分类,另一个是袋鼠、麋鹿的分类。

在上图中,我们将猫狗分类的模型给训练完成之后,将其网络的各个层 (Layer) 给挪过来,这些层可能是一些特征抽取层,也可能是一些卷积层,或者包含了 RNN 层,这些我们统称为 Shared layers。对于这些 Shared layers,我们可以将其参数固定或者不固定都是可以的。再拼上其他几个目标域的任务,比如 Task A 是熊猫、考拉的分类,Task B 是袋鼠、麋鹿的分类。这些目标域任务的网络结构,我们也需要去进行设计,但通常是比较简单的分类器。

  • 自主迁移学习

我们在了解自主迁移学习之前先要了解一下自编码器 (AutoEncoder)。

在自编码器中,我们希望送入图像 X 和输出图像

能够无限接近。X 会先送进编码器 (ENCODER) 变成一个中间向量 Z,Z 再经过解码器 (DECODER) 输出

。这里为什么叫自编码器呢?这是因为

的标签就是 X 本身,在训练的过程中

会向着 X 无限接近。

自编码器的作用是在高分辨率的音频或者图像的传输过程中 (比如说 4K、8K),通常网络带宽是有限的,我们可以将高分辨率的 X 给压缩到 Z,然后将 Z 给送入到网络中进行传输,这个 Z 通常比较小,在接收端接收 Z 再运行解码器将 Z 解码还原成 X。当然这是自编码器的一个应用。

还有一种就是传入的是一张马赛克的图片,经过自编码器之后,输出的是一个没有打马赛克的图片,这也是 AutoEncoder 的一个用途。这里通常叫做去噪自编码器。

自主迁移学习的源域的数据集是不带标签的,但是目标域的数据集是带标签的。

首先,我们将源域的数据集 (猫狗) 送入自编码器,训练出一个完备的自编码器模型。我们把编码器 (Encoder) 部分给取出来,我们认为编码器部分具备一种能力,就是能够对输入的图像数据进行特征的抽取。

将编码器部分取出来之后,在其上层接一个简单的分类层 (Classification layers),组成一个新的网络,来训练目标域的熊猫、考拉分类。这个整个过程就叫做自主迁移学习。这种方式的迁移学习的成本比多任务迁移学习的成本更低,它不需要对源域数据进行标注。

迁移学习的实施方法

这里我们针对的是多任务迁移学习,现在我们假设的场景是做一个鱼的识别分类器,但是没有相关的数据集。

现在我们要找到一个源域,任务是做图片分类器,获取源域知识,在少量鱼标记数据上迁移学习 (几十张或者上百张)。

方案设计:

:ImageNet 数据集,

:ImageNet 上的分类任务,如 VGGNet;

:少量鱼的标记数据,

:鱼的识别分类。

这样可以有四种方法来进行实现。

方法一:pretrained VGGNet (预训练模型) 特征抽取器 SVM / 贝叶斯分类器训练,这种方法基本上已经过时了,属于传统的机器学习。

一般来说,我们会把前面的卷积层当成特征抽取器,后面的全连接层当成分类器。

我们可以将其分类器给删除掉,只保留卷积层。然后将所有鱼的图片送入 CNN 中进行特征抽取,将抽取完的特征存储下来。然后再来训练一个简单的 SVM 或者贝叶斯分类器来对鱼进行分类。

方法二:pretrained VGGNet 特征抽取器 修改 FC layer。

在这里,我们不去删除分类器,而是去修改分类器,因为 ImageNet 原本的分类器可能是成千上万种类别,我们假设鱼的种类是一百种,那么我们就需要将 output 的分类数量修改成 100。甚至我们可以在分类器中增加一些卷积层,再通过 FC 层来输出都是可以的。

方法三:固定 pretrained VGGNet 特征抽取器 修改 FC layer。

在上一种方法的训练过程中,整个的层,无论是特征抽取层还是分类器层都在进行权重的更新,模型参数都在变化。而在本方法中,CNN 层的参数是不参与反向传播的,其中的参数不会发生变化。我们认为 CNN 层已经训练完备具备特征抽取的能力了,我们只需要修改后面的分类器网络就可以了。这样的话训练时间会缩短,因为真正训练的只有分类器的部分。这个方法在深度学习中有一个专门的名词叫微调 (Fine tuning),它就是迁移学习的一种方法。

方法四:多个不同的 pretrained 特征抽取器 分类器 FC layer。

这里不同的预训练模型可以是不同的,比如第二个预训练模型是 MobileNet,把需要训练的鱼的图片分别送入不同的预训练模型,得到不同的模型抽取出来的特征,再将这些不同的模型的特征拼接 (concat) 起来,最后设计一个分类器,将拼接的特征送入分类器中。由于方法四比较麻烦,平时用的不是很多。通常方法三和方法二是用的比较多的。

基于 ResNet 的迁移学习

迁移学习的源域的预训练模型的要求其实是比较苛刻的,它必须是在一个大规模的数据上做预训练,否则会导致该预训练模型的特征抽取的能力是有限的。

我们这一次做的是一个图片动作识别的动作分类。设计思路如下

:ImageNet 数据集,

:ImageNet 上的分类任务,ResNet34;

:少量动作标记数据,

:动作分类。

我们先来看一个网络参数的概念,这里以一个最简单的全连接层来说明

代码语言:javascript复制
import torch.nn as nn

if __name__ == '__main__':

    net = nn.Linear(10, 100)
    for k, v in net.named_parameters():
        print(k)
        print(v)

运行结果

代码语言:javascript复制
weight
Parameter containing:
tensor([[-0.1512,  0.0094, -0.1952, -0.2911, -0.0713, -0.0749, -0.2792,  0.0223,
          0.2715, -0.2801],
        [ 0.1736,  0.0219, -0.0839, -0.3110, -0.2997, -0.0009, -0.0440,  0.1641,
         -0.1014,  0.0820],
        [-0.0070, -0.1996, -0.3114,  0.1465,  0.1702,  0.0227, -0.3075, -0.0964,
         -0.2725, -0.2684],
        [ 0.0267,  0.1512, -0.1803,  0.0338, -0.2362, -0.2396,  0.2440, -0.2711,
         -0.3000, -0.2354],
        [ 0.1166, -0.2163, -0.0593, -0.0437, -0.2011, -0.1675, -0.2814, -0.2531,
          0.2135,  0.3139],
        [ 0.0946,  0.2476,  0.0103,  0.1359,  0.2417, -0.2008,  0.2895,  0.1928,
         -0.2859, -0.1904],
        [ 0.1790,  0.2405,  0.1249,  0.0475,  0.0666, -0.0709, -0.0544,  0.2075,
         -0.2277, -0.1843],
        [ 0.2473,  0.0936, -0.2557, -0.2785,  0.0655,  0.0986,  0.2453, -0.1914,
          0.2629, -0.0262],
        [ 0.1863,  0.1443, -0.1161, -0.2848, -0.1022, -0.1912, -0.1587, -0.3046,
         -0.0015, -0.3112],
        [-0.2585,  0.1969, -0.1671, -0.1214,  0.1781, -0.2503, -0.0686, -0.2183,
          0.1106, -0.2624],
        [ 0.2152,  0.2842, -0.2983, -0.2079,  0.0909, -0.1147, -0.0644,  0.0014,
          0.1307,  0.2361],
        [-0.2991, -0.1975, -0.2118,  0.2784, -0.0173, -0.0733, -0.2660,  0.2911,
         -0.1690,  0.0910],
        [ 0.0577,  0.2946,  0.1430, -0.1700, -0.2800,  0.2422, -0.2416,  0.0487,
         -0.0304,  0.0530],
        [ 0.1164, -0.1942, -0.0533, -0.3036, -0.3057,  0.1948, -0.2119,  0.2421,
          0.1072,  0.2266],
        [-0.2987,  0.1223, -0.2962, -0.0942, -0.0894,  0.0562, -0.2319, -0.2158,
         -0.0385,  0.1021],
        [ 0.1150,  0.2645,  0.1348,  0.1642,  0.2264, -0.1805,  0.1146, -0.1962,
         -0.0057,  0.2338],
        [ 0.1584, -0.2839,  0.3120, -0.0224, -0.3065, -0.0568, -0.1833, -0.0924,
          0.1861, -0.1080],
        [ 0.1608,  0.1035, -0.1884,  0.2555,  0.0958, -0.2116,  0.0503,  0.1194,
         -0.3109, -0.0706],
        [ 0.0793,  0.0597, -0.2295, -0.1508, -0.1926, -0.2338,  0.2232,  0.1940,
          0.1564,  0.2526],
        [-0.2352,  0.1136, -0.1300,  0.0453,  0.2190, -0.2185,  0.1502, -0.0299,
         -0.0095,  0.2036],
        [ 0.1303,  0.2006,  0.1289,  0.2555,  0.2188, -0.2637, -0.1108, -0.1133,
         -0.0492,  0.0312],
        [-0.2892, -0.0011,  0.0701,  0.0173, -0.1867,  0.2510, -0.1801,  0.1558,
          0.2144, -0.2392],
        [ 0.2987,  0.2862, -0.2375,  0.2574,  0.3083,  0.1118,  0.0774,  0.2802,
          0.0599,  0.2141],
        [-0.1782, -0.1406,  0.2775, -0.0514, -0.1559,  0.0087,  0.1747,  0.2079,
         -0.1944,  0.2855],
        [ 0.1666,  0.0111, -0.2879, -0.0590,  0.2913, -0.1520,  0.2336, -0.1688,
         -0.1175,  0.2756],
        [-0.1604, -0.1052, -0.2323, -0.0987, -0.0772, -0.3009, -0.0594, -0.0157,
          0.0788,  0.1557],
        [-0.2793, -0.0865, -0.0839, -0.0633, -0.2992, -0.0298,  0.2993, -0.2796,
          0.0404,  0.2271],
        [ 0.2273,  0.1380,  0.0820, -0.0906,  0.2885,  0.1595,  0.2183, -0.0962,
          0.1451,  0.1738],
        [-0.1709, -0.1991, -0.1436,  0.2222, -0.2246, -0.2698,  0.1426,  0.0178,
         -0.1507,  0.1962],
        [-0.1052, -0.2710,  0.0081,  0.2011,  0.2629,  0.1788, -0.2595,  0.1949,
          0.0919,  0.2111],
        [ 0.1863,  0.2820, -0.0596,  0.1023, -0.1286, -0.2530,  0.2020,  0.2202,
          0.0120,  0.3002],
        [-0.0092,  0.0400,  0.0814,  0.2381, -0.1310,  0.1829, -0.2359,  0.1275,
          0.1898,  0.0451],
        [ 0.0569, -0.0740,  0.0449,  0.2888, -0.0633, -0.3003,  0.0307, -0.0767,
         -0.1918,  0.1996],
        [-0.3038, -0.0761, -0.0517, -0.0505,  0.0042,  0.3071,  0.2066, -0.1765,
         -0.2446,  0.2819],
        [ 0.1288, -0.2963, -0.0198,  0.2153, -0.2704,  0.3154,  0.2304,  0.0614,
          0.2988,  0.0271],
        [-0.0692,  0.1584, -0.2810, -0.2819,  0.2259,  0.1131,  0.0133, -0.1142,
         -0.2470,  0.1652],
        [-0.1919,  0.0719,  0.1883,  0.1932, -0.0420, -0.1415,  0.1281,  0.0193,
          0.2798, -0.0996],
        [-0.0504,  0.1131, -0.2465,  0.3084, -0.2502,  0.1388,  0.0035, -0.1159,
         -0.1872,  0.0745],
        [-0.0417, -0.2954,  0.0080,  0.0548,  0.2130,  0.2834, -0.2383,  0.2274,
         -0.2498, -0.0611],
        [ 0.1023, -0.0903, -0.1154,  0.0349,  0.0098,  0.1594,  0.0831,  0.2525,
         -0.3160, -0.2358],
        [-0.2916, -0.2148, -0.2694,  0.2258, -0.1018,  0.2207, -0.1694, -0.2033,
          0.2837, -0.0907],
        [ 0.0056,  0.0630, -0.0601,  0.1789,  0.1926,  0.1798,  0.1020,  0.0217,
         -0.1055,  0.1556],
        [-0.1102, -0.0975, -0.0067,  0.1905,  0.0035,  0.2388,  0.0065,  0.0132,
         -0.1996,  0.0172],
        [-0.2918,  0.2914,  0.0641,  0.1553,  0.3099, -0.2043,  0.1944, -0.2247,
         -0.1089, -0.2171],
        [-0.1193,  0.1034, -0.1453,  0.2687, -0.0415, -0.2016, -0.0827, -0.0345,
         -0.2025,  0.0339],
        [-0.1210, -0.1386,  0.0783,  0.2673,  0.1191, -0.0228, -0.2321, -0.2860,
          0.2036, -0.2946],
        [-0.1985, -0.0979,  0.0378, -0.2486,  0.2781, -0.2019,  0.0919, -0.0009,
         -0.0861, -0.1081],
        [ 0.2206,  0.1363, -0.1479, -0.1972, -0.0796, -0.2008, -0.2314,  0.1732,
         -0.2197,  0.0140],
        [-0.0204, -0.3081, -0.1041, -0.3068, -0.0709,  0.1541, -0.0008,  0.2925,
         -0.1648, -0.1411],
        [-0.0933,  0.0322, -0.0915, -0.1996,  0.0417, -0.1067,  0.1419,  0.0852,
          0.2511, -0.0614],
        [-0.1479,  0.2027, -0.2771, -0.2433, -0.2435,  0.1445, -0.0952,  0.1242,
          0.0900, -0.3120],
        [ 0.1153,  0.1787,  0.0085, -0.0630,  0.2055, -0.0236, -0.0052, -0.2624,
         -0.0829,  0.0406],
        [ 0.1813,  0.1925,  0.0043, -0.0118, -0.2471, -0.3103, -0.1167, -0.1799,
          0.2266,  0.0270],
        [-0.1398, -0.0731, -0.0338,  0.2116,  0.1625,  0.2184, -0.3046, -0.2461,
         -0.0441, -0.1424],
        [-0.1943, -0.3038,  0.0419,  0.2189, -0.2121,  0.0968,  0.2624,  0.2590,
         -0.0535,  0.0378],
        [ 0.2851,  0.0171,  0.2299,  0.1639, -0.1652,  0.1702, -0.1550, -0.0092,
         -0.2946,  0.1806],
        [ 0.1249,  0.0773,  0.0954,  0.2616, -0.0767,  0.0233, -0.0781,  0.2788,
          0.1121,  0.3141],
        [-0.2453,  0.0101,  0.2659,  0.2034, -0.1296, -0.0657,  0.1958,  0.2650,
         -0.2125,  0.0671],
        [-0.1074,  0.2632, -0.1679, -0.2281,  0.1627,  0.2332,  0.0733, -0.1283,
         -0.0760,  0.0250],
        [ 0.1210,  0.2531,  0.2307, -0.1291,  0.0360, -0.1463,  0.2424, -0.1343,
         -0.2658, -0.2920],
        [ 0.1538, -0.2144,  0.2966, -0.1866, -0.0836, -0.0634,  0.0228, -0.2590,
         -0.2466, -0.3013],
        [ 0.2660,  0.3010,  0.1324, -0.2321,  0.1065, -0.0549, -0.2605, -0.2198,
          0.3006,  0.0401],
        [-0.0929,  0.1887, -0.2695, -0.1645, -0.2606, -0.2075,  0.0252, -0.0248,
         -0.0513,  0.3071],
        [-0.1121, -0.2630,  0.1714, -0.0614,  0.1244, -0.2698,  0.1364,  0.1229,
         -0.2427,  0.2772],
        [-0.2655,  0.2321, -0.2822, -0.2833,  0.0768, -0.2331,  0.0535, -0.3106,
         -0.0816,  0.2979],
        [ 0.0299,  0.3092,  0.1933, -0.2018, -0.0641, -0.0592, -0.3037, -0.1947,
         -0.2685,  0.1553],
        [-0.1637, -0.0149,  0.1954, -0.0853,  0.2154,  0.2400,  0.0392, -0.2939,
          0.1578,  0.0464],
        [-0.2362,  0.3022, -0.0435,  0.2462, -0.2842, -0.2340, -0.0492, -0.1501,
          0.2912,  0.0400],
        [-0.0588,  0.2687,  0.2557, -0.0823,  0.2021,  0.0462, -0.2523, -0.0573,
         -0.1844, -0.0162],
        [-0.2115, -0.0668,  0.1865, -0.3012, -0.1430, -0.1858,  0.0901, -0.0147,
          0.1845, -0.1413],
        [-0.1926,  0.1849, -0.0391, -0.2165,  0.2262,  0.0282, -0.2720,  0.1138,
         -0.2184, -0.0629],
        [ 0.0335,  0.1333, -0.1070,  0.0370, -0.2160,  0.0710,  0.0726,  0.0023,
          0.1974,  0.2498],
        [ 0.2640,  0.2644,  0.2584,  0.0208,  0.0063,  0.3010,  0.0234, -0.0271,
         -0.2447,  0.1254],
        [-0.1791,  0.2479, -0.1122, -0.0476,  0.1108,  0.2430, -0.2596, -0.3142,
         -0.2188,  0.2107],
        [ 0.2919,  0.0231,  0.0352,  0.2631,  0.2298, -0.3115,  0.1193,  0.2997,
          0.2370, -0.0700],
        [-0.2155,  0.2136,  0.1528, -0.0981,  0.1575, -0.1495,  0.1370,  0.2769,
         -0.2264,  0.0624],
        [-0.1599, -0.3073, -0.1789, -0.0792, -0.2543, -0.3146, -0.3094, -0.2935,
          0.1901, -0.1090],
        [-0.1883, -0.0371,  0.2633, -0.2636, -0.2709,  0.2084,  0.2148, -0.2893,
         -0.2604,  0.1425],
        [-0.1513,  0.3066, -0.0617,  0.2080, -0.2217,  0.1986, -0.2722, -0.2945,
         -0.0990, -0.2623],
        [ 0.0778,  0.0230,  0.2430,  0.0546, -0.2561, -0.1747,  0.0279, -0.2440,
         -0.2042,  0.2642],
        [-0.3036, -0.0077,  0.2350,  0.0612,  0.1382, -0.2675, -0.2863, -0.1425,
         -0.3115,  0.2145],
        [ 0.2844,  0.2778,  0.2577, -0.2005,  0.1673,  0.2410, -0.2546, -0.2883,
         -0.1838, -0.0043],
        [-0.0037,  0.0763, -0.1020, -0.1803, -0.2030,  0.0698,  0.2741,  0.2451,
          0.0505,  0.0694],
        [ 0.1746,  0.0044,  0.2061,  0.3072,  0.1811, -0.2480, -0.0533,  0.1855,
         -0.0852, -0.0088],
        [-0.2080,  0.1583,  0.1491,  0.0605, -0.1563, -0.0535,  0.0160, -0.1581,
         -0.1604, -0.0974],
        [-0.2601,  0.0592, -0.1958, -0.1943,  0.1719, -0.0783,  0.1113,  0.2861,
          0.0428,  0.2575],
        [ 0.0406, -0.2026,  0.2198, -0.0473, -0.1599,  0.2828,  0.2457, -0.2226,
          0.0226, -0.2781],
        [ 0.1705, -0.2185, -0.1890, -0.0477, -0.2320, -0.3094, -0.2659,  0.1758,
          0.2335, -0.0022],
        [ 0.2599, -0.1549,  0.0968, -0.2937,  0.0167, -0.1488, -0.1077, -0.0021,
         -0.1929,  0.3076],
        [-0.0359,  0.0718,  0.0344, -0.2514,  0.0576,  0.2859,  0.0963, -0.2750,
         -0.0993,  0.0720],
        [ 0.2135,  0.1362,  0.2285,  0.0136, -0.1709, -0.0345, -0.1640, -0.0519,
         -0.0951,  0.0840],
        [-0.0771,  0.2575, -0.0879,  0.3008, -0.2960,  0.0905,  0.0713,  0.1362,
          0.2647,  0.1353],
        [ 0.1529,  0.2251, -0.2066, -0.1648, -0.3029, -0.1679, -0.2598,  0.2772,
         -0.1182, -0.2175],
        [-0.1030, -0.0383,  0.0182,  0.0668, -0.1874, -0.0833,  0.0136, -0.0979,
         -0.2099,  0.1666],
        [ 0.1873, -0.2892,  0.0895,  0.1929,  0.1416,  0.2782,  0.1888, -0.2935,
          0.0140,  0.3150],
        [ 0.0238,  0.0388, -0.0680,  0.0389,  0.0300, -0.1332, -0.1168, -0.0442,
          0.2117,  0.1490],
        [-0.2340, -0.1430,  0.1278, -0.0571, -0.0926,  0.3115, -0.0142, -0.2753,
         -0.3159,  0.2065],
        [-0.1393,  0.2226, -0.2273, -0.2404, -0.1420,  0.1291,  0.2447,  0.2914,
         -0.1132,  0.0521],
        [-0.1739, -0.0598,  0.1113,  0.1940,  0.2713, -0.2998, -0.1691, -0.0870,
          0.0477, -0.2552],
        [-0.1625,  0.1031, -0.2073, -0.3013, -0.2973, -0.1255, -0.2814,  0.0016,
          0.1938,  0.1943]], requires_grad=True)
bias
Parameter containing:
tensor([ 0.1419, -0.2517, -0.1293, -0.1948, -0.0094, -0.2624,  0.0806, -0.0085,
         0.2750, -0.2095,  0.0500, -0.1621,  0.1437, -0.0636,  0.1156,  0.0411,
         0.1624,  0.0533, -0.2659,  0.2860,  0.1105,  0.2902,  0.0260,  0.0968,
         0.0155, -0.1847,  0.2143, -0.1323,  0.1891,  0.0966,  0.0019, -0.2211,
         0.2195,  0.0728, -0.3100, -0.0098, -0.0672,  0.0225, -0.3005,  0.3063,
        -0.1433, -0.1355,  0.0631, -0.1178,  0.2228, -0.2700,  0.1666,  0.1650,
        -0.0480, -0.0855,  0.2002,  0.3024,  0.0503,  0.1469, -0.2147,  0.0215,
         0.1430,  0.1396,  0.3069, -0.3065,  0.2825,  0.2960, -0.0314,  0.2117,
        -0.1285,  0.1147, -0.1420,  0.0471,  0.2387,  0.1604,  0.3016, -0.3138,
         0.0422, -0.1260, -0.1738, -0.0683, -0.2231, -0.0146, -0.1972,  0.2797,
         0.3107,  0.0211, -0.2589,  0.2272,  0.2147, -0.3160,  0.0110,  0.1659,
        -0.1931, -0.1776,  0.1570,  0.0556,  0.0142,  0.0637, -0.2795,  0.1971,
        -0.0682, -0.2067,  0.0456,  0.2263], requires_grad=True)

这里的 weight、bias 就是打印出来的 k,它代表权重和偏置 (偏置其实也是权重的一种,代表 0 次方位的权值)。其他的就是 v 的值。我们在 v 的值中可以看到有 requires_grad=True,它代表这些值可以被反向传播进行更新。之前我们说我们可以在迁移学习的实施方法三中可以固定这些值不被反向传播更新,方法如下

代码语言:javascript复制
import torch.nn as nn

if __name__ == '__main__':

    net = nn.Linear(10, 100)
    # 这里其实就是net.named_parameters()的v
    # k只是每个参数项的名字
    for p in net.parameters():
        p.requires_grad = False
    for k, v in net.named_parameters():
        # 如果可以梯度更新
        if v.requires_grad:
            print(k)
            print(v)

此时我们运行,将不会打印任务内容。这代表我们将整个网络的参数给固定住了,在以后的训练过程中将不会进行更新。

这里我们使用的数据集下载地址是 http://vision.stanford.edu/Datasets/40actions.html,它包含了人类的 40 个动作的静态图片,每个动作类别有 180-300 张图像。但我们不会用这么多图像,每个动作就挑选出 5 张图片来进行试验。

在该数据集下有一个 imageSplits 文件夹中有一个 actions.txt,它包含了所有动作的名称。

代码语言:javascript复制
action_name			number_of_images
applauding			284
blowing_bubbles			259
brushing_teeth			200
cleaning_the_floor		212
climbing			295
cooking				288
cutting_trees			203
cutting_vegetables		189
drinking			256
feeding_a_horse			287
fishing				273
fixing_a_bike			228
fixing_a_car			251
gardening			199
holding_an_umbrella		292
jumping				295
looking_through_a_microscope	191	
looking_through_a_telescope	203
playing_guitar			289
playing_violin			260
pouring_liquid			200
pushing_a_cart			235
reading				245
phoning				259
riding_a_bike			293
riding_a_horse			296
rowing_a_boat			185
running				251
shooting_an_arrow		214
smoking				241
taking_photos			197
texting_message			193
throwing_frisby			202
using_a_computer		230
walking_the_dog			293
washing_dishes			182
watching_TV			223
waving_hands			210
writing_on_a_board		183
writing_on_a_book		246

我们将根据这份名称文件来生成一个 id 与名称互相映射的 json 文件

代码语言:javascript复制
import json

if __name__ == '__main__':

    file = "/Users/admin/Downloads/Stanford40/ImageSplits/actions.txt"
    i = 0
    data1 = {}
    data2 = {}
    data = {}
    with open(file, 'r') as f:
        for line in f:
            if i != 0:
                cls = line.split('t')[0]
                data1[str(i - 1)] = cls
                data2[cls] = str(i - 1)
            i  = 1
        data['id2cls'] = data1
        data['cls2id'] = data2
    json_str = json.dumps(data)
    print(json_str)
    with open('cls_mapper.json', 'w') as f_json:
        json.dump(data, f_json)

运行结果

代码语言:javascript复制
{"id2cls": {"0": "applauding", "1": "blowing_bubbles", "2": "brushing_teeth", "3": "cleaning_the_floor", "4": "climbing", "5": "cooking", "6": "cutting_trees", "7": "cutting_vegetables", "8": "drinking", "9": "feeding_a_horse", "10": "fishing", "11": "fixing_a_bike", "12": "fixing_a_car", "13": "gardening", "14": "holding_an_umbrella", "15": "jumping", "16": "looking_through_a_microscope", "17": "looking_through_a_telescope", "18": "playing_guitar", "19": "playing_violin", "20": "pouring_liquid", "21": "pushing_a_cart", "22": "reading", "23": "phoning", "24": "riding_a_bike", "25": "riding_a_horse", "26": "rowing_a_boat", "27": "running", "28": "shooting_an_arrow", "29": "smoking", "30": "taking_photos", "31": "texting_message", "32": "throwing_frisby", "33": "using_a_computer", "34": "walking_the_dog", "35": "washing_dishes", "36": "watching_TV", "37": "waving_hands", "38": "writing_on_a_board", "39": "writing_on_a_book"}, "cls2id": {"applauding": "0", "blowing_bubbles": "1", "brushing_teeth": "2", "cleaning_the_floor": "3", "climbing": "4", "cooking": "5", "cutting_trees": "6", "cutting_vegetables": "7", "drinking": "8", "feeding_a_horse": "9", "fishing": "10", "fixing_a_bike": "11", "fixing_a_car": "12", "gardening": "13", "holding_an_umbrella": "14", "jumping": "15", "looking_through_a_microscope": "16", "looking_through_a_telescope": "17", "playing_guitar": "18", "playing_violin": "19", "pouring_liquid": "20", "pushing_a_cart": "21", "reading": "22", "phoning": "23", "riding_a_bike": "24", "riding_a_horse": "25", "rowing_a_boat": "26", "running": "27", "shooting_an_arrow": "28", "smoking": "29", "taking_photos": "30", "texting_message": "31", "throwing_frisby": "32", "using_a_computer": "33", "walking_the_dog": "34", "washing_dishes": "35", "watching_TV": "36", "waving_hands": "37", "writing_on_a_board": "38", "writing_on_a_book": "39"}}

并且生成了一份 cls_mapper.json 的文件,内容与上面打印的相同。现在我们来生成训练数据集、验证数据集和测试数据集,并生成相应的标签文件。

代码语言:javascript复制
import shutil
import json

if __name__ == '__main__':

    dir = "/Users/admin/Downloads/Stanford40/JPEGImages"
    json_file = "./data/cls_mapper.json"
    with open(json_file, 'r') as f:
        cls = json.load(f)

    train = ""
    dev = ""
    test = ""
    for k, v in cls['cls2id'].items():
        for i in range(1, 6):
            shutil.move(dir   "/"   k   "_00"   str(i)   ".jpg", "./data/train/")
            train  = v   "|./data/train/"   k   "_00"   str(i)   ".jpgn"
            if i   5 > 9:
                shutil.move(dir   "/"   k   "_0"   str(i   5)   ".jpg", "./data/dev/")
                dev  = v   "|./data/dev/"   k   "_0"   str(i   5)   ".jpgn"
            else:
                shutil.move(dir   "/"   k   "_00"   str(i   5)   ".jpg", "./data/dev/")
                dev  = v   "|./data/dev/"   k   "_00"   str(i   5)   ".jpgn"
            shutil.move(dir   "/"   k   "_0"   str(i   10)   ".jpg", "./data/test/")
            test  = v   "|./data/test/"   k   "_0"   str(i   10)   ".jpgn"
    with open("./data/meta_train.txt", 'w') as f_train:
        f_train.write(train)
    with open("./data/meta_dev.txt", 'w') as f_dev:
        f_dev.write(dev)
    with open("./data/meta_test.txt", 'w') as f_test:
        f_test.write(test)
    f_train.close()
    f_dev.close()
    f_test.close()

运行结果

超参数配置

代码语言:javascript复制
import torch
# ################################################################
#                             HyperParameters
# ################################################################

class Hyperparameters:
    # ################################################################
    #                             Data
    # ################################################################
    device = 'cuda' if torch.cuda.is_available() else 'cpu' # cuda
    data_root = './data'

    cls_mapper_path = './data/cls_mapper.json'
    train_data_root = './data/train/'
    dev_data_root = './data/dev/'
    test_data_root = './data/test/'

    metadata_train_path = './data/meta_train.txt'
    metadata_dev_path = './data/meta_dev.txt'
    metadata_test_path = './data/meta_test.txt'

    classes_num = 40
    seed = 1234

    # ################################################################
    #                             Model
    # ################################################################
    if_conv_frozen = True

    # ################################################################
    #                             Exp
    # ################################################################
    batch_size = 4
    init_lr = 5e-4
    epochs = 100
    verbose_step = 30
    save_step = 30


HP = Hyperparameters()

工具方法

代码语言:javascript复制
import os
from PIL import Image


# 获取某个文件夹下面所有后缀为suffix的文件,返回path的list
def recursive_fetching(root, suffix=['jpg', 'png']):
    all_file_path = []

    def get_all_files(path):
        all_file_list = os.listdir(path)
        # 遍历该文件夹下的所有目录或者文件
        for file in all_file_list:
            filepath = os.path.join(path, file)
            # 如果是文件夹,递归调用函数
            if os.path.isdir(filepath):
                get_all_files(filepath)
            # 如果不是文件夹,保存文件路径及文件名
            elif os.path.isfile(filepath):
                all_file_path.append(filepath)

    get_all_files(root)

    file_paths = [it for it in all_file_path if os.path.split(it)[-1].split('.')[-1].lower() in suffix]

    return file_paths


def load_meta(meta_path):
    with open(meta_path, 'r') as fr:
        # 返回的是数字标签、图片路径组成的列表
        return [line.strip().split('|') for line in fr.readlines()]


def load_image(image_path):
    # 加载图片
    return Image.open(image_path)

数据集

代码语言:javascript复制
import torch
from torch.utils.data import DataLoader
from config import HP
from utils import load_meta, load_image
from torchvision import transforms as T
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import os

# 数据预处理
ac_transform = T.Compose([
    T.Resize((112, 112)),                   # 保证同样输入的shape
    T.RandomRotation(degrees=45),           # 减小倾斜图片影响
    T.GaussianBlur(kernel_size=(3, 3)),     # 抑制模糊图片的影响
    T.RandomHorizontalFlip(),               # 左右
    T.ToTensor(),                           # 归一化 & float32 tensor
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))   # 标准化
])


class ActionDataset(torch.utils.data.Dataset):
    def __init__(self, metadata_path):
        # 由数字标签和图片路径组成的数据集
        self.dataset = load_meta(metadata_path) # [(0, image_path), () ,...]

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # 获取数字标签和图片路径
        cls_id, path = int(item[0]), item[1]
        # 加载图片
        image = load_image(path)
        # 对图像进行数据预处理
        return ac_transform(image).to(HP.device), cls_id

    def __len__(self):
        return len(self.dataset)


if __name__ == '__main__':

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # openKMP cause unexpected error

    HP.device = 'cpu'
    ad = ActionDataset(HP.metadata_test_path)
    ac_loader = DataLoader(ad, batch_size=9, shuffle=True)
    for b in ac_loader:
        images = b[0]
        print(images.size())
        grid = vutils.make_grid(images, nrow=3)
        plt.imshow(grid.permute(1, 2, 0))
        plt.show()
        break

运行结果

迁移学习网络模型

代码语言:javascript复制
import torch
from torch import nn
from config import HP
import torchvision

# 1. Pretrained ResNet 34
# 2. delete and modify fc


class TransferLNet(nn.Module):
    '''
    迁移学习网络模型
    '''

    def __init__(self):
        super(TransferLNet, self).__init__()
        # 获取ImageNet的ResNet34预训练模型
        self.model = torchvision.models.resnet34(pretrained=True)
        # 是否固定预训练模型参数,此处是固定
        if HP.if_conv_frozen:
            for k, v in self.model.named_parameters():
                v. requires_grad = False
        # 获取预训练模型全连接层的输入通道数
        resnet_fc_dim = self.model.fc.in_features
        # 修改分类器
        new_fc_layer = nn.Linear(resnet_fc_dim, out_features=HP.classes_num) # new fc layer
        self.model.fc = new_fc_layer

    def forward(self, x):
        return self.model(x)


if __name__ == '__main__':
    x = torch.randn(size=(7, 3, 112, 112))
    model = TransferLNet()
    output = model(x)
    print(output.size())
    for k, v in model.named_parameters():
        if v.requires_grad:
            print(k)

运行结果

代码语言:javascript复制
torch.Size([7, 40])
model.fc.weight
model.fc.bias

由结果可以看出,模型最终输出的是 40 个分类。可以参加训练的参数也仅仅是全连接层的权重和偏置。

模型训练

代码语言:javascript复制
import os
from argparse import ArgumentParser
import torch.optim as optim
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from model import TransferLNet
from config import HP
from dataset_action import ActionDataset


logger = SummaryWriter('./log')

# seed init: Ensure Reproducible Result
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)


def evaluate(model_, devloader, crit):
    model_.eval() # set evaluation flag
    sum_loss = 0.
    total = 0
    with torch.no_grad():
        for batch in devloader:
            x, y = batch
            pred = model_(x)
            loss = crit(pred, y.to(HP.device))
            sum_loss  = loss.item()
            total  = (torch.argmax(pred, 1) == y).sum()

    model_.train() # back to training mode
    return sum_loss / len(devloader), total / (len(devloader) * HP.batch_size)


def save_checkpoint(model_, epoch_, optm, checkpoint_path):
    save_dict = {
        'epoch': epoch_,
        'model_state_dict': model_.state_dict(),
        'optimizer_state_dict': optm.state_dict()
    }
    torch.save(save_dict, checkpoint_path)


def train():
    parser = ArgumentParser(description="Model Training")
    parser.add_argument(
        '--c',
        default=None,
        type=str,
        help='train from scratch or resume training'
    )
    args = parser.parse_args()

    # 创建迁移学习模型
    model = TransferLNet()
    model = model.to(HP.device)

    # 交叉熵损失函数
    criterion = nn.CrossEntropyLoss()

    # 梯度下降优化器
    opt = optim.Adam(model.parameters(), lr=HP.init_lr)
    # opt = optim.SGD(model.parameters(), lr=HP.init_lr)

    # 训练数据集
    trainset = ActionDataset(HP.metadata_train_path)
    train_loader = DataLoader(trainset, batch_size=HP.batch_size, shuffle=True, drop_last=True)

    # 验证数据集
    devset = ActionDataset(HP.metadata_dev_path)
    dev_loader = DataLoader(devset, batch_size=HP.batch_size, shuffle=False, drop_last=False)

    start_epoch, step = 0, 0
    # 是否加载已训练模型
    if args.c:
        checkpoint = torch.load(args.c)
        model.load_state_dict(checkpoint['model_state_dict'])
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print('Resume From %s.' % args.c)
    else:
        print('Training From scratch!')

    model.train()   # set training flag

    # 开始训练
    train_total = 0
    dev_acc_init = 0
    for epoch in range(start_epoch, HP.epochs):
        print('Start Epoch: %d, Steps: %d' % (epoch, len(train_loader)))
        for batch in train_loader:
            x, y = batch    # load data
            opt.zero_grad() # gradient clean
            pred = model(x) # forward process
            loss = criterion(pred, y.to(HP.device))   # loss calc
            train_total  = (torch.argmax(pred, 1) == y).sum()
            loss.backward() # backward process
            opt.step()
            train_acc = train_total / (len(train_loader) * HP.batch_size)
            logger.add_scalar('Loss/Train', loss, step)
            # 评估模型精度
            if not step % HP.verbose_step:  # evaluate log print
                eval_loss, evaL_acc = evaluate(model, dev_loader, criterion)
                logger.add_scalar('Loss/Dev', eval_loss, step)
            # 保存模型
            model_path = 'model.pth'
            if evaL_acc > dev_acc_init:
                save_checkpoint(model, epoch, opt, os.path.join('model_save', model_path))
                dev_acc_init = evaL_acc

            step  = 1
            logger.flush()
            print('Epoch: [%d/%d], step: %d Train Loss: %.5f, Train Acc: %.5f Dev Loss: %.5f, Eval Acc: %.5f'
                  % (epoch, HP.epochs, step, loss.item(), train_acc, eval_loss, evaL_acc))
        train_total = 0
    logger.close()


if __name__ == '__main__':
    train()

模型推理

代码语言:javascript复制
import torch
from torch.utils.data import DataLoader
from dataset_action import ActionDataset
from model import TransferLNet
from config import HP
import time

# change to cpu
HP.device = 'cpu'
# 创建一个迁移学习模型
model = TransferLNet()
# tenor -> device='cuda'/'cpu'
# 加载已训练的模型
checkpoint = torch.load('./model_save/model.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

# 测试数据集
testset = ActionDataset(HP.metadata_test_path)
test_loader = DataLoader(testset, batch_size=HP.batch_size, shuffle=False, drop_last=False)
model.eval()

total_cnt = 0
correct_cnt = 0

start_st = time.time()
with torch.no_grad():
    for batch in test_loader:
        x, y = batch
        pred = model(x)
        total_cnt  = pred.size(0)
        correct_cnt  = (torch.argmax(pred, 1) == y).sum()
print(time.time()-start_st)
print('Acc: %.3f' % (correct_cnt/total_cnt))

0 人点赞