日拱一卒,伯克利教你牛顿法,而我只想逃课……

2022-09-21 12:42:58 浏览数 (1)

作者 | 梁唐

出品 | 公众号:Coder梁(ID:Coder_LT)

大家好,日拱一卒,我是梁唐。

今天我们继续来看伯克利CS61A,我们来看作业5的最后一道附加题。这道题非常有意思,涉及很多知识,因此想要完整讲明白,需要很多篇幅,所以单独写了一篇。

这道题是一道数学题:给定我们一个区间[l, r],以及一个数组c,让我们找到c数组对应的方程在区间上的取值范围。

方程的表达式为:

f(x) = c[0] c[1]x c[2]x^2 cdots c[k-1]x^{k-1}

题目中给了我们一个提示,可以使用牛顿法来预估区间的极值点。

如果你没有听说过什么是牛顿法,先不要绝望,且听我一点点将它说清楚。

严格说起来这道题非常不严谨,因为既没有说传入的区间范围有多大,也没有对c数组的长度进行限制。在这种情况下,函数在区间上可能存在的极值点的数量是无法预估的,极端情况下可能会非常大。

所以我当时非常困惑也有一点点不满,因为在我看来牛顿法是解决不了这么大不确定程度的问题。

于是,在这种情况下,我想出了蒙特卡洛法。

蒙特卡洛

蒙特卡洛听起来非常高大上,但实际上思想非常朴素,就是基于大数定理,尽可能多地采样,通过多次采样的统计结果来逼近某个很难直接计算的值。

比如AlphaGo当中就用到了蒙特卡洛方法对一局围棋的某一个局面进行预估,围棋的局面是有优势还是劣势是很难估计的。即使围棋大师能够看出谁占上风,也没办法把赢面这个概念量化。所以AlphaGo中用蒙特卡洛方法来预估,即基于某一个盘面对弈数万甚至数十万次,统计获胜的概率,将这个概率当做赢面。

所以大家看到这里,应该都能理解蒙特卡洛是一种怎样简单粗暴的方法。

在这道题当中使用蒙特卡罗法,其实就是在区间当中采样若干个点,然后针对每一个点计算函数值,然后用采样得到的最大最小值来近似最终的解。在这题当中,给定的函数是连续的。对于连续函数来说,我们采样的

x'

距离真实的极值点

x

越接近,得到的

f(x')

自然也就越接近

f(x)

理论上来说,只要我们在区间上取的点足够多,我们就可以采样到距离极值点尽可能接近的点,来逼近真实的答案。代码也非常好写,使用numpy只需要几行就可以搞定。

我在代码当中采样了1000个点,可以得到距离真实结果误差小于小数点后6个0的近似值。

代码语言:javascript复制
def polynomial(x, c):
    """Return the interval that is the range of the polynomial defined by
    coefficients c, for domain interval x.

    >>> str_interval(polynomial(interval(0, 2), [-1, 3, -2]))
    '-3 to 0.125'
    >>> str_interval(polynomial(interval(1, 3), [1, -3, 2]))
    '0 to 10'
    >>> str_interval(polynomial(interval(0.5, 2.25), [10, 24, -6, -8, 3]))
    '18.0 to 23.0'
    """
    "*** YOUR CODE HERE ***"
    import numpy as np
    x = np.linspace(lower_bound(x), upper_bound(x), 1000)
    ret = np.zeros_like(x)
    for i in range(len(c)):
        ret  = c[i] * np.power(x, i)
    return interval(np.min(ret), np.max(ret))

如果采样更多的点,精度还可以进一步增加。

牛顿法

看完了逃课的方法,我们再来看看官方正解中用到的牛顿法。

牛顿法本身并不复杂,它是一种迭代求根的方法,以下图为例:

假设曲线是某方程的函数图像,我们要求它和的根。我们随意选择一个点x,它对应的函数值为f(x),它距离x轴的距离为-f(x)。我们在(x, f(x))处做切线与x轴交于

x_0

,切线的斜率即为f'(x)

因此,我们可以写出等式:

f'(x) = frac{-f(x)}{x_0 - x}

通分、移项之后得:

x_0 = x - f(x) / f'(x)

求出了

x_0

的坐标之后,我们接下来选择

x_0

再做切线计算导数进行同样的迭代过程,直到找到的x对应的

f(x)

无限逼近于0为止。

了解完了牛顿法,我们再回过来看看原问题。

我们要求函数在区间上的范围,其实就是找出函数在区间上的极值点,并且求出极值点处对应的函数值,再返回极值点和端点对应的函数值的最大值和最小值即可。

连续函数求极值也就是求一阶导数等于0的点,我们可以套用一下牛顿法,也就是使用牛顿法迭代求解一阶导数等于0处的点。

虽然说起来顺理成章,但这里面其实藏了一些细节,比如说如果一阶导数有可能没有根,那么使用牛顿法不论怎么迭代也不可能收敛的。所以这个时候我们就需要通过最大迭代次数来限制,当迭代次数到达一定次数之后, 无论是否收敛,都会退出执行。

我们先来看下牛顿迭代法的代码,这段代码是老师的课件里给的,不得不说写得真的是非常6.

代码语言:javascript复制
def improve(update, close, guess=1, max_updates=100):
    """Iteratively improve guess with update until close(guess) is true or
    max_updates have been applied."""
    k = 0
    while not close(guess) and k < max_updates:
        guess = update(guess)
        k = k   1
    return guess


def approx_eq(x, y, tolerance=1e-15):
    return abs(x - y) < tolerance


def find_zero(f, df, guess=1):
    """Return a zero of the function f with derivative df."""
    def near_zero(x):
        return approx_eq(f(x), 0)
    return improve(newton_update(f, df), near_zero, guess)


def newton_update(f, df):
    """Return an update function for f with derivative df,
    using Newton's method."""
    def update(x):
        return x - f(x) / df(x)
    return update

我们一个一个来说,先说最简单的approx_eq,这个函数用来判断浮点数x和y的差值是否小于tolerance

需要这样的操作是因为浮点数判断相等是非常麻烦的,比如1.49999999和1.5就不相等,可能1.5和1.50也不一定会相等。这个和浮点数的存储方式有关,因此如果直接判断两个浮点数相等非常可能出现问题。因此常用的方法是当两个浮点数差值的绝对值小于我们指定的某一个精度时,就认为它们相等。

再来看newton_update函数,这个函数接收函数f和它的一阶导函数df,返回一个函数updateupdate接收一个x,返回根据牛顿法计算出下一次的迭代位置。

再来看improve函数,这个函数是执行迭代的函数。它接收四个传参,分别是update函数,close函数,guessmax_updateupdate函数用来每次更新迭代的值,close函数用来判断迭代是否结束。guess是迭代的初始值,max_updates是最大迭代次数。

最后是find_zero函数,前面几个函数都看懂了,基本上这个函数也就没多大问题了。

理解了这些代码之后就可以给出我们自己的实现了,其实某种程度上来说也是一种蒙特卡洛算法。因为我们不知道可能存在多少个极值点,所以只能先假设比如说区间里最多会有20个或者是100个极值点。然后我们使用牛顿法对这些可能是极值点附近的点进行迭代,再从迭代之后的结果当中找到最大值和最小值。

核心思路和直接使用蒙特卡洛是一样的,只不过多了一层使用牛顿法迭代而已。由于我们要求的是一阶导数为0的点,所以还需要求出一阶导数的导数,也就是二阶导数。

代码如下:

代码语言:javascript复制
def polynomial(x, c):
    """Return the interval that is the range of the polynomial defined by
    coefficients c, for domain interval x.
    >>> str_interval(polynomial(interval(0, 2), [-1, 3, -2]))
    '-3 to 0.125'
    >>> str_interval(polynomial(interval(1, 3), [1, -3, 2]))
    '0 to 10'
    >>> str_interval(polynomial(interval(0.5, 2.25), [10, 24, -6, -8, 3]))
    '18.0 to 23.0'
    """
    # f函数本身
    f = lambda x: 0
    # 一阶导数
    df = lambda x: 0
    # 二阶导数
    ddf = lambda x: 0

    lower, upper = lower_bound(x), upper_bound(x)
    gap = (upper - lower) / 20
    # 均分区间,列举附近可能有极值点的位置
    guess = [lower   i * gap for i in range(20)]

    # 使用for循环,将函数中的k项拼接在一起
    def iter_f(cof, p, f):
        return lambda x: cof * pow(x, p)   f(x)

    def iter_df(cof, p, df):
        if p < 1:
            return df
        return lambda x: cof * p * pow(x, p-1)   df(x)

    def iter_ddf(cof, p, ddf):
        if p < 2:
            return ddf
        return lambda x: cof * p * (p-1) * pow(x, p-2)   ddf(x)

    for i, cof in enumerate(c):
        f = iter_f(cof, i, f)
        df = iter_df(cof, i, df)
        ddf = iter_ddf(cof, i, ddf)

    # 使用牛顿法进行迭代
    extremenums = [find_zero(df, ddf, g) for g in guess]
    extremenums = [g for g in extremenums if g > lower and g < upper]   [lower, upper]
    vals = [f(x) for x in extremenums]
    return interval(min(vals), max(vals))

其实写完之后回过头来看这些代码本身难度其实还好,并没有想象中那么难。计算导数以及牛顿法都是高等数学中的基础内容,但是拿到题目要能够想到这样的解法要难得多。

牛顿迭代法是一个非常牛的算法,但出现的频率实在是很低。所以,这也是熟悉和学习这个算法一个非常好的机会。

好了,关于这道题就聊到这里,感谢大家的阅读。

喜欢本文的话不要忘记三连~

max

0 人点赞