tensorflow入门:Softmax Classication

2019-05-26 15:02:35 浏览数 (1)

Softmax

Softmax用于多元分类,同logistic regression一样使用cross entropy作为损失函数,其原理不再赘述。

另外,多元分类中我们使用one-hot编码来表示种类。 例:A,B,C三种类别的物体表示为[1, 0, 0][0, 1, 0][0, 0, 1],这种表示方式是为了矩阵计算上的便利。

tensorflow实现

代码语言:javascript复制
import tensorflow as tf
import numpy as np

def convert_to_one_hot(Y, C):
    Y = np.eye(C)[Y.reshape(-1)]
    return Y

# traing data
x_data = [[1, 2, 1, 1], [2, 1, 3, 2], [3, 1, 3, 4], [4, 1, 5, 5], [1, 7, 5, 5], 
          [1, 2, 5, 6], [1, 6, 6, 6], [1, 7, 7, 7]]
y_data = [[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0]]

X = tf.placeholder("float", [None, 4])
Y = tf.placeholder("float", [None, 3])

# number of classes
n_class = 3

# define hyperparameter
W = tf.Variable(tf.random_normal([4, n_class]), name="weight")
b = tf.Variable(tf.random_normal([n_class]), name="bias")

# define hypothesis using the built_in softmax
# softmax = tf.exp(logits) / tf.reduce_mean(tf.exp(logits), dim)
hypothesis = tf.nn.softmax(tf.matmul(X, W)   b)

# cross entropy loss
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))

# specify optimizer method
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for step in range(2001):
        _, cost_val = sess.run([optimizer, cost], feed_dict={X: x_data, Y: y_data})
        if step % 200 == 0:
            print(step, cost_val)
    
    # test by making some prediction
    a = sess.run(hypothesis, feed_dict={X: [[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]})
    p = sess.run(tf.argmax(a, 1))

    p_one_hot = convert_to_one_hot(p, n_class)
    
    print()
    print(a)
    print()
    print(p)
    print()
    print(p_one_hot)
代码语言:javascript复制
0 5.5841656
200 0.4481054
400 0.36178762
600 0.28431925
800 0.23740968
1000 0.21427636
1200 0.19528978
1400 0.17937106
1600 0.16581444
1800 0.15412535
2000 0.1439426

[[2.0620492e-03 9.9792922e-01 8.6969685e-06]
 [9.0453833e-01 8.4767073e-02 1.0694573e-02]
 [5.9199765e-09 2.7693674e-04 9.9972302e-01]]

[1 0 2]

[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]

因为是个简单的例子,损失函数不断下降。

0 人点赞