1、概述
solver算是caffe中比较核心的一个概念,在我们训练train我们的网络时,就必须要带上这个参数,
如下例是我要对Lenet进行训练的时候要调用的程序,现在不知道什么意思没关系,只需要知道这个solver.prototxt是个必不可少的东西就ok
./build/tools/caffe train--solver=examples/mnist/lenet_solver.prototxt
Solver通过协调Net的前向推断计算和反向梯度计算对参数进行更新,从而达到减小loss的目的
Solver的主要功能:
○ 设计好需要优化的对象,创建train网络和test网络。(通过调用另外一个配置文件prototxt来进行)
○ 通过forward和backward迭代的进行优化来更新新参数。
○ 定期的评价测试网络。 (可设定多少次训练后,进行一次测试)。
○ 在优化过程中记录模型和solver的状态的快照。
在每一次的迭代过程中,solver做了这几步工作:
1、调用forward算法来计算最终的输出值,以及对应的loss
2、调用backward算法来计算每层的梯度
3、根据选用的slover方法,利用梯度进行参数更新
4、根据学习率、历史数据、求解方法更新solver状态,使得权重从初始化状态逐步更新到最终的状态。
2、caffe.proto关于solver的描述
虽然内容很多,但是基本上是注释占的篇幅多,而且我也基本上都翻译成了中文注释,建议仔细阅读,这是一切solver的模版
代码语言:javascript复制// NOTE
// Update the next available ID when you add a new SolverParameter field.
// ## 注意,如果你要增加一个新的sovler参数,需要给它更新ID,就是下面内容中的数字
// SolverParameter next available ID: 41 (last added: type) ##下一个可用的ID是41,上一次caffe增加的是type参数,就是下文中的40
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
// ##指定训练和测试网络
// Exactly one train net must be specified using one of the following fields:
// train_net_param, train_net, net_param, net
// One or more test nets may be specified using any of the following fields:
// test_net_param, test_net, net_param, net
// If more than one test net field is specified (e.g., both net and
// test_net are specified), they will be evaluated in the field order given
// above: (1) test_net_param, (2) test_net, (3) net_param/net.
// A test_iter must be specified for each test_net.
// A test_level and/or a test_stage may also be specified for each test_net.
//////////////////////////////////////////////////////////////////////////////
// Proto filename for the train net, possibly combined with one or more
// test nets. ##这个训练网络的Proto文件名,可能结合一个或多个测试网络。
optional string net = 24;
// Inline train net param, possibly combined with one or more test nets. ## 对应的训练网络的参数,可能结合一个或多个测试网络
optional NetParameter net_param = 25;
optional string train_net = 1; // Proto filename for the train net. ## train net的proto文件名
repeated string test_net = 2; // Proto filenames for the test nets. ## test nets的proto文件名
optional NetParameter train_net_param = 21; // Inline train net params. ## 与上面train网络一致对应的参数
repeated NetParameter test_net_param = 22; // Inline test net params. ## 与上面test网络一致对应的参数
// The states for the train/test nets. Must be unspecified or
// specified once per net.
// ## train/test网络的状态。 必须是未指定或每个网络指定一次
// By default, all states will have solver = true; ##默认情况下,所有状态都将有solver = true;
// train_state will have phase = TRAIN, ##train_state会有phase = TRAIN,
// and all test_state's will have phase = TEST. ##所有的test_state都将进行phase = TEST
// Other defaults are set according to the NetState defaults. ##其他默认值是根据NetState默认设置的。
optional NetState train_state = 26;
repeated NetState test_state = 27;
// The number of iterations for each test net. ## test网络的迭代次数:
repeated int32 test_iter = 3;
// The number of iterations between two testing phases.
// ## 两次test之间(train)的迭代次数
//## <训练test_interval个批次,再测试test_iter个批次,为一个回合(epoch), 合理设置应使得每个回合内,遍历覆盖到全部训练样本和测试样本 >
optional int32 test_interval = 4 [default = 0];
optional bool test_compute_loss = 19 [default = false]; // ## 默认不计算测试时损失
// If true, run an initial test pass before the first iteration,
// ensuring memory availability and printing the starting value of the loss.
// ##如设置为真,则在训练前运行一次测试,以确保内存足够,并打印初始损失值
optional bool test_initialization = 32 [default = true];
optional float base_lr = 5; // The base learning rate ##基本学习速率
// the number of iterations between displaying info. If display = 0, no info
// will be displayed. ##打印信息的遍历间隔,遍历多少个批次打印一次信息。设置为0则不打印。
optional int32 display = 6;
// Display the loss averaged over the last average_loss iterations ## 打印最后一个迭代批次下的平均损失
optional int32 average_loss = 33 [default = 1];
optional int32 max_iter = 7; // the maximum number of iterations ##train的最大迭代次数
// accumulate gradients over `iter_size` x `batch_size` instances
// ## 累积梯度误差基于“iter_size×batchSize”个样本实例,< “批次数×批量数”=“遍历的批次数×每批的样本数”个样本实例 >
optional int32 iter_size = 36 [default = 1];
// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
//##学习率衰退策略.目前实行的学习率策略如下:
// - fixed: always return base_lr. ##保持base_lr不变.
// - step: return base_lr * gamma ^ (floor(iter / step)) ##返回 base_lr * gamma ^(floor(iter / stepsize)),
// - exp: return base_lr * gamma ^ iter ##返回base_lr * gamma ^ iter, iter为当前迭代次数
// - inv: return base_lr * (1 gamma * iter) ^ (- power) ##如果设置为inv,还需设置一个power,返回return 后的内容
// - multistep: similar to step but it allows non uniform steps defined by ##这个参数和step很相似,还需要设置一个stepvalue。
// stepvalue ##但step是均匀等间隔变化,而此参数根据stepvalue变化
// - poly: the effective learning rate follows a polynomial decay, to be ##学习率进行多项式衰减,由max_iter变为0
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) , ##返回 base_lr (1- iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay ##学习率进行sigmod衰减,
// return base_lr ( 1/(1 exp(-gamma * (iter - stepsize)))) ##返回return 后的内容
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
// ## 在上述参数中,base_lr, max_iter, gamma, step, stepvalue and power 被定义
// 在solver.prototxt文件中,iter是当前迭代次数。
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value. ## 动量
optional float weight_decay = 12; // The weight decay. ##权值衰减系数
// regularization types supported: L1 and L2
// controlled by weight_decay
// ## 由权值衰减系数所控制的正则化类型:L1或L2范数,默认L2
optional string regularization_type = 29 [default = "L2"];
// the stepsize for learning rate policy "step" ##"step"策略下,学习率的步长值
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep" ## "multistep"策略下的步长值
repeated int32 stepvalue = 34;
// Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
// whenever their actual L2 norm is larger.
optional float clip_gradients = 35 [default = -1];
optional int32 snapshot = 14 [default = 0]; // The snapshot interval ##快照间隔<遍历多少次对模型和求解器状态保存一次>
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
// ## 是否对diff快照,有助调试,但最终的protocol buffer尺寸会很大
optional bool snapshot_diff = 16 [default = false];
// ## 快照数据保存格式{ hdf5,binaryproto(默认) }
enum SnapshotFormat {
HDF5 = 0;
BINARYPROTO = 1;
}
optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. ##选CPU或GPU模式,默认是GPU
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [default = GPU];
// the device_id will that be used in GPU mode. Use device_id = 0 in default. ##如果选了GPU模式,此参数指定哪个GPU,默认是0号GPU
optional int32 device_id = 18 [default = 0];
// If non-negative, the seed with which the Solver will initialize the Caffe
// random number generator -- useful for reproducible results. Otherwise,
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1];
// type of the solver ## 求解器类型=SGD(默认),目前一共有6种
optional string type = 40 [default = "SGD"];
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
//## RMSProp,AdaGrad和AdaDelta和Adam的数值稳定性
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver ## Adam类型时的参数
optional float momentum2 = 39 [default = 0.999];
// RMSProp decay value ##RMSProp的衰减值
// MeanSquare(t) = rms_decay*MeanSquare(t-1) (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38;
// If true, print information about the state of the net that may help with
// debugging learning problems.
//## 此参数默认为false,若为true,则打印网络状态信息,有助于调试问题
optional bool debug_info = 23 [default = false];
// If false, don't save a snapshot after training finishes.
//## 此参数默认为true,若为false,则不会在训练后保存快照
optional bool snapshot_after_train = 28 [default = true];
// DEPRECATED: old solver enum types, use string instead ##已经弃用,本来表示6种sovler类型,现在用string type中的string代替
enum SolverType {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
ADAM = 5;
}
// DEPRECATED: use type instead of solver_type ##已经弃用,用string type中的type代替
optional SolverType solver_type = 30 [default = SGD];
}
3、举例说明
我仍以caffe/examples/mnist/lenet_solver.prototxt这个文件为例,下图是我的截图
我把上图的内容复制过来看的清楚一些,并把注释翻译了一下:
---------------这一部分可以对照着2中proto中的描述看,你会发现其实solver的编写也就是对着模版填参数的一个过程,----------
# 我们需要的Net的模型,这个模型定义在另一个prototxt文件中,这个就是我上一篇博文举的Net的例子
# 显然这里根据需要你可以选择其他的一些Net
net: "examples/mnist/lenet_train_test.prototxt"
#test_iter 设置了test一共迭代多少次,这里是100
# 至于test每一次迭代处理多少张图片,在Net那个prototxt里面batch_size规定了的
test_iter: 100
# 训练每迭代500次,测试一次(这每一次测试要迭代100次).
test_interval:500
#设置学习率。base_lr用于设置基础学习率,在迭代的过程中,可以对基础学习率进行调整。怎么样进行调整,就是调整的策略,由lr_policy来设置。
#momentum称为动量,使得权重更新更为平缓
#weight_decay称为衰减率因子,防止过拟合的一个参数
base_lr: 0.01
momentum: 0.9
# 这里省略了一个内容 type: SGD ,就是solver方法的选择,因为默认就是SGD,所以这个 solver. prototxt 中省略没写, 如果你想用其他的sovler方法就要指明写出来
weight_decay:0.0005
# 学习率调整的策略,详细见我下面的补充
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# train每迭代100次就显示一次
display: 100
#train最大迭代次数
max_iter: 10000
#快照。将训练出来的model和solver状态进行保存,snapshot用于设置训练多少次后进行保存,默认为0,不保存。snapshot_prefix设置保存路径。
还可以设置snapshot_diff,是否保存梯度值,默认为false,不保存。
也可以设置snapshot_format,保存的类型。有两种选择:HDF5和BINARYPROTO,默认为BINARYPROTO
#这里设置train每迭代5000次就存储一次数据
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
#设置运行模式,默认为GPU,如果你没有GPU,则需要改成CPU,否则会出错.
solver_mode: GPU
4、solver方法
Solver方法就是计算最小化损失值(loss)的方法,也就是我上面解析中说的省略掉的一行,其实一共有6种sovler方法:
· Stochastic Gradient Descent (type: "SGD"),
· AdaDelta (type: "AdaDelta"),
· Adaptive Gradient (type: "AdaGrad"),
· Adam (type: "Adam"),
· Nesterov’s Accelerated Gradient (type: "Nesterov") and
· RMSprop (type: "RMSProp")
默认设置的是SGD 随机梯度下降,所以就可以不写,但是如果想用其他的,就必须要写出来,比如type:Adam
这个方法对于我这种小白来说暂时没有研究的必要,而且SGD方法的数学原理至少我是知道的,所以我这里就只把这几种方法列出来了,没有详细解读,如果有兴趣可以参考下面这篇博客:
http://www.cnblogs.com/denny402/p/5074212.html,
这篇关于sovler讲解的博文就写完了,下面一篇就准备来将一下caffe中的hello world--使用Lenet来识别mnist手写数据。