介绍
这篇文章里,我们使用python numpy实现一个线性分类器,使用mnist的数据集对线性分类器进行训练与预测。文章会详细介绍线性分类器的实现细节包括,前向传播,反向传播实现。
测试数据
有很多方法加载mnist数据集,我们这里使用sklearn库提供的load_digits函数,下载mnist数据集,该函数会在当前目录下建立一个MNIST目录,数据都下载到该目录下面:
下面几行代码是mnist数据下载和查看:
代码语言:javascript复制digits = load_digits()
print(dir(digits))
print(digits.data.shape)
print("feature name: ",digits.feature_names[0])
print(digits.frame)
print("data: ",digits.images[0])
print(digits.images[0].shape)
print("target: ",digits.target[0])
print("data: ", digits.data[0])
print("data type: ", type(digits.data[0]))
sample_index = 0
plt.figure(figsize=(3, 3))
plt.imshow(digits.images[sample_index], cmap=plt.cm.gray_r,
interpolation='nearest')
plt.title("image label: %d" % digits.target[sample_index]);
结果:
代码语言:javascript复制['DESCR', 'data', 'feature_names', 'frame', 'images', 'target', 'target_names']
(1797, 64)
feature name: pixel_0_0
None
data: [[ 0. 0. 5. 13. 9. 1. 0. 0.]
[ 0. 0. 13. 15. 10. 15. 5. 0.]
[ 0. 3. 15. 2. 0. 11. 8. 0.]
[ 0. 4. 12. 0. 0. 8. 8. 0.]
[ 0. 5. 8. 0. 0. 9. 8. 0.]
[ 0. 4. 11. 0. 1. 12. 7. 0.]
[ 0. 2. 14. 5. 10. 12. 0. 0.]
[ 0. 0. 6. 13. 10. 0. 0. 0.]]
(8, 8)
target: 0
data: [ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3.
15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0.
0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12.
0. 0. 0. 0. 6. 13. 10. 0. 0. 0.]
data type: <class 'numpy.ndarray'>
线性分类器结构简介:
这里实现的线性分类器很简单,首先将输入拉平为一个向量,我们使用一个权重矩阵乘以该输入得到输出向量,使用softmax得到 不同类别的分数,最终挑选分数最大的类别作为当前输入所属类别的预测结果。
如上图,我们这里输出为10分类(0-9),输入为长度为64的向量。则权重矩阵的维度为 64 * 10(或者10 * 64,取决于是权重左乘 输入还是输入左乘权重),得到10维的输出向量后,使用softmax以下公式,计算输入对于每个分类的得分(也可以理解为属于该分类的概率),softmax这里会将所有在上一步得到的值缩放到大于0 的范围,然后给每个分类计算一个0-1的值,所有分类的值总和为1
前向传播
有个上述的分类器结构,我们的前向传播就很好理解了
代码语言:javascript复制input_size = 64
output_size = 10
W = np.random.uniform(size=(input_size, output_size), high=0.1, low=-0.1)
B = np.random.uniform(size=output_size, high=0.1, low=-0.1)
def softmax(X):
expV = np.exp(X)
return expV / np.sum(expV, axis=-1, keepdims=True)
def forward(X):
z = np.dot(X, W) B
return softmax(z)
def predict(X):
tmp = forward(X)
#print(tmp)
#print("shape of x: ", X.shape)
if len(X.shape) == 1:
return np.argmax(tmp)
else:
return np.argmax(tmp, axis=1)
主要三个函数:
forward:对于输入X,首先通过dot 方法左乘 权重矩阵W ,之后 加上偏置量 B 得到输出z
softmax: z经过softmax得到最终的每个类别的预测分数,根据上面公式,首先使用np.exp计算以e为底的指数,之后对每个指数求该指数与所有指数求和结果值的分数(0-1)作为输出。这里的axis=-1是指按照每行为一组,计算一组内的sum。这样softmax函数就可以一次处理一批数据(X为多行)
predict实际就是调用forward前向传播,之后求对应axis轴上取值最大的下标,作为对应分类
反向传播
反向传播的核心是按照梯度下降的方向进行权重更新。我们这里损失函数选择为交叉熵损失函数,关于所以训练阶段softmax输出需要代入如下交叉熵损失公式计算loss
这里yc是真是标签等于1或者0, pc是softmax输出,是0-1之间的浮点数。c为从1到M类,这里M为10. 这里的pc 就是softmax输出。
交叉熵损失函数对输入pc的导数较为复杂,但是 交叉熵加softmax整体对上面线性分类器的输出z的导数解析形式很简单:
这里dz = pc - yc
关于这个式子由来的详细解释很多,例如这篇博客:https://blog.csdn.net/u014313009/article/details/51045303
或者这篇:
https://blog.csdn.net/weixin_42156589/article/details/80518437
根据链式法则,求得对z的导数后求对w的导数就很简单了。
根据上面前向传播逻辑 z = x*W B , 则, dw = x * dz, dB = dz
代码如下:
代码语言:javascript复制def get_one_hot(n_classes, idx):
return np.eye(n_classes)[idx]
def compute_llk(Y_true, Y_pred):
""" should be a vector for both input parameter."""
EPS = 1e-8
Y_hat = np.atleast_2d(Y_true)
Y_pred = np.atleast_2d(Y_pred)
# lines are batch number
# sum for each row to compute the cross-entrophy score.
llk = np.sum(np.log(EPS Y_pred) * Y_hat, axis=-1)
# return negative mean as this batchs' cross-entryphy score.
return -np.mean(llk)
def compute_loss(X, y):
o = forward(X)
one_hot = get_one_hot(10, y)
return compute_llk(one_hot, o)
def train_algo(x,y,lr):
global W,B
y_pred = predict(x)
# see : https://www.cnblogs.com/wuliytTaotao/p/10787510.html
# loss relate to z
#print("y pred is: ", y_pred)
dz = y_pred - get_one_hot(output_size, y)
# x * dz -> gradW
gradW = np.outer(x, dz)
# 1 * dz -> dz
gradB = dz
W = W - lr * gradW
B = B - lr * gradB
上面主要四个函数:
get_one_hot: 很简单,给定类别数,和y(这里就就是小标)返回one-hot向量
compute_llk: 是交叉熵损失函数的实现,其中EPS为为了避免Y_pred为0 导致log求值返回-inf,这里的axis=-1和上面的类似,表示对一行内的所有列进行求和。
最后返回值求mean表示对这批数据的loss求平均值。
compute_loss: 逻辑很清晰,根据输入,进行前向传播,计算y标签的one_hot编码,计算llk 损失。
train_algo:
1> 首先对输入x执行预测函数predict,输出y_pred
2> 根据上面dloss的公式,计算dz
3> 使用np.outer函数,根据链式法则,计算损失对W的导数gradW(这个outer函数就是计算两个向量的外积,)下面给出一个简单的推导,字体很丑请忽略。
4> 计算损失对偏置项的导数gradB
5> 更新权重和偏置