机器学习从零开始系列连载(10)——最优化原理(下)

2020-02-29 17:09:06 浏览数 (1)

并行SGD

SGD相对简单并且被证明有较好的收敛性质和精度,所以自然而然就想到将其扩展到大规模数据集上,就像Hadoop/Spark的基本框架是MapReduce,并行机器学习的常见框架有两种:AllReduce 和 Parameter Server(PS)。

AllReduce

AllReduce的思想来源于MPI,它可以被看做Reduce操作 Broadcast操作,例如:

From MPI Tutorials

其他AllReduce的拓扑结构如下:

From Huasha Zhao & John Canny

非常好的开源实现有John Langfordvowpal wabbit陈天奇Rabit(轻量级、可容错)。并行计算的关键之一是如何在大规模数据集下计算目标函数的梯度值,AllReduce框架很适合这种任务,比如:vw通过构建一个二叉树来管理机器节点,其中一个节点会被当做master,其他节点作为slave,master管理着slave并定期接受它们的心跳,每个子节点的计算结果会被其父节点收集,到达根节点后累加并广播到其所有子节点,一个简单的例子如下:

使用mini-batch的并行SGD算法伪代码如下:

参数服务器(Parameter Server)

参数服务器强调模型训练时参数的并行异步更新,最早是由Google的Jeffrey Dean团队提出,为了解决深度学习的参数学习问题,其基本思想是:将数据集划分为若干子数据集,每个子数据集所在的节点都运行着一个模型的副本,通过独立部署的参数服务器组织模型的所有权重,其基本操作有:Fatching:每隔n次迭代,从参数服务器获取参数权重,Pushing:每隔m次迭代,向参数服务器推送本地梯度更新值,之后参数服务器会更新相关参数权重,其基本架构如下:

From Jeffrey Dean: Large Scale Distributed Deep Networks

每个模型的副本都是,为减少通信开销,每个模型副本在迭代次后向参数服务器请求参数跟新,反过来本地模型每迭代次后向参数服务器推送一次梯度更新值,当然,为了折中速度和效果,梯度的更新可以选择异步也可以是同。参数服务器是一个非常好的机器学习框架,尤其在深度学习的应用场景中,有篇不错的文章: 参数服务器——分布式机器学习的新杀器。开源的实现中比较好的是bosen项目和李沐ps-lite(现已集成到DMLC项目中)。下面是一个Go语言实现的多线程版本的参数服务器(用于Ftrl算法的优化),源码位置:Goline

代码语言:javascript复制
// data structure of ftrl solver.
type FtrlSolver struct {
    Alpha   float64 `json:"Alpha"`
    Beta    float64 `json:"Beta"`
    L1      float64 `json:"L1"`
    L2      float64 `json:"L2"`
    Featnum int     `json:"Featnum"`
    Dropout float64 `json:"Dropout"`
    N []float64 `json:"N"`
    Z []float64 `json:"Z"`
    Weights util.Pvector `json:"Weights"`
    Init bool `json:"Init"`
}
// data structure of parameter server.
type FtrlParamServer struct {
    FtrlSolver
    ParamGroupNum int
    LockSlots     []sync.Mutex
    log           log4go.Logger
}
// fetch parameter group for update n and z value.
func (fps *FtrlParamServer) FetchParamGroup(n []float64, z []float64, group int) error {
    if !fps.FtrlSolver.Init {
        fps.log.Error("[FtrlParamServer-FetchParamGroup] Initialize fast ftrl solver error.")
        return errors.New("[FtrlParamServer-FetchParamGroup] Initialize fast ftrl solver error.")
    }
    var start int = group * ParamGroupSize
    var end int = util.MinInt((group 1)*ParamGroupSize, fps.FtrlSolver.Featnum)
    fps.LockSlots[group].Lock()
    for i := start; i < end; i   {
        n[i] = fps.FtrlSolver.N[i]
        z[i] = fps.FtrlSolver.Z[i]
    }
    fps.LockSlots[group].Unlock()
    return nil
}
// fetch parameter from server.
func (fps *FtrlParamServer) FetchParam(n []float64, z []float64) error {
    if !fps.FtrlSolver.Init {
        fps.log.Error("[FtrlParamServer-FetchParam] Initialize fast ftrl solver error.")
        return errors.New("[FtrlParamServer-FetchParam] Initialize fast ftrl solver error.")
    }
    for i := 0; i < fps.ParamGroupNum; i   {
        err := fps.FetchParamGroup(n, z, i)
        if err != nil {
            fps.log.Error(fmt.Sprintf("[FtrlParamServer-FetchParam] Initialize fast ftrl solver error.", err.Error()))
            return errors.New(fmt.Sprintf("[FtrlParamServer-FetchParam] Initialize fast ftrl solver error.", err.Error()))
        }
    }
    return nil
}
// push parameter group for upload n and z value.
func (fps *FtrlParamServer) PushParamGroup(n []float64, z []float64, group int) error {
    if !fps.FtrlSolver.Init {
        fps.log.Error("[FtrlParamServer-PushParamGroup] Initialize fast ftrl solver error.")
        return errors.New("[FtrlParamServer-PushParamGroup] Initialize fast ftrl solver error.")
    }
    var start int = group * ParamGroupSize
    var end int = util.MinInt((group 1)*ParamGroupSize, fps.FtrlSolver.Featnum)
    fps.LockSlots[group].Lock()
    for i := start; i < end; i   {
        fps.FtrlSolver.N[i]  = n[i]
        fps.FtrlSolver.Z[i]  = z[i]
        n[i] = 0
        z[i] = 0
    }
    fps.LockSlots[group].Unlock()
    return nil
}
// push weight update to parameter server.
func (fw *FtrlWorker) PushParam(param_server *FtrlParamServer) error {
    if !fw.FtrlSolver.Init {
        fw.log.Error("[FtrlWorker-PushParam] Initialize fast ftrl solver error.")
        return errors.New("[FtrlWorker-PushParam] Initialize fast ftrl solver error.")
    }
    for i := 0; i < fw.ParamGroupNum; i   {
        err := param_server.PushParamGroup(fw.NUpdate, fw.ZUpdate, i)
        if err != nil {
            fw.log.Error(fmt.Sprintf("[FtrlWorker-PushParam] Initialize fast ftrl solver error.", err.Error()))
            return errors.New(fmt.Sprintf("[FtrlWorker-PushParam] Initialize fast ftrl solver error.", err.Error()))
        }
    }
    return nil
}
// to do update for all weights.
func (fw *FtrlWorker) Update(
    x util.Pvector,
    y float64,
    param_server *FtrlParamServer) float64 {
    if !fw.FtrlSolver.Init {
        return 0.
    }
    var weights util.Pvector = make(util.Pvector, fw.FtrlSolver.Featnum)
    var gradients []float64 = make([]float64, fw.FtrlSolver.Featnum)
    var wTx float64 = 0.
    for i := 0; i < len(x); i   {
        item := x[i]
        if util.UtilGreater(fw.FtrlSolver.Dropout, 0.0) {
            rand_prob := util.UniformDistribution()
            if rand_prob < fw.FtrlSolver.Dropout {
                continue
            }
        }
        var idx int = item.Index
        if idx >= fw.FtrlSolver.Featnum {
            continue
        }
        var val float64 = fw.FtrlSolver.GetWeight(idx)
        weights = append(weights, util.Pair{idx, val})
        gradients = append(gradients, item.Value)
        wTx  = val * item.Value
    }
    var pred float64 = util.Sigmoid(wTx)
    var grad float64 = pred - y
    util.VectorMultiplies(gradients, grad)
    for k := 0; k < len(weights); k   {
        var i int = weights[k].Index
        var g int = i / ParamGroupSize
        if fw.ParamGroupStep[g]%fw.FetchStep == 0 {
            param_server.FetchParamGroup(
                fw.FtrlSolver.N,
                fw.FtrlSolver.Z,
                g)
        }
        var w_i float64 = weights[k].Value
        var grad_i float64 = gradients[k]
        var sigma float64 = (math.Sqrt(fw.FtrlSolver.N[i] grad_i*grad_i) - math.Sqrt(fw.FtrlSolver.N[i])) / fw.FtrlSolver.Alpha
        fw.FtrlSolver.Z[i]  = grad_i - sigma*w_i
        fw.FtrlSolver.N[i]  = grad_i * grad_i
        fw.ZUpdate[i]  = grad_i - sigma*w_i
        fw.NUpdate[i]  = grad_i * grad_i
        if fw.ParamGroupStep[g]%fw.PushStep == 0 {
            param_server.PushParamGroup(fw.NUpdate, fw.ZUpdate, g)
        }
        fw.ParamGroupStep[g]  = 1
    }
    return pred
}

二阶优化方法

概览

大部分的优化算法都是基于梯度的迭代方法,其迭代式来源为泰勒展开式,迭代的一般式为:

其中步长,向量 被称作搜索方向,它一般要求是一个能使目标函数值(最小化问题)下降的方向,即满足:

进一步说, 的通项式有以下形式:

是一个对称非奇异矩阵(大家请问为什么?)。

这类优化方法大体分两种,要么是先确定优化方向后确定步长(line search),要么是先确定步长后确定优化方向(trust region)。

以常用的line search为例,如何找到较好的步长X呢?好的步长它需要满足以下条件:

牛顿法(Newton Method)

从泰勒展开式得到牛顿法的基本迭代式:

对牛顿法的改进之一是使用自适应步长X:

但总的来说牛顿法由于需要求解Hessian 矩阵,所以计算代价过大,对问题规模较大的优化问题力不从心。

拟牛顿法(Quasi-Newton Method)

为解决Hessian 矩阵计算代价的问题,想到通过一阶信息去估计它的办法,于是涌现出一类方法,其中最有代表性的是DFP和BFGS(L-BFGS),其原理如下:

思考一个问题:为什么通常二阶优化方法收敛速度快于一阶方法?

1.机器学习原来这么有趣!【第一章】

2.机器学习原来这么有趣!【第二章】:用机器学习制作超级马里奥的关卡

3.机器学习从零开始系列连载(1)——基本概念

4.机器学习从零开始系列连载(2)——线性回归

5.机器学习从零开始系列连载(3)——支持向量机

6.机器学习从零开始系列连载(4)——逻辑回归

7.机器学习从零开始系列连载(5)——Bagging and Boosting框架

0 人点赞