深度使用卷积--使用tensorflow实现卷积

2021-01-29 10:35:24 浏览数 (1)

在上一篇我们了解了卷积的概念,并且使用numpy实现了卷积。但是现在其实很少会需要自己去实现卷积操作了,因为很多深度学习框架已经很好的封装了基础API甚至封装好了网络模型,比如tensorflow,pytorch,caffe等等。今天我们就使用tensorflow来实现卷积,顺便和我们自己实现的卷积结果对比,验证我们的实现是否正确。

tensorflow实现卷积

API介绍

tensorflow是一个数据流图,tf基础使用后面有时间会再从基础进行介绍,今天直接上卷积的使用了,主要用到的API就是tf.nn.conv2d

对参数进行简单介绍

代码语言:javascript复制
tf.nn.conv2d(
    input,
    filter=None,
    strides=None,
    padding=None,
    use_cudnn_on_gpu=True,
    data_format='NHWC',
    name=None
)

input:输入数据shape为batch, in_height, in_width, in_channels

对应

输入batch_size,输入图片宽度,输入图片高度,输入图片通道数(RGB就是3通道)

filter:卷积核shape为 filter_height, filter_weight, in_channel, out_channels

对应

卷积核对高,卷积核对宽,上一层输入的通道数,卷积核的个数

strides:计算卷积时的步长, 1, strides, strides, 1

padding:填充方式,通过上一篇应该有比较深入的了解了,两种填充方式,“SAME” 和 “VALID”

use_cudnn_on_gpu:是否使用cudnn加速

data_format:数据格式,一般使用默认的NHWC,通道在最后

tensorflow代码实现

数据处理

我们还是用和上一篇一样的数据,回顾下在numpy里面我们使用的输入shapebatch,C,H,W通道数是在前面,但是在tensorflow提供的API中默认是使用的NHWC,同理filter我们在使用numpy实现时shape是C_out,C_in,H,W在tf里要求的是H,W,C_in,C_out。

所以我们第一步需要对数据进行下处理,让他和我们要求对shape一致。

代码语言:javascript复制
    #inputs输入为[3,9,9]先把batch那一维度expand,[1,3,9,9]
    inputs = np.expand_dims(inputs, axis=0)
    #使用transpose变形为tf需要的[1,9,9,3]
    inputs = inputs.transpose((0,2,3,1))
    # filter:[2,3,3,3]--[3,3,3,2]
    filter= filter.transpose((2,3,1,0))

tf实现卷积

代码语言:javascript复制
def tf_conv(inputs,filter,padding='SAME'):
	#定义输入和输出placeholder
	# input = [batch, in_height, in_width, in_channels]
    # filter = [filter_height, filter_width, in_channels, out_channels]
    tf_inputs = tf.placeholder(shape=(1, 9, 9, 3), dtype=tf.float32)
    tf_filter = tf.placeholder(shape=(3, 3, 3, 2), dtype=tf.float32)
    conv = tf.nn.conv2d(tf_inputs, tf_filter, strides=[1, 1, 1, 1], padding=padding)
    #通过sess.run计算conv值
    with tf.Session() as sess:
        out = sess.run(conv,feed_dict={tf_inputs: inputs,tf_filter:filter})
        out = np.squeeze(out)
        return out

numpy实现结果

我们运行下在上一篇的代码结果:

代码语言:javascript复制
numpy conv 
 [[[ 110.  186.  249.  312.  375.  438.  501.  564.  338.]
  [ 186.  297.  378.  459.  540.  621.  702.  783.  456.]
  [ 249.  378.  459.  540.  621.  702.  783.  864.  501.]
  [ 312.  459.  540.  621.  702.  783.  864.  945.  546.]
  [ 375.  540.  621.  702.  783.  864.  945. 1026.  591.]
  [ 438.  621.  702.  783.  864.  945. 1026. 1107.  636.]
  [ 501.  702.  783.  864.  945. 1026. 1107. 1188.  681.]
  [ 564.  783.  864.  945. 1026. 1107. 1188. 1269.  726.]
  [ 338.  456.  501.  546.  591.  636.  681.  726.  398.]]

 [[ 134.  231.  312.  393.  474.  555.  636.  717.  446.]
  [ 231.  378.  486.  594.  702.  810.  918. 1026.  627.]
  [ 312.  486.  594.  702.  810.  918. 1026. 1134.  690.]
  [ 393.  594.  702.  810.  918. 1026. 1134. 1242.  753.]
  [ 474.  702.  810.  918. 1026. 1134. 1242. 1350.  816.]
  [ 555.  810.  918. 1026. 1134. 1242. 1350. 1458.  879.]
  [ 636.  918. 1026. 1134. 1242. 1350. 1458. 1566.  942.]
  [ 717. 1026. 1134. 1242. 1350. 1458. 1566. 1674. 1005.]
  [ 446.  627.  690.  753.  816.  879.  942. 1005.  590.]]] 

tensorflow运行结果

运行下今天介绍的调用tensorflow的卷积API运行的结果:

代码语言:javascript复制
tf conv 
 [[[ 110.  186.  249.  312.  375.  438.  501.  564.  338.]
  [ 186.  297.  378.  459.  540.  621.  702.  783.  456.]
  [ 249.  378.  459.  540.  621.  702.  783.  864.  501.]
  [ 312.  459.  540.  621.  702.  783.  864.  945.  546.]
  [ 375.  540.  621.  702.  783.  864.  945. 1026.  591.]
  [ 438.  621.  702.  783.  864.  945. 1026. 1107.  636.]
  [ 501.  702.  783.  864.  945. 1026. 1107. 1188.  681.]
  [ 564.  783.  864.  945. 1026. 1107. 1188. 1269.  726.]
  [ 338.  456.  501.  546.  591.  636.  681.  726.  398.]]

 [[ 134.  231.  312.  393.  474.  555.  636.  717.  446.]
  [ 231.  378.  486.  594.  702.  810.  918. 1026.  627.]
  [ 312.  486.  594.  702.  810.  918. 1026. 1134.  690.]
  [ 393.  594.  702.  810.  918. 1026. 1134. 1242.  753.]
  [ 474.  702.  810.  918. 1026. 1134. 1242. 1350.  816.]
  [ 555.  810.  918. 1026. 1134. 1242. 1350. 1458.  879.]
  [ 636.  918. 1026. 1134. 1242. 1350. 1458. 1566.  942.]
  [ 717. 1026. 1134. 1242. 1350. 1458. 1566. 1674. 1005.]
  [ 446.  627.  690.  753.  816.  879.  942. 1005.  590.]]] 

可以看到两个结果一致~

OK,通过两篇文章的介绍,相信对卷积的实现有很好的理解了。在理解的基础才能更好的去使用各种框架的封装接口。

0 人点赞