TF-NN

2020-04-16 15:29:54 浏览数 (1)

TF-Neural Network

代码语言:javascript复制
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

#构建隐藏层build the hidden layer
def add_layter(inputs,in_size,out_size,activation_function = None):
    #权值weight
    weight = tf.Variable(tf.random_normal([in_size,out_size]))
    #tf.random_normal()正态分布的数据[in_size,out_size]的矩阵
    #偏执
    bias = tf.Variable(tf.zeros(shape = [1,out_size]) 0.1)
    wx_plus = tf.matmul(inputs,weight) bias#矩阵相乘
    if activation_function is None:
        outputs = wx_plus
    else:
        outputs = activation_function(wx_plus)#calcute the activative function
    return outputs
#-1 to 1 size = 3000 等差数列
x = np.linspace(-1,1,3000,dtype = np.float32)[:,np.newaxis]
#添加噪声add noize
noise = np.random.normal(0,0.05,x.shape).astype(np.float32)
#数据data
y = np.square(x) noise-0.5

#显示数据view the data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x,y)
plt.ion()#打开交互模式
plt.show()

#占位符
xx = tf.placeholder(tf.float32,[None,1])
yy = tf.placeholder(tf.float32,[None,1])


'''
设置隐藏层数和每层的神经元个数
'''
l1 = add_layter(xx,1,10,activation_function = tf.nn.relu)
l2 = add_layter(l1,10,15,activation_function = tf.nn.relu)
l3 = add_layter(l2,15,10,activation_function = tf.nn.relu)

###定义输出层
outputs = add_layter(l1,10,1,activation_function=None)

##定义损失函数
##线性回归问题,loss = MSE
loss = tf.reduce_mean(tf.reduce_sum(tf.square(yy-outputs),reduction_indices=[1]))

#train 训练
train = tf.train.GradientDescentOptimizer(learning_rate = 0.1).minimize(loss = loss)

#initial 初始化
init = tf.global_variables_initializer()

#create the graph 创建图
with tf.Session() as sess:
    sess.run(init)
    for i in range(1000):
        sess.run(train,feed_dict = {xx:x,yy:y})
        if i % 50 == 0:
            print(sess.run(loss,feed_dict={xx:x,yy:y}))
            try:
                ax.lines.remove(lines[0])
            except Exception:
                pass
            prediction = sess.run(outputs,feed_dict = {xx:x})
            lines = ax.plot(x,prediction,'r-',lw = 5)
            plt.pause(0.5)

线性回归方程拟合动态图(matplotlib制作)

0 人点赞