【PyTorch】实现SqueezeNet的Fire模块
  TEZNKK3IfmPf 2023年11月15日 12 0

问题

SqueezeNet是一款非常经典的CV网络,其设计理念对后续的很多网络都有非常强的指导意义,其核心思想包括:

  • 使用1x1卷积核替代3x3,主要原因是3x3的卷积核参数量是1x1的9倍多;
  • 降低3x3卷积核的通道数量;
  • 网络结构中延迟下采样的时机以获得较大尺寸的激活特征图;

方法

下面介绍PyTorch实现的SqueezeNet网络最核心的Fire模块,如下:

import torch
from torch import nn, Tensor
from typing import Any

class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: # 增加in_xxx和out_xxx的好处是,调用的时候可以省略参数名
        super().__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, **kwargs) # **容易漏掉
        self.relu = nn.ReLU()
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2d(x)
        out = self.relu(x)
    
        return out

class Fire(nn.Module):
  
    def __init__(self, in_channels: int, s_1x1: int, e_1x1: int, e_3x3: int) -> None:
        super().__init__()
    
        self.squeeze = BasicConv2d(in_channels, s_1x1, kernel_size=1)
    
        self.expand_1x1 = BasicConv2d(s_1x1, e_1x1, kernel_size = 1)
        self.expand_3x3 = BasicConv2d(s_1x1, e_3x3, kernel_size = 3, padding = 1) # p=1是为了保持3x3特征图不变
  
    def forward(self, x: Tensor) -> Tensor:
        x = self.squeeze(x)
    
        return torch.cat([
            self.expand_1x1(x), 
            self.expand_3x3(x)
        ], dim=1)


if __name__ == '__main__':
  
    x = torch.rand(size=(1, 3, 224, 224))
  
    conv2d = BasicConv2d(3,  64, kernel_size = 3, padding = 1, stride = 1)
    print(conv2d(x).shape) # torch.Size([1, 64, 224, 224])   
  
    fire = Fire(3, 32, 32, 48)
    print(fire(x).shape) # torch.Size([1, 80, 224, 224])

【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月15日 0

暂无评论

推荐阅读
  TEZNKK3IfmPf   2024年03月29日   35   0   0 pytorch
  TEZNKK3IfmPf   2023年11月14日   21   0   0 pytorch
  I7JaHrFMuDsU   24天前   21   0   0 pytorch
  TEZNKK3IfmPf   2023年11月14日   36   0   0 listpytorch
  TEZNKK3IfmPf   2023年11月14日   17   0   0 pytorch
  TEZNKK3IfmPf   2023年11月15日   15   0   0 pytorch
TEZNKK3IfmPf