Numpy简述神经网络模型权重搜索原理-Pytorch引文

2023-10-06 17:05:06 浏览数 (1)

Tensorflow的bug太多了,我只能转投Pytorch的怀抱

01

最近Tensorflow(下称TF)已死的言论不知道大家是否接收到:

放弃支持Windows GPU、bug多,TensorFlow被吐槽:2.0后慢慢死去 https://zhuanlan.zhihu.com/p/656241342

主要是谷歌放弃了在Windows上对TF的支持。对普通开发者而言,顶层信息其实并没有太大的波澜,随波逐流就是。

但是,如果我们嗅到一丝丝警觉而不管不顾的话,早晚要被抛弃!

所以,Pytorch(下称torch)还不得不信手拈来。同时,让我们顺带复习一下基本的求导、前馈、权重、Loss等词汇在深度学习里是怎么运作的吧:

正文开始:

学习torch之前,容我们思考一下,深度学习模型的学习思维和逻辑过程。假如,面对我们的是一个线性模型:Y=wX。那我们最关键的是学习(训练、调整)权重w的值。

02

以下代码能让我们直观的感受到w的粗略学习过程:

代码语言:javascript复制
import numpy as np
w_list=[]
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1 # Random value
def forward(x):
    return x*w
def loss(x, y):
    y_pred = forward(x)
    return (y_pred-y)*(y_pred-y)
for w in np.arange(0.0,4.1,0.1):
    print("w=", w)
    l_sum=0
    for x_val, y_val in zip (x_data, y_data):
        y_pred_val = forward(x_val)
        l = loss(x_val, y_val)
        l_sum =l
        print("t", x_val, y_val, y_pred_val, l)
        
    print("MSE=", l_sum/3)
    w_list.append(w)
    mse_list.append(l_sum/3)



import matplotlib.pyplot as plt
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

上述内容就是一个w的变化过程。从原始数据中我们可以简单判断出,w应该等于2。权重不断的在改变中经过了2,但并没有停止的意思。因此我们的模型并不能给我们最终结果为2。

03

由此,我们需要优化:

优化的过程需要涉及到求导,导数为0的时候就是我们线性函数的最优解(暂时)。

代码语言:javascript复制
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

w_list = []
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0

# Function for forward pass to predict y
def forward(x):
    return x*w
def loss(x,y):
    y_pred = forward(x)
    return (y_pred-y)**2
def gradient(x,y):
    return 2*x*(x*w-y)

# Training loop

print('Predict (before training)', 4, forward(4))

# Training loop

for epoch in range(100):
    l_sum=0
    for x_val, y_val in zip(x_data, y_data):
        grad = gradient(x_val, y_val)
        w = w-0.01*grad
        print('tgrad: ', x_val, y_val, grad)
        l=loss(x_val, y_val)
        l_sum =l
        
    print('Progress: ', epoch, 'w=', w, 'loss=', l)
    w_list.append(w)
    mse_list.append(l_sum/3)
    
    
print('Predict (After training)', '4 hours', forward(4))    

plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

那现在,我们就能顺利得到循环n次后的最优解w=2。

这就是这个学习过程的基本思路,但它其实并不需要涉及到torch,这是因为我们目前还没涉及到自动微分的过程。

04

torch其实就是集成了许多核心运算形式,方便我们调用。这点TF其实也是一样的。只不过在使用过程中,许多开发者发现TF版本兼容性较差,动不动就因为版本原因产生bug。解决bug的成本太高了,所以许多人才转投torch等其他开源框架。

下期,我们将重点描述torch的基本入门操作。

0 人点赞