基于pytorch可视化alexnet卷积核和特征图

2020-11-13 09:57:26 浏览数 (1)

引子:

之前一篇我们使用paddle paddle实现了alexnet, 今天我们来对alexnet进行可视化,具体看下每个卷积层的卷积到底是个什么样的,以加深对深度卷积网络的理解。这次我们使用pytorch实现的alexnet实现作为网络,使用pretrain的权重是pytorch官方提供的。

使用一张图片进行前向传播和可视化的数据来源,来自:https://raw.githubusercontent.com/mrqwertyuiop/Dog-Cat-Image-Classification---Machine-Learning-Model-CNN/master/cat1.jpg

alexnet实现:

与之前一篇关于复现alexnet代码类似,这里的alexnet实现使用pytorch,具体代码如下:

代码语言:javascript复制
class AlexNet(nn.Module):

    def __init__(self, num_classes: int = 1000) -> None:
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

注意到,与之前不同,这里写了一个新类AlexNet 继承自torch的nn.Module:

init函数里面设置模型结构(这里的实现与之前稍有不同,主要是为了能够复用torch官网提供的模型权重,如果自己训练的话,也可以进行网络结构调整)。

forward函数首先将输入导入特征提取层,之后进行一次平均池化 avg pooling,然后将特征图拉平,(torch.flatten),并送入带有dropout的拥有两个全连接层作为分类器。

输入图像处理和前向传播

有了上面的网络结构, 我们下面吧图片输入到网络中,图片数据是上述的一张猫的图片,代码如下:

代码语言:javascript复制
 from PIL import Image
    from torchvision import transforms
    input_image = Image.open(filename)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    # batch 1 , channel 3, length * width 224 * 224
    #print(input_batch.shape)
    #torch.Size([1, 3, 224, 224])

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes

这里的filename就是下载到本地的猫的图片,这里对图片的预处理包括 1> resize成 256 * 256; 2> 裁剪到224 * 224; 3>使用预设的均值和方差标准化处理。实际上核心就是output = model(input_batch)这行,将数据灌入模型的操作。

卷积可视化

将数据灌入模型后,pytorch框架会进行对应的前向传播,要对卷积核可视化,我们需要把卷积核从框架中提取出来。多谢torch提供的接口,我们可以直接把对应层的权重取出,主要代码如下:

代码语言:javascript复制
conv1 = dict(model.features.named_children())['0']
 localw = conv1.weight.cpu().clone()   
  print("total of number of filter : ", len(localw))
        for i in range(1,num):
            localw0 = localw[i]
            #print(localw0.shape)    
            # mean of 3 channel.
            #localw0 = torch.mean(localw0,dim=0)
            # there should be 3(3 channels) 11 * 11 filter.
            plt.figure(figsize=(20, 17))
            if (len(localw0)) > 1:
                for idx, filer in enumerate(localw0):
                    plt.subplot(9, 9, idx 1) 
                    plt.axis('off')
                    plt.imshow(filer[ :, :].detach(),cmap='gray')
            else:
                    plt.subplot(9, 9, idx 1) 
                    plt.axis('off')
                    plt.imshow(localw0[0, :, :].detach(),cmap='gray')

这里的model就是之前的模型,features就是模型中sequence 子模块,其中named_children()会取出sequence包含的子结构,标号为0 的就是第一个卷积层,标号为10就是最优一层卷积层。之后我们取出该层的卷积核,将卷积核画出来如下图:

将最后一层卷积层的卷积核(最后一层卷积核只画了部分):

分析:

可以看出第一层卷积核 人类还是可以比较容易理解,有些提取的是边缘,有些提取的是圆形,有些提取的是斑点等。

最后一层卷积层的卷积核就已经看不出来是提取的什么东西了,即卷积核提取的是更加抽象的特征。

特征图可视化:

除了可以可视化卷积核,来观察网络,还可以将网络中的特征图可视化出来。在zfnethttps://arxiv.org/abs/1311.2901一篇论文中,使用转置卷积将特征图映射回原始图像空间。来观察每层的特征图。我们这里偷个懒直接将特征图从网络中拿出来,可视化。 可视化实现是通过使用pytorch提供的hook机制,在卷积层中注册一个回调函数,把卷积层的输入输出存下载实现的,具体实现如下:

代码语言:javascript复制
class Hook(object):
    def __init__(self):
        self.module_name = []
        self.features_in_hook = []
        self.features_out_hook = []


    def __call__(self,module, fea_in, fea_out):
        print("hooker working", self)
        self.module_name.append(module.__class__)
        self.features_in_hook.append(fea_in)
        self.features_out_hook.append(fea_out)
        return None
    

def plot_feature(model, idx):
    
    hh = Hook()
    model.features[idx].register_forward_hook(hh)
    
    forward_model(model,False)
    print(hh.module_name)
    print((hh.features_in_hook[0][0].shape))
    print((hh.features_out_hook[0].shape))
    
    out1 = hh.features_out_hook[0]

    total_ft  = out1.shape[1]
    first_item = out1[0].cpu().clone()    

    plt.figure(figsize=(20, 17))
    

    for ftidx in range(total_ft):
        if ftidx > 99:
            break
        ft = first_item[ftidx]
        plt.subplot(10, 10, ftidx 1) 
        
        plt.axis('off')
        #plt.imshow(ft[ :, :].detach(),cmap='gray')
        plt.imshow(ft[ :, :].detach())

这里我们首先实现了一个hook类,之后再plot_feature函数中,将改hook类的对象注册到要进行可视化的网络中的某层中:

model.features[idx].register_forward_hook(hh)

model在进行前向传播的时候会调用hook的__call__函数,我们也就是在那里存储了当前层的输入和输出。这里的features_out_hook 是一个list,每次前向传播一次,都是调用一次,也就是features_out_hook 长度会增加1.

第一层卷积层的特征图如下:

最后一层卷积层特征图(部分):

分析:

可以看出第一层的卷积层输出,特征图里面还可以看出猫的形状,最后一层卷积网络的输出特征图,看着有点像热力图,并且完全没有猫的样子,是更加抽象的图片表达,这点与上面卷积核可视化结果类似。

0 人点赞