TensorFlow实战–Chapter04单变量线性回归
使用tensorflow实现单变量回归模型
文章目录
- TensorFlow实战--Chapter04单变量线性回归
- 监督式机器学习的基本术语
- 标签和特征
- 训练
- 损失
- 定义损失函数
- 模型训练与降低损失
- 样本和模型
- 线性回归问题TensorFlow实战
- 人工数据生成
- 利用matplotlib绘图
- 定义模型
- 模型训练
- 创建会话,变量初始化
- 迭代训练
- 打印结果
- 可视化
- 进行预测
- 显示损失值
- 图形现显示话损失值
使用TensorFlow进行算法设计与训练的核心步骤
- 准备数据
- 构建模型
- 训练模型
- 进行预测
上述步骤是我们使用TensorFlow进行算法设计与训练的核心步骤,贯穿于具体实践中。
监督式机器学习的基本术语
标签和特征
训练
损失
定义损失函数
模型训练与降低损失
样本和模型
线性回归问题TensorFlow实战
人工数据生成
代码语言:javascript复制import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
代码语言:javascript复制np.random.seed(5)
代码语言:javascript复制使用numpy生成等差数列的方法,生成100个点,每个点的取值在[-1,1]之间
y = 2x 1 噪声,其中,噪声的维度与x_data一致
np.random.randn(x_data.shape)实参的前面加上和**,就意味着拆包,单个*号表示将元组拆成一个个单独的实参
x_data = np.linspace(-1, 1, 100)
y_data = 2 * x_data 1.0 np.random.randn(*x_data.shape) * 0.4
利用matplotlib绘图
代码语言:javascript复制plt.scatter(x_data, y_data)
plt.show()
代码语言:javascript复制plt.plot(x_data, 2 * x_data 1.0, color="red", linewidth=3)
plt.show()
定义模型
定义训练数据的占位符,x是特征值,y是标签值
代码语言:javascript复制x = tf.placeholder("float", name="x")
y = tf.placeholder("float", name="y")
定义模型函数
代码语言:javascript复制def model(x, w, b):
return tf.multiply(x, w) b
创建变量
- TensorFlow变量的声明函数是tf.Variable
- tf,Variable的作用是保存和更新参数
- 变量的初始值可以是随机数、常数,或是通过其他变量的初始值计算得到
# 构建线性函数的斜率,变量2
w = tf.Variable(1.0, name="w0")
# 构建线性函数的截距,变量b
b = tf.Variable(0.0, name="b0")
# pred是预测值,前向计算
pred = model(x, w, b)
模型训练
- 设置训练参数
train_epochs = 10 # 迭代次数
learning_rate = 0.05 # 学习率
- 损失函数
- 损失函数用于描述预测值和真实值之间的误差,从而指导模型的收敛方向
- 常见的损失函数:
- 均方差(Mean Square Error)
- 交叉熵(cross-entropy)
loss_function = tf.reduce_mean(tf.square(y - pred)) # 采用均方差
- 定义优化器 优化器Optimizer,初始化一个GradientDescentOPtimizer 设置学习率和优化目标:最小化损失
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss_function)
创建会话,变量初始化
- 在真正进行计算之气,需将所有变量初始化
- 通过==tf.global_variables_initializer()==函数可实现对所有变量的初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
迭代训练
模型训练阶段,设置迭代轮次,通过将样本逐个输入模型,进行梯度优化操作。每轮迭代后,绘制出模型曲线
代码语言:javascript复制for epoch in range(train_epochs):
for xs, ys in zip(x_data, y_data):
_, loss = sess.run([optimizer, loss_function],
feed_dict={
x: xs,
y: ys
})
b0temp = b.eval(session=sess)
w0temp = w.eval(session=sess)
plt.plot(x_data, w0temp * x_data b0temp)
打印结果
代码语言:javascript复制print("w:",sess.run(w))
print("b:",sess.run(b))
代码语言:javascript复制w: 1.9822965
b: 1.0420128
可视化
代码语言:javascript复制plt.scatter(x_data,y_data,label="original data")
plt.plot(x_data,x_data*sess.run(w) sess.run(b),label="fitter line",color="r",linewidth=3)
plt.legend(loc=2)
plt.show()
进行预测
代码语言:javascript复制x_test = 3.21
predict = sess.run(pred,feed_dict={x:x_test})
print(f"预测值:{predict}")
target = 3* x_test 1.0
print(f"目标值:{target}")
代码语言:javascript复制预测值:7.405184268951416
目标值:10.629999999999999
显示损失值
代码语言:javascript复制step = 0
loss_list = []
display_step = 10
for eopch in range(train_epochs):
for xs, ys in zip(x_data, y_data):
_, loss = sess.run([optimizer, loss_function],
feed_dict={
x: xs,
y: ys
})
loss_list.append(loss)
step = 1
if step % display_step == 0:
print(f"Train Epoch:{eopch 1}", f"Step:{step}", f"loss={loss}")
b0temp = b.eval(session=sess)
w0temp = w.eval(session=sess)
plt.plot(x_data,w0temp*x_data b0temp)
代码语言:javascript复制Train Epoch:1 Step:10 loss=0.03621993958950043
Train Epoch:1 Step:20 loss=0.08414854109287262
Train Epoch:1 Step:30 loss=0.00047292912495322526
Train Epoch:1 Step:40 loss=0.32646191120147705
Train Epoch:1 Step:50 loss=0.027518080547451973
Train Epoch:1 Step:60 loss=0.01023304183036089
Train Epoch:1 Step:70 loss=0.12734712660312653
Train Epoch:1 Step:80 loss=0.0010273485677316785
Train Epoch:1 Step:90 loss=0.10288805514574051
Train Epoch:1 Step:100 loss=0.048337195068597794
Train Epoch:2 Step:110 loss=0.03621993958950043
Train Epoch:2 Step:120 loss=0.08414854109287262
Train Epoch:2 Step:130 loss=0.00047292912495322526
Train Epoch:2 Step:140 loss=0.32646191120147705
Train Epoch:2 Step:150 loss=0.027518080547451973
Train Epoch:2 Step:160 loss=0.01023304183036089
Train Epoch:2 Step:170 loss=0.12734712660312653
Train Epoch:2 Step:180 loss=0.0010273485677316785
Train Epoch:2 Step:190 loss=0.10288805514574051
Train Epoch:2 Step:200 loss=0.048337195068597794
Train Epoch:3 Step:210 loss=0.03621993958950043
Train Epoch:3 Step:220 loss=0.08414854109287262
Train Epoch:3 Step:230 loss=0.00047292912495322526
Train Epoch:3 Step:240 loss=0.32646191120147705
Train Epoch:3 Step:250 loss=0.027518080547451973
Train Epoch:3 Step:260 loss=0.01023304183036089
Train Epoch:3 Step:270 loss=0.12734712660312653
Train Epoch:3 Step:280 loss=0.0010273485677316785
Train Epoch:3 Step:290 loss=0.10288805514574051
Train Epoch:3 Step:300 loss=0.048337195068597794
Train Epoch:4 Step:310 loss=0.03621993958950043
Train Epoch:4 Step:320 loss=0.08414854109287262
Train Epoch:4 Step:330 loss=0.00047292912495322526
Train Epoch:4 Step:340 loss=0.32646191120147705
Train Epoch:4 Step:350 loss=0.027518080547451973
Train Epoch:4 Step:360 loss=0.01023304183036089
Train Epoch:4 Step:370 loss=0.12734712660312653
Train Epoch:4 Step:380 loss=0.0010273485677316785
Train Epoch:4 Step:390 loss=0.10288805514574051
Train Epoch:4 Step:400 loss=0.048337195068597794
Train Epoch:5 Step:410 loss=0.03621993958950043
Train Epoch:5 Step:420 loss=0.08414854109287262
Train Epoch:5 Step:430 loss=0.00047292912495322526
Train Epoch:5 Step:440 loss=0.32646191120147705
Train Epoch:5 Step:450 loss=0.027518080547451973
Train Epoch:5 Step:460 loss=0.01023304183036089
Train Epoch:5 Step:470 loss=0.12734712660312653
Train Epoch:5 Step:480 loss=0.0010273485677316785
Train Epoch:5 Step:490 loss=0.10288805514574051
Train Epoch:5 Step:500 loss=0.048337195068597794
Train Epoch:6 Step:510 loss=0.03621993958950043
Train Epoch:6 Step:520 loss=0.08414854109287262
Train Epoch:6 Step:530 loss=0.00047292912495322526
Train Epoch:6 Step:540 loss=0.32646191120147705
Train Epoch:6 Step:550 loss=0.027518080547451973
Train Epoch:6 Step:560 loss=0.01023304183036089
Train Epoch:6 Step:570 loss=0.12734712660312653
Train Epoch:6 Step:580 loss=0.0010273485677316785
Train Epoch:6 Step:590 loss=0.10288805514574051
Train Epoch:6 Step:600 loss=0.048337195068597794
Train Epoch:7 Step:610 loss=0.03621993958950043
Train Epoch:7 Step:620 loss=0.08414854109287262
Train Epoch:7 Step:630 loss=0.00047292912495322526
Train Epoch:7 Step:640 loss=0.32646191120147705
Train Epoch:7 Step:650 loss=0.027518080547451973
Train Epoch:7 Step:660 loss=0.01023304183036089
Train Epoch:7 Step:670 loss=0.12734712660312653
Train Epoch:7 Step:680 loss=0.0010273485677316785
Train Epoch:7 Step:690 loss=0.10288805514574051
Train Epoch:7 Step:700 loss=0.048337195068597794
Train Epoch:8 Step:710 loss=0.03621993958950043
Train Epoch:8 Step:720 loss=0.08414854109287262
Train Epoch:8 Step:730 loss=0.00047292912495322526
Train Epoch:8 Step:740 loss=0.32646191120147705
Train Epoch:8 Step:750 loss=0.027518080547451973
Train Epoch:8 Step:760 loss=0.01023304183036089
Train Epoch:8 Step:770 loss=0.12734712660312653
Train Epoch:8 Step:780 loss=0.0010273485677316785
Train Epoch:8 Step:790 loss=0.10288805514574051
Train Epoch:8 Step:800 loss=0.048337195068597794
Train Epoch:9 Step:810 loss=0.03621993958950043
Train Epoch:9 Step:820 loss=0.08414854109287262
Train Epoch:9 Step:830 loss=0.00047292912495322526
Train Epoch:9 Step:840 loss=0.32646191120147705
Train Epoch:9 Step:850 loss=0.027518080547451973
Train Epoch:9 Step:860 loss=0.01023304183036089
Train Epoch:9 Step:870 loss=0.12734712660312653
Train Epoch:9 Step:880 loss=0.0010273485677316785
Train Epoch:9 Step:890 loss=0.10288805514574051
Train Epoch:9 Step:900 loss=0.048337195068597794
Train Epoch:10 Step:910 loss=0.03621993958950043
Train Epoch:10 Step:920 loss=0.08414854109287262
Train Epoch:10 Step:930 loss=0.00047292912495322526
Train Epoch:10 Step:940 loss=0.32646191120147705
Train Epoch:10 Step:950 loss=0.027518080547451973
Train Epoch:10 Step:960 loss=0.01023304183036089
Train Epoch:10 Step:970 loss=0.12734712660312653
Train Epoch:10 Step:980 loss=0.0010273485677316785
Train Epoch:10 Step:990 loss=0.10288805514574051
Train Epoch:10 Step:1000 loss=0.048337195068597794
图形现显示话损失值
代码语言:javascript复制plt.plot(loss_list)
代码语言:javascript复制plt.plot(loss_list,"r ")
到这里就结束了,如果对你有帮助,欢迎点赞关注评论,你的点赞对我很重要