“ Tensorflow的bug太多了,我只能转投Pytorch的怀抱”
01
—
最近Tensorflow(下称TF)已死的言论不知道大家是否接收到:
放弃支持Windows GPU、bug多,TensorFlow被吐槽:2.0后慢慢死去 https://zhuanlan.zhihu.com/p/656241342
主要是谷歌放弃了在Windows上对TF的支持。对普通开发者而言,顶层信息其实并没有太大的波澜,随波逐流就是。
但是,如果我们嗅到一丝丝警觉而不管不顾的话,早晚要被抛弃!
所以,Pytorch(下称torch)还不得不信手拈来。同时,让我们顺带复习一下基本的求导、前馈、权重、Loss等词汇在深度学习里是怎么运作的吧:
正文开始:
学习torch之前,容我们思考一下,深度学习模型的学习思维和逻辑过程。假如,面对我们的是一个线性模型:Y=wX。那我们最关键的是学习(训练、调整)权重w的值。
02
—
以下代码能让我们直观的感受到w的粗略学习过程:
代码语言:javascript复制import numpy as np
w_list=[]
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1 # Random value
def forward(x):
return x*w
def loss(x, y):
y_pred = forward(x)
return (y_pred-y)*(y_pred-y)
for w in np.arange(0.0,4.1,0.1):
print("w=", w)
l_sum=0
for x_val, y_val in zip (x_data, y_data):
y_pred_val = forward(x_val)
l = loss(x_val, y_val)
l_sum =l
print("t", x_val, y_val, y_pred_val, l)
print("MSE=", l_sum/3)
w_list.append(w)
mse_list.append(l_sum/3)
import matplotlib.pyplot as plt
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
上述内容就是一个w的变化过程。从原始数据中我们可以简单判断出,w应该等于2。权重不断的在改变中经过了2,但并没有停止的意思。因此我们的模型并不能给我们最终结果为2。
03
—
由此,我们需要优化:
优化的过程需要涉及到求导,导数为0的时候就是我们线性函数的最优解(暂时)。
代码语言:javascript复制import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
w_list = []
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
# Function for forward pass to predict y
def forward(x):
return x*w
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)**2
def gradient(x,y):
return 2*x*(x*w-y)
# Training loop
print('Predict (before training)', 4, forward(4))
# Training loop
for epoch in range(100):
l_sum=0
for x_val, y_val in zip(x_data, y_data):
grad = gradient(x_val, y_val)
w = w-0.01*grad
print('tgrad: ', x_val, y_val, grad)
l=loss(x_val, y_val)
l_sum =l
print('Progress: ', epoch, 'w=', w, 'loss=', l)
w_list.append(w)
mse_list.append(l_sum/3)
print('Predict (After training)', '4 hours', forward(4))
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
那现在,我们就能顺利得到循环n次后的最优解w=2。
这就是这个学习过程的基本思路,但它其实并不需要涉及到torch,这是因为我们目前还没涉及到自动微分的过程。
04
—
torch其实就是集成了许多核心运算形式,方便我们调用。这点TF其实也是一样的。只不过在使用过程中,许多开发者发现TF版本兼容性较差,动不动就因为版本原因产生bug。解决bug的成本太高了,所以许多人才转投torch等其他开源框架。
下期,我们将重点描述torch的基本入门操作。