pytorch网络结构查看
  m3bOGMt8pvLp 2023年12月08日 31 0

1、使用torchsummary来打印网络特征提取部分参数

实例:

import torch
from torchsummary import summary
from model import AlexNet
device=torch.device("cuda" if torch.cuda.is_available() else "cpu" )
net=AlexNet().to(device)


summary(net, (3, 224, 224),print=True)

2、使用Netron (onnx形式)

(1)使用模型类 除了定义一个模型外,还需要传入一个数据样例

import torchvision.models as models
import torch


import onnx
import onnx.utils
import onnx.version_converter




# 定义数据+网络
data = torch.randn(2, 3, 256, 256)
net = models.resnet34()


# 导出
torch.onnx.export(
    net,
    data,
    'model.onnx',
    export_params=True,
    opset_version=8,
)


# 增加维度信息
model_file = 'model.onnx'
onnx_model = onnx.load(model_file)
onnx.save(onnx.shape_inference.infer_shapes(onnx_model), model_file)

(2)pytorch模型转换到onnx模型

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

 

model——需要导出的pytorch模型

args——模型的输入参数,满足输入层的shape正确即可。

path——输出的onnx模型的位置。例如‘yolov5.onnx’。

export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。

verbose——是否打印模型转换信息。default=False。

input_names——输入节点名称。default=None。

output_names——输出节点名称。default=None。

do_constant_folding——是否使用常量折叠,默认即可。default=True。

dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。

格式如下 :

1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}

2)仅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}

3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}

opset_version——opset的版本,低版本不支持upsample等操作。 

代码:

 

import torch
import torch.onnx
weights_path = "./AlexNet.pth"
net = AlexNet(num_classes=5)
net.load_state_dict(torch.load(weights_path))
net.eval()


# 创建一个输入张量
x = torch.randn(1, 3, 224, 224, requires_grad=True)


# 导出模型到 ONNX 格式
input_names = ['input']
output_names = ['output']
torch.onnx.export(net, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose=True)

 

collections.OrderedDict 类型的对象,而不是 torch.nn.Module 类型的模型。collections.OrderedDict 类型的对象是包含模型权重的字典,而不是模型本身。

如果你只是加载了模型的权重而不是整个模型,你需要重新创建模型的结构,然后加载权重。

 

 

 

 

【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年12月08日 0

暂无评论

m3bOGMt8pvLp
作者其他文章 更多