AI学习-查看ONNX模型的输入格式
  llt0tXqeaug8 2023年11月27日 47 0
import argparse
import onnx


def get_onnx_info(onnx_model_path):
    try:
        # 加载 ONNX 模型
        onnx_model = onnx.load(onnx_model_path)


        # 获取输入信息
        input_info = onnx_model.graph.input[0]  # 如果有多个输入,可以更改索引以获取其他输入信息


        # 获取输入的数据类型
        input_type = input_info.type
        
        # 获取输出信息
        output_info = onnx_model.graph.output[0]


        output_type = output_info.type


        # 定义数据类型映射
        data_type_mapping = {
            1: 'float32',
            7: 'int64',
            # 在实际模型中可能有其他类型
        }


        # 检查输入是否是张量类型
        if input_type.HasField("tensor_type"):
            input_data_type = data_type_mapping.get(input_type.tensor_type.elem_type, 'Unknown')
            input_shape = [dim.dim_value for dim in input_type.tensor_type.shape.dim]


            # 将数据类型和维度组合成一个列表
            input_info_list = [f"'{input_data_type}'"] + input_shape


            # 打印输出格式
            print(input_info_list)
        else:
            return None  # 非张量类型的输入
            
        if output_type.HasField("tensor_type"):
            output_data_type = data_type_mapping.get(output_type.tensor_type.elem_type, 'Unknown')
            output_shape = output_type.tensor_type.shape


            # 将数据类型和维度组合成一个列表
            output_info_list = [f"'{output_data_type}'"] + [dim.dim_value for dim in output_shape.dim]




            # 打印输出格式
            print(output_info_list)
        else:
            return None  # 非张量类型的输入
    except Exception as e:
        print(f"Error: {e}")
        return None




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Get ONNX model input information')
    parser.add_argument('model_path', type=str, help='Path to the ONNX model file')
    # 解析命令行参数
    args = parser.parse_args()
            
    # 调用函数并打印 ONNX 模型的输入信息
    input_info = get_onnx_info(args.model_path)  # 替换为你的 ONNX 模型文件路径
【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月27日 0

暂无评论

推荐阅读
  X5zJxoD00Cah   2023年11月02日   31   0   0 调用函数Pythonhtml
  gBkHYLY8jvYd   2023年11月19日   82   0   0 输出格式变换规则数据
  gBkHYLY8jvYd   2023年11月19日   21   0   0 输出格式进制字符串
llt0tXqeaug8