基于MobileNet的UNet语义分割模型
  vzuygR7xgfdK 2023年11月02日 248 0

Unet模型

U-net网络非常简单,前半部分作用是特征提取,后半部分是上采样。在一些文献中也把这样的结构叫做编码器-解码器结构。由于此网络整体结构类似于大写的英文字母U,故得名U-net。 U-net与其他常见的分割网络有一点非常不同的地方:U-net采用了完全不同的特征融合方式:拼接,U-net采用将特征在channel维度拼接在一起,形成更厚的特征。而FCN融合时使用的对应点相加,并不形成更厚的特征。

基于MobileNet的UNet语义分割模型_mobilenet

mobilenet模型

MobileNet是一种轻量级的卷积神经网络,它的主要目标是在保持模型准确性的同时,尽可能地减少模型的大小和计算复杂度。MobileNet的设计思想是使用深度可分离卷积层来代替传统的卷积层,以减少计算量和模型大小。 MobileNet的深度可分离卷积层是由深度卷积层和逐点卷积层组成的。深度卷积层只考虑每个通道内的空间关系,而逐点卷积层则只考虑每个位置的通道关系。这种分离的方式使得MobileNet可以用更少的参数和计算量来学习空间和通道的特征,从而减小了模型的大小和计算复杂度。

基于MobileNet的UNet语义分割模型_Unet_02

为什么选用mobilenet

MobileNet就是性价比极高的一个轻量级网络。而UNet的backbone,即特征提取网络为一个参数量极大的VGG16模型,可想而知很多嵌入式设备是带不动的,更不能得到实时的分割效果。因此,本人想通过使用MobileNet替换VGG16的方式来轻量化我们的UNet模型,使得参数量减少,来达到加速推理的效果。本文中,本人基于Tensorflow深度学习框架成功修改了网络的backbone,并进行模型融合,提高了模型特征提取的准确性。

实现代码:


def get_mobilenet_encoder(input_height=224, input_width=224,
pretrained='imagenet', channels=3):

# todo add more alpha and stuff

assert input_height % 32 == 0
assert input_width % 32 == 0

alpha = 1.0
depth_multiplier = 1
dropout = 1e-3

img_input = Input(shape=(input_height, input_width, channels))

x = _conv_block(img_input, 32, alpha, strides=(2, 2))
x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
f1 = x

x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
                          strides=(2, 2), block_id=2)
x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
f2 = x

x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
                          strides=(2, 2), block_id=4)
x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
f3 = x

x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
                          strides=(2, 2), block_id=6)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
f4 = x

x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
                          strides=(2, 2), block_id=12)
x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
f5 = x

if pretrained == 'imagenet':
    model_name = 'mobilenet_%s_%d_tf_no_top.h5' % ('1_0', 224)
    BASE_WEIGHT_PATH = ('https://github.com/fchollet/deep-learning-models/'
                        'releases/download/v0.6/')
    weight_path = BASE_WEIGHT_PATH + model_name
    weights_path = tf.keras.utils.get_file(model_name, weight_path)

    Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True)

return img_input, [f1, f2, f3, f4, f5]


def _unet(classes, encoder, l1_skip_conn=True, input_height=416,
input_width=608, channels=3):

img_input, levels = encoder(
    input_height=input_height, input_width=input_width, channels=channels)
[f1, f2, f3, f4, f5] = levels

o = f4

o = (ZeroPadding2D((1, 1), data_format=imgchannel))(o)
o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=imgchannel))(o)
o = (BatchNormalization())(o)

o = (UpSampling2D((2, 2), data_format=imgchannel))(o)
o = (concatenate([o, f3], axis=-1))
o = (ZeroPadding2D((1, 1), data_format=imgchannel))(o)
o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=imgchannel))(o)
o = (BatchNormalization())(o)

o = (UpSampling2D((2, 2), data_format=imgchannel))(o)
o = (concatenate([o, f2], axis=-1))
o = (ZeroPadding2D((1, 1), data_format=imgchannel))(o)
o = (Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=imgchannel))(o)
o = (BatchNormalization())(o)

o = (UpSampling2D((2, 2), data_format=imgchannel))(o)

if l1_skip_conn:
    o = (concatenate([o, f1], axis=-1))

o = (ZeroPadding2D((1, 1), data_format=imgchannel))(o)
o = (Conv2D(64, (3, 3), padding='valid', activation='relu', data_format=imgchannel, name="seg_feats"))(o)
o = (BatchNormalization())(o)

o = Conv2D(classes, (3, 3), padding='same',
           data_format=imgchannel)(o)

model = get_segmentation_model(img_input, o)

return model
【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月08日 0

暂无评论