编译 | Zhihong Deng
和这篇文章的作者一样,有时想到一个很棒的点子,辛苦写好代码,终于运行正常了,但是效果就是不咋地,不免懊恼地产生一种“难道我的点子不行?”的想法。真的是点子不行吗?未必,NN不work的原因有很多种,作者在这篇博客中根据自己的实践经验分享了很多宝贵的建议。
目录
0. 如何使用这份指引?
I. 与数据集相关的问题
II. 与数据规范化/增广相关的问题
III. 与实现相关的问题
IV. 与训练相关的问题
如何使用这份指引
出错的原因千千万,但其中某些因素是更容易发现和修改的,所以作者给出了一个短短的列表,列出出错时他最先用来自检的一些方法:
1. 从简单的能解决目标任务的模型(比如VGG之于图像分类)开始,并尽量使用标准的损失函数;
2. 去掉各种非必要的小trick,比如正则化和数据增广;
3. 如果是进行模型微调(finetuning),注意检查预处理是否和原模型的训练一致;
4. 检查输入数据是否是正确的;
5. 从一个小数据集(2~20个样本)开始,让模型能够过拟合它,逐步增加数据,观察结果;
6. 逐步修改,比如重新引入正则化和数据增广,使用自定义的损失函数,使用更复杂的模型...
如果以上方法都试了,还是不work,那工作量可能就得比较大了:
与数据集相关的问题
1. 检查输入数据 Check your input data
确认输入网络的数据是合理的。作者举了几个例子,有时候处理图像数据时会宽和高这两个维度混淆,有时候会误把全0输入给网络,或者一直使用同一个batch来训练一个网络。要避免这种错误只要把输入打印出来看一看就好了。
2. 尝试随机输入 Try random input
如果使用随机输入也能产生相同/接近的效果,那么很明显你的模型在某一步把输入数据变成了(不带任何信息的)垃圾。尝试逐层/逐运算来debug,判断是哪一步出错。
3. 检查数据加载器 Check the data loader
有时候你的输入数据没问题,但是传递数据的代码有问题,可以把输入网络的数据(在进行处理之前)打印出来检查一下。
4. 确保输入和标记是关联的 Make sure input is connected to output
要检查每条输入数据是否正确地和对应的标记构成一条训练样本。如果每个epoch有对训练样本打乱顺序,要确保打乱顺序后这种对应关系仍然是正确的。
5. 输入和标记之间的关系是否太过随机 Is the relationship between input and output too random?
输入和标记之间的关系太过随机,或者说不随机的部分太少,输入不足以和标记产生关联关系(模型学不到有用的信息),也是一个模型不work的原因,这是数据的本质决定的,没办法改变。
6. 是否有太多噪声数据 Is there too much noise in the dataset?
有时候并不是所有样本都是有用的,有些样本被标注了错误的类别,它们会对网络的训练造成很坏的影响,这只能通过手动检查样本来发现了。文中特别提到了 ICLR-15 的这篇论文,即使对 MNIST 数据集 50% 的标签进行了污染,仍然能取得 50% 以上的准确率。所以说什么量级才是 “太多” 是存在争议的,毕竟现实世界的数据就是有很多噪声了,模型应当在某种程度上具备handle噪声的能力。
7. 打乱数据集 Shuffle the dataset
如果你的数据集在训练时没有打乱,甚至说是按一个特定的方式排序的(比如按类标的大小),那么很可能会对模型的学习造成负面影响。打乱数据集可以避免这一点,同时也要注意第4点,打乱数据集要确保输入和标记的对应关系不变。
8. 减少类别不平衡 Reduce class imbalance
如果你的数据存在类别不平衡问题,比如A类的样本是B类的1000倍,那么你可能就需要重新设计你的损失函数去改善这个问题,也可以尝试别的一些方法。
9. 是否有足够的训练样本 Do you have enough training examples?
如果你要从头开始训练一个模型而不是做微调(finetuning),那么你很可能需要大量的数据才能让模型达到期望的效果。比如对于图像分类任务来说,一般认为每个类别至少需要1000张图片。
10. 确保 batch 中的样本不同属一个类别 Make sure your batches don’t contain a single label
这种情况对于有序的数据集很常见(比如前一万个样本都是同一类别的),最简单的解决方法就是在训练前打乱数据集的顺序。
11. 减少 batch size Reduce batch size
这篇论文指出使用太大的 batch size 会降低模型的泛化能力。
补充1. 使用标准数据集 Use standard dataset (e.g. mnist, cifar10)
这一条来自于网友 @hengcherkeng
在测试新的网络结构或者新代码时,先使用标准数据集而不是自己的数据来实验,因为这些标准数据集已被证明是可解决的,而且有可以对比的参考结果(之前的模型能做到什么地步)。标准数据集上不会有那么多的噪声数据,也不会有训练集/测试集分布相差太远的问题,数据集本身的难度也不过太高(现有工作能做得比较好)。
与数据规范化/增广相关的问题
12. 对特征进行标准化 Standardize the features
对输入特征进行标准化(0均值和单位方差)。
13. 是否进行了太多数据增广 Do you have too much data augmentation?
数据增广有一定的正则化作用,然而过多的数据增广再结合到其他形式的正则化手段(比如:L2正则,dropout 等等)会使得网络欠拟合。
14. 检查是否和预训练模型一致 Check the preprocessing of your pretrained model
如果你使用了预训练模型,那就要确保使用时要和预训练模型训练时的设置相同,比如预训练模型训练时,输入特征的数值范围是 [0, 1],那么使用它时,也要保证输入特征处于相同的数值范围(不要变成了 [-1, 1] 或者 [0, 255] 之类的)。
15. 检查对训练/验证/测试集的预处理是否正确 Check the preprocessing for train/validation/test set
CS231n 中指出了一个常见的错误:
“所有预处理用到的统计数据(比如数据的均值)都应该只在训练集上进行计算,然后再应用到验证集和测试集上。举个例子,一个CV新手在做预处理时,很可能会犯这样的错误:在整个数据集上计算图像数据的均值,然后让每张图片减去该均值之后再划分训练/验证/测试集。”
除此之外,还要检查对每个样本或者batch进行的多个预处理步骤是否都是正确的。
与实现相关的问题
16. 尝试解决简化版的问题 Try solving a simpler version of the problem
比方说要做目标检测,网络要同时输出目标的类别和坐标,那么可以先试试解决一个简化的问题——只预测目标的类别,这样做有助于找出问题所在。
17. 检查损失的数值是否正确 Look for correct loss “at chance”
这一个做法也是 CS231n 里面提到的:先用很小的值初始化参数,不采用任何正则化。假如这是一个有10个类的分类任务,那么初始化之后训练之前,每个样本预测正确的可能性为10%,如果用 softmax 损失(概率取负对数)的话就应该是 -ln(0.1),也即 2.302 左右。如果相差很大,说明网络本身的计算上可能有问题。在这一步之后,可以尝试加入正则化,如果网络正确无误的话应该能观察到损失变大了。
18. 检查损失函数 Check your loss function
如果你的损失函数是自己写的,那就要检查一下有没有bug,最好可以自己写个单元测试来检查。作者表示自己写的损失函数出错从而导致模型表现不佳是很常见的。
19. 检查损失函数的输入 Verify loss input
如果你的损失函数是由框架提供的,那就检查一下模型传递给损失函数的输入是否是正确的。比方说在 PyTorch 中,很容易混淆 NLLLoss 和 CrossEntropyLoss,前者要求输入是经过 softmax 计算出的概率分布,后者则不需要(内含softmax)。
20. 调节损失的权重 Adjust loss weights
如果你的损失函数是由多个损失函数组成的,那就要检查一下它们的权重是符合你的期望的,可以尝试一下不同的权重。
21. 观察其他指标 Monitor other metrics
损失有时候不是用来检查模型是否正确运行的最好指标,如果可以,不妨观察一下其他指标(比如准确率)是否正常。
22. 对自定义的层进行测试 Test any custom layers
如果模型中某些层是你自己实现的,那就需要着重检查这些层是否真的像你所期望的那样工作。
23. 检查冻结的层或者变量 Check for “frozen” layers or variables
如果使用预训练模型,有些层或者变量是不希望更新的,就会设置为冻结。但有时候会误把一些希望更新的层/变量也设置为冻结。需要检查一下。
24. 增加网络的容量 Increase network size
有时候效果不好也有可能是网络的容量不足以捕捉到需要的信息,不妨加入更多层或者使用更多隐层神经元试试。
25. 检查隐层的维度是否有错 Check for hidden dimension errors
如果你输入数据的维度是类似于 (k, H, W) = (64, 64, 64) 的样子的话,你很可能会忽略掉因为维度出错而引起的问题。可以使用古怪一点的数字作为输入的维度(比如使用几个素数),检查在前馈的过程中每一层的输入输出的维度是否都是正确的。
26. 梯度检查 Explore Gradient checking
如果你的梯度下降是自己写的,那梯度检查可以帮助你确定反向传播是否有正常工作。更多信息可以查看 1, 2, 3。
与训练相关的问题
27. 在一个极小的数据集上实验 Solve for a really small dataset
取数据集的一个非常小的子集来做实验,过拟合这个子数据集,保证模型在这个数据集上是能work的(如果连这么小的数据集都过拟合不了,那代码肯定出了问题)。比方说做分类,可以先构造一个只有2个样本的数据集,观察模型是否能够学会区分这2个样本,可以的话再每个类加入更多的样本。
28. 检查权重初始化 Check weights initialization
如果不是很确定怎样初始化最好,那么一般用 Xavier 或者 He initialization 就可以了。糟糕的初始化可能会使得模型的学习容易陷入局部最小,所以不妨试试使用不同的初始化方式是否能有帮助。
29. 改变超参数 Change your hyperparameters
效果不好也有可能是超参数导致的,如果条件允许,可以尝试对超参数进行 grid search。
30. 减少正则化 Reduce regularization
过多的正则化会让模型欠拟合。减少一些正则化手段,比如 dropout,batch norm 还有权重和偏置的 L2 正则化项等等。在 “Practical Deep Learning for coders” 这门课程中, Jeremy Howard 建议先解决欠拟合的问题,当你能充分地过拟合训练数据的时候再考虑如何解决过拟合。
31. 等 Give it time
不得不承认,有时候“等”也是一个办法。有时候你的模型就是需要更多的训练时间才能做出准确的预测。如果你的损失还在稳定地下降,那就让它再多训练一会儿吧~
32. 训练模式和测试模式之间的切换 Switch from Train to Test mode
有些层,比如 Batch Norm,Dropout 等等在训练和测试时进行的操作是不同的,要确保它们在训练的时候以训练模式工作,在测试的时候以测试模式工作。
33. 可视化训练过程 Visualize the training
- 检查激活值、权重和每一层的更新,确保它们的数值处于正常范围。比方说,参数(权重和偏置)的更新应该处于 1e-3 的量级。
- 使用 Tensorboard 或 Crayon 这样的可视化工具,必要的时候还可以打印出权重/偏置/激活值查看。
- 留意是否某些层的激活值要远大于0,尝试使用 Batch Norm 或者 ELUs。
- Deeplearning4j 指出了应该怎么去看权重和偏置的直方图:
“对于权重,一段时间后,直方图应该接近高斯(正态)分布;对于偏置,直方图应该从0开始,并最终接近高斯分布(LSTM除外)。留意那些发散到正无穷或者负无穷的参数,留意那些变得非常大的偏置,在类别不平衡的分类问题中常常会在输出层观察到这些现象”
- 检查每一层的更新,它们同样应该接近高斯分布。
34. 尝试不同的优化器 Try a different optimizer
优化器的选择本来不应该对模型效果有很大的影响,除非你选择的超参太过糟糕。但是,一个合适的优化器能让模型在更短时间内得到更好的训练。写论文时一般也会指出使用了什么优化器,如果没有的话,就用 Adam 或者带动量的SGD。
推荐阅读 Sebastian Ruder 写的这篇关于梯度下降优化器的超赞的博客。
35. 梯度爆炸/消失 Exploding / Vanishing gradients
- 检查每一层的更新,出现很大的数值通常意味着存在梯度爆炸问题,进行梯度裁剪或许有帮助。
- 检查每一层的激活值,Deeplearning4j 指出,激活值合理的标准差大概是0.5~2.0之间。如果超出该范围很多,那就意味着可能存在梯度爆炸或者梯度消失的问题了。
36. 增大/降低学习率 Increase/Decrease Learning Rate
学习率太小会令模型收敛非常慢,学习率太高可以让损失在一开始就下降得很快,但之后却难以找到一个好的解。不妨试试把你当前的学习率乘上10或者除以10,观察有什么变化。
37. 克服 NaNs Overcoming NaNs
在训练 RNNs 时,结果可能会变成 NaN(Non-a-Number)。有几种方法可以解决这个问题:
- 降低学习率,特别是在前100次迭代就得到了 NaNs 的时候;
- NaNs 也可能是因为除0操作/对0或负数取对数造成的,检查一下是否有这些问题;
- Russell Stewart 有很好的见解:如何处理 NaNs(这个网站好像没了?)。
- 逐层检查模型,看看是那个地方出现了 NaNs。
以上就是 NN 不 work 的时候可以尝试的37种做法,出错的原因有很多种,当然没办法指望这37种做法就能完全 cover,但按照我的经验,尝试从这个列表里查错还是挺有用的。之后有新的问题再作补充(当然,代码一遍就work,没有问题就最好啦)。
原文地址:blog.slavv.com/37-reasons-why-your-neural-network-is-not-working-4020854bd607