作者 | Aymeric Damien
编辑 | 奇予纪
出品 | 磐创AI团队
线性回归示例:
本示例使用TensorFlow v2库实现线性回归,此示例使用简单方法来更好地理解训练过程背后的所有机制。
代码语言:javascript复制from __future__ import absolute_import, division, print_function
import tensorflow as tf
import numpy as np
rng = np.random
代码语言:javascript复制# 参数
learning_rate = 0.01
training_steps = 1000
display_step = 50
代码语言:javascript复制# 训练数据
X = np.array([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,
7.042,10.791,5.313,7.997,5.654,9.27,3.1])
Y = np.array([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
2.827,3.465,1.65,2.904,2.42,2.94,1.3])
n_samples = X.shape[0]
代码语言:javascript复制# 随机初始化权重,偏置
W = tf.Variable(rng.randn(),name="weight")
b = tf.Variable(rng.randn(),name="bias")
# 线性回归(Wx b)
def linear_regression(x):
return W * x b
# 均方差
def mean_square(y_pred,y_true):
return tf.reduce_sum(tf.pow(y_pred-y_true,2)) / (2 * n_samples)
# 随机梯度下降优化器
optimizer = tf.optimizers.SGD(learning_rate)
代码语言:javascript复制# 优化过程
def run_optimization():
# 将计算封装在GradientTape中以实现自动微分
with tf.GradientTape() as g:
pred = linear_regression(X)
loss = mean_square(pred,Y)
# 计算梯度
gradients = g.gradient(loss,[W,b])
# 按gradients更新 W 和 b
optimizer.apply_gradients(zip(gradients,[W,b]))
代码语言:javascript复制# 针对给定训练步骤数开始训练
for step in range(1,training_steps 1):
# 运行优化以更新W和b值
run_optimization()
if step % display_step == 0:
pred = linear_regression(X)
loss = mean_square(pred, Y)
print("step: %i, loss: %f, W: %f, b: %f" % (step, loss, W.numpy(), b.numpy()))
output:
代码语言:javascript复制step: 50, loss: 0.210631, W: 0.458940, b: -0.670898
step: 100, loss: 0.195340, W: 0.446725, b: -0.584301
step: 150, loss: 0.181797, W: 0.435230, b: -0.502807
step: 200, loss: 0.169803, W: 0.424413, b: -0.426115
step: 250, loss: 0.159181, W: 0.414232, b: -0.353942
step: 300, loss: 0.149774, W: 0.404652, b: -0.286021
step: 350, loss: 0.141443, W: 0.395636, b: -0.222102
step: 400, loss: 0.134064, W: 0.387151, b: -0.161949
step: 450, loss: 0.127530, W: 0.379167, b: -0.105341
step: 500, loss: 0.121742, W: 0.371652, b: -0.052068
step: 550, loss: 0.116617, W: 0.364581, b: -0.001933
step: 600, loss: 0.112078, W: 0.357926, b: 0.045247
step: 650, loss: 0.108058, W: 0.351663, b: 0.089647
step: 700, loss: 0.104498, W: 0.345769, b: 0.131431
step: 750, loss: 0.101345, W: 0.340223, b: 0.170753
step: 800, loss: 0.098552, W: 0.335003, b: 0.207759
step: 850, loss: 0.096079, W: 0.330091, b: 0.242583
step: 900, loss: 0.093889, W: 0.325468, b: 0.275356
step: 950, loss: 0.091949, W: 0.321118, b: 0.306198
step: 1000, loss: 0.090231, W: 0.317024, b: 0.335223
代码语言:javascript复制import matplotlib.pyplot as plt
# 绘制图
plt.plot(X, Y, 'ro', label='Original data')
plt.plot(X, np.array(W * X b), label='Fitted line')
plt.legend()
plt.show()
output:
mark