注意力机制——SENet原理详解及源码解析
  OEpWoCKNwrr1 2023年11月05日 35 0


🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

SENet原理详解

  先来简单说说我们为什么需要使用注意力机制,这是因为我们希望网络可以专注于一些更加重要的东西,这对物体的识别定位都大有益处。enmmm,是不是够简单呢。🍦🍦🍦如果你是第一次学习注意力机制,我觉得你会充满疑惑,怎么让网络注意到一些更加重要的东西呢?那么带着疑问,和我一起来看看SENet的原理,等我介绍完后看你能否理解喔。🌿🌿🌿

  话不多说,我们直接来看SENet的关键结构,如下图所示:

注意力机制——SENet原理详解及源码解析_激活函数

  我们来介绍一下上图的网络,首先是输入X,其维度为注意力机制——SENet原理详解及源码解析_SENet_02,经过一系列卷积等维度变化操作后得到特征图U,其维度注意力机制——SENet原理详解及源码解析_激活函数_03【注:其实从特征图U开始向后才是真正的SENet的结构,这一步转换只是一些特征图维度变化】 当我们得到U后,会先将U经过全局平均池化的操作,即将U的维度由注意力机制——SENet原理详解及源码解析_激活函数_03变成注意力机制——SENet原理详解及源码解析_激活函数_05,此步骤对应着上图中的注意力机制——SENet原理详解及源码解析_激活函数_06。接着会执行步骤注意力机制——SENet原理详解及源码解析_SENet_07,此步骤包含两个全连接层已经两个激活函数,为方便大家理解,做此过程的图如下:

注意力机制——SENet原理详解及源码解析_注意力机制_08

  从上图我们可以看出,在第一次全连接层后我们使用Relu激活函数,此时得到的输出维度为注意力机制——SENet原理详解及源码解析_池化_09,通常情况下注意力机制——SENet原理详解及源码解析_全连接_10设置为注意力机制——SENet原理详解及源码解析_全连接_11注意力机制——SENet原理详解及源码解析_池化_12。第二个全连接层后使用Sigmoid函数,将每层数值归一化到0-1之间,以此表示每个通道的权重,第二个全连接的输出也为注意力机制——SENet原理详解及源码解析_激活函数_05。得到了最后注意力机制——SENet原理详解及源码解析_激活函数_05的输出后,我们将U和刚刚得到的注意力机制——SENet原理详解及源码解析_激活函数_05输出相乘,得到最终的特征图注意力机制——SENet原理详解及源码解析_全连接_16,最终特征图注意力机制——SENet原理详解及源码解析_全连接_16的维度和U一致,为注意力机制——SENet原理详解及源码解析_激活函数_03

  介绍到这里,大家是否明白了呢。如果你还没明白的话,再来看下图吧!!!首先下图左上角表示为两个通道的特征图,经平均池化后得到左下角的图;再次经过两次全连接层和激活函数后,转化成了右下角的图,最后用右下角的0.5、0.6分别乘原始的特质图,则得到最终的右上角的图。可以发现经过SENet特征图输入前后尺寸没有变化,其值发生变化。

注意力机制——SENet原理详解及源码解析_注意力机制_19

SENet代码详解

理解了上文所述的SENet原理,那么编写SENet的代码就非常简单了,如下:

def SENet(input):
    #全局平均池化
    x = nn.AdaptiveAvgPool2d((1,1))(input)
    x = x.view(1, -1)
    #第一个全连接层
    x = nn.Linear(2, 1)(x)
    x = nn.functional.relu(x)
    #第二个全连接层
    x = nn.Linear(1, 2)(x)
    x = nn.functional.sigmoid(x)

    return x


if __name__ == '__main__':
    input = torch.ones(1, 2 ,2 ,2)
    output = SENet(input)
    # 将SENet的输出维度进行变化,以便后面的乘机操作
    output = output.view(input.shape[0], input.shape[1],1, 1)
    SE_output = input*output
    
    print(input)
    print(input.shape)
    print(output)
    print(output.shape)
    print(SE_output)

我们可以来看一下上述代码的输出,如下:

input:

注意力机制——SENet原理详解及源码解析_池化_20

output:

注意力机制——SENet原理详解及源码解析_全连接_21

SE_output:

注意力机制——SENet原理详解及源码解析_激活函数_22

【注意:大家需要注意在最后一步相乘操作前需要先View一下输出output的尺寸,不然乘的结果不一样哦,这涉及到一些pytorch乘法的操作,这部分我也调试了很久,大家可以动手试试看。】

 
 

如若文章对你有所帮助,那就🛴🛴🛴

        

注意力机制——SENet原理详解及源码解析_激活函数_23


【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月08日 0

暂无评论

OEpWoCKNwrr1