import argparse
import onnx
def get_onnx_info(onnx_model_path):
try:
# 加载 ONNX 模型
onnx_model = onnx.load(onnx_model_path)
# 获取输入信息
input_info = onnx_model.graph.input[0] # 如果有多个输入,可以更改索引以获取其他输入信息
# 获取输入的数据类型
input_type = input_info.type
# 获取输出信息
output_info = onnx_model.graph.output[0]
output_type = output_info.type
# 定义数据类型映射
data_type_mapping = {
1: 'float32',
7: 'int64',
# 在实际模型中可能有其他类型
}
# 检查输入是否是张量类型
if input_type.HasField("tensor_type"):
input_data_type = data_type_mapping.get(input_type.tensor_type.elem_type, 'Unknown')
input_shape = [dim.dim_value for dim in input_type.tensor_type.shape.dim]
# 将数据类型和维度组合成一个列表
input_info_list = [f"'{input_data_type}'"] + input_shape
# 打印输出格式
print(input_info_list)
else:
return None # 非张量类型的输入
if output_type.HasField("tensor_type"):
output_data_type = data_type_mapping.get(output_type.tensor_type.elem_type, 'Unknown')
output_shape = output_type.tensor_type.shape
# 将数据类型和维度组合成一个列表
output_info_list = [f"'{output_data_type}'"] + [dim.dim_value for dim in output_shape.dim]
# 打印输出格式
print(output_info_list)
else:
return None # 非张量类型的输入
except Exception as e:
print(f"Error: {e}")
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Get ONNX model input information')
parser.add_argument('model_path', type=str, help='Path to the ONNX model file')
# 解析命令行参数
args = parser.parse_args()
# 调用函数并打印 ONNX 模型的输入信息
input_info = get_onnx_info(args.model_path) # 替换为你的 ONNX 模型文件路径