yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测
  XZVAVmlOqzo6 2023年11月02日 74 0

yolov5格式的香烟数据集

https://download.csdn.net/download/qq_42864343/88110620?spm=1001.2014.3001.5503

创建yolo-nas的运行环境

进入Pycharm的terminal,输入如下命令

conda create -n yolonas python=3.8

pip install super-gradients

使用自定义数据训练Yolo-nas

准备数据

在YOLO-NAS根目录下创建mydata文件夹(名字可以自定义),目录结构如下:

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据集

将自己数据集里用labelImg标注好的xml文件放到xml目录

图片放到images目录

划分数据集

把划分数据集代码 split_train_val.py放到yolo-nas目录下:

# coding:utf-8

import os
import random
import argparse

# 通过argparse模块创建一个参数解析器。该参数解析器可以接收用户输入的命令行参数,用于指定xml文件的路径和输出txt文件的路径。
parser = argparse.ArgumentParser()
# 指定xml文件的路径
parser.add_argument('--xml_path', default='mydata/xml', type=str, help='input xml label path')
# 设置输出txt文件的路径
parser.add_argument('--txt_path', default='mydata/dataSet', type=str, help='output txt label path')
opt = parser.parse_args()
# 训练集与验证集 占全体数据的比例
trainval_percent = 1.0
# 训练集 占训练集与验证集总体 的比例
train_percent = 0.9
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
# 获取到xml文件的数量
total_xml = os.listdir(xmlfilepath)
# 判断txtsavepath是否存在,若不存在,则创建该路径。
if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

# 统计xml文件的个数,即Image标签的个数
num = len(total_xml)
list_index = range(num)
# tv (训练集和测试集的个数) = 数据总数 * 训练集和数据集占全体数据的比例
tv = int(num * trainval_percent)
# 训练集的个数
tr = int(tv * train_percent)
#  按数量随机得到取训练集和测试集的索引
trainval = random.sample(list_index, tv)
#  打乱训练集 
train = random.sample(trainval, tr)
#  创建存放所有图片数据路径的文件
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
#  创建存放所有测试图片数据的路径的文件
file_test = open(txtsavepath + '/test.txt', 'w')
# 创建存放所有训练图片数据的路径的文件
file_train = open(txtsavepath + '/train.txt', 'w')
# 创建存放所有测试图片数据的路径的文件
file_val = open(txtsavepath + '/val.txt', 'w')

# 遍历list_index列表,将文件名按照划分规则写入相应的txt文件中
for i in list_index:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        file_trainval.write(name)
        if i in train:
            file_train.write(name)
        else:
            file_val.write(name)
    else:
        file_test.write(name)

file_trainval.close()
file_train.close()
file_val.close()
file_test.close()

运行代码:dataSet中出现四个文件,里面是图片的名字

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据_02

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_xml_03

根据xml标注文件制作适合yolo的标签

即将每个xml标注提取bbox信息为txt格式,每个图像对应一个txt文件,文件每一行为一个目标的信息,包括class, x_center, y_center, width, height。创建make_labes.py,复制如下代码运行:

# -*- coding: utf-8 -*-
import xml.etree.ElementTree as ET
import os
from os import getcwd
 
sets = ['train', 'val', 'test']
classes = ['smoke']   # 改成自己的类别
abs_path = os.getcwd()
print(abs_path)
 
def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return x, y, w, h
 
def convert_annotation(image_id):
    in_file = open('mydata/xml/%s.xml' % (image_id), encoding='UTF-8')
    out_file = open('mydata/label/%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        b1, b2, b3, b4 = b
        # 标注越界修正
        if b2 > w:
            b2 = w
        if b4 > h:
            b4 = h
        b = (b1, b2, b3, b4)
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
 
wd = getcwd()
for image_set in sets:
    if not os.path.exists('mydata/label/'):
        os.makedirs('mydata/label/')
    image_ids = open('mydata/dataSet/%s.txt' % (image_set)).read().strip().split()
    list_file = open('mydata/%s.txt' % (image_set), 'w')
    for image_id in image_ids:
        list_file.write(abs_path + '/mydata/images/%s.jpg\n' % (image_id))
        convert_annotation(image_id)
    list_file.close()

运行完成:

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据集_04

label目录下出现了图片对应的标记位置(好像是标记框的左上角和由上角的坐标)与类别

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_xml_05

mydata目录下,出现了训练集train.txt,测试集test.txt,里面是对应的图片路径

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_xml_06

将划分好的数据集转成适合yolo-nas要求的数据集

创建data目录

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据_07

error目录: 存放格式有问题的图片,格式有问题的图片会中断训练 images/train目录:存放训练集图片 images/val目录:存放测试集图片 labels/train目录:存放训练集图片的标签 labels/val目录:存放测试集图片的标签

根据 trian.txt和val.txt中的图片路径移动图片到train和val文件中

项目目录下创建 mv_img.py,运行代码移动图片

import shutil

# 读取txt文件根据里面的图片路径移动图片到新目录
def move_img(img_txt_path,new_img_path):
    with open(img_txt_path, 'r') as file:
        image_paths = file.readlines()

    # image_new_paths = []
    # for path in image_paths:
    #     path = path.strip()
    #     path = 'mydata/label/'+path+'.txt'
    #     image_new_paths.append(path)

    # 去除路径中的换行符
    image_paths = [path.strip() for path in image_paths]

    # 复制图片到新的文件夹
    # for path in image_new_paths:
    for path in image_paths:
        shutil.copy(path, new_img_path)
img_txt_paths =['mydata/val.txt','mydata/train.txt']
new_img_paths =['data/images/val','data/images/train']

for img_txt_path,new_img_path in zip(img_txt_paths,new_img_paths):
    print(f'img_txt_path,new_img_path = {img_txt_path,new_img_path}')
    move_img(img_txt_path,new_img_path)
# move_img('mydata/train.txt','data/images/train')

项目目录下创建 mv_labels.py,运行代码移动标签

import shutil
# 根据图片名,读取标签的txt文件,并移动到data中label文件夹中
def move_label(img_txt_path,new_label_path):
    with open(img_txt_path, 'r') as file:
        image_paths = file.readlines()

    image_new_paths = []
    for path in image_paths:
        path = path.strip()
        path = 'mydata/label/'+path+'.txt'
        image_new_paths.append(path)
    print(f'image_new_paths = {image_new_paths}')

    # 去除路径中的换行符
    # image_paths = [path.strip() for path in image_paths]

    # 复制图片到新的文件夹
    # for path in image_new_paths:
    for path in image_new_paths:
        print(f'path={path}')
        print(f'new_label_path={new_label_path}')
        shutil.copy(path, new_label_path)

'''移动标签'''
img_name_paths =['mydata/dataSet/val.txt','mydata/dataSet/train.txt']
new_label_paths =['data/labels/val','data/labels/train']

for img_name_path,new_label_path in zip(img_name_paths,new_label_paths):
    print(f'img_name_path,new_img_path = {img_name_path,new_label_path}')
    move_label(img_name_path,new_label_path)

去除训练集测试集中格式错误的代码

在项目的根目录下创建 检查错误jpg.py,复制如下代码运行,这样会去除掉data/images/val/路径下格式错误的图片(错误图片会使训练中的程序中断运行)

import os

train_dir = 'data/images/val/'

def progress(percent, width=50):
    '''进度打印功能'''
    if percent >= 100:
        percent = 100

    show_str = ('[%%-%ds]' % width) % (int(width * percent / 100) * "#")  # 字符串拼接的嵌套使用
    print('\r%s %d%% ' % (show_str, percent), end='')


def is_valid_jpg(jpg_file):
    with open(jpg_file, 'rb') as f:
        f.seek(-2, 2)
        buf = f.read()
        f.close()
        return buf == b'\xff\xd9'  # 判定jpg是否包含结束字段


data_size = len([lists for lists in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, lists))])
recv_size = 0
incompleteFile = 0
print('file tall : %d' % data_size)

for file in os.listdir(train_dir):
    if os.path.splitext(file)[1].lower() == '.jpg':
        ret = is_valid_jpg(train_dir + file)
        if ret == False:
            incompleteFile = incompleteFile + 1
            os.remove(train_dir + file)

    recv_per = int(100 * recv_size / data_size)
    progress(recv_per, width=30)
    recv_size = recv_size + 1

progress(100, width=30)
print('\nincomplete file : %d' % incompleteFile)

然后将train_dir = 'data/images/val/'改为train_dir = 'data/images/train/',再运行一下代码,去除掉train文件下格式错误的图片

训练代码

import os

import requests
import torch
from PIL import Image

from super_gradients.training import Trainer, dataloaders, models
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val
)
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (
    PPYoloEPostPredictionCallback
)
class config:
    # trainer params
    CHECKPOINT_DIR = 'checkpoints'  # specify the path you want to save checkpoints to
    EXPERIMENT_NAME = 'cars-from-above'  # specify the experiment name

    # dataset params
    DATA_DIR = 'data'  # parent directory to where data lives

    TRAIN_IMAGES_DIR = 'images/train'  # child dir of DATA_DIR where train images are
    TRAIN_LABELS_DIR = 'labels/train'  # child dir of DATA_DIR where train labels are

    VAL_IMAGES_DIR = 'images/val'  # child dir of DATA_DIR where validation images are
    VAL_LABELS_DIR = 'labels/val'  # child dir of DATA_DIR where validation labels are

    # TEST_IMAGES_DIR = 'images/test'  # child dir of DATA_DIR where validation images are
    # TEST_LABELS_DIR = 'labels/test'  # child dir of DATA_DIR where validation labels are
    CLASSES = ['smoke']  # 指定类名

    NUM_CLASSES = len(CLASSES) # 获取类个数

    # dataloader params - you can add whatever PyTorch dataloader params you have
    # could be different across train, val, and test
    DATALOADER_PARAMS = {
        'batch_size': 16,
        'num_workers': 2
    }

    # model params
    MODEL_NAME = 'yolo_nas_l'  # 可以选择 yolo_nas_s, yolo_nas_m, yolo_nas_l。分别是 小型,中型,大型
    PRETRAINED_WEIGHTS = 'coco'  # only one option here: coco
trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR)

# 指定训练数据
train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': config.DATA_DIR,
        'images_dir': config.TRAIN_IMAGES_DIR,
        'labels_dir': config.TRAIN_LABELS_DIR,
        'classes': config.CLASSES
    },
    dataloader_params=config.DATALOADER_PARAMS
)

# 指定评估数据
val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': config.DATA_DIR,
        'images_dir': config.VAL_IMAGES_DIR,
        'labels_dir': config.VAL_LABELS_DIR,
        'classes': config.CLASSES
    },
    dataloader_params=config.DATALOADER_PARAMS
)

# test_data = coco_detection_yolo_format_val(
#     dataset_params={
#         'data_dir': config.DATA_DIR,
#         'images_dir': config.TEST_IMAGES_DIR,
#         'labels_dir': config.TEST_LABELS_DIR,
#         'classes': config.CLASSES
#     },
#
dataloader_params=config.DATALOADER_PARAMS
# )
# train_data.dataset.plot()

model = models.get(config.MODEL_NAME,
                   num_classes=config.NUM_CLASSES,
                   pretrained_weights=config.PRETRAINED_WEIGHTS
                   )
train_params = {
    # ENABLING SILENT MODE
    "average_best_models":True,
    "warmup_mode": "linear_epoch_step",
    "warmup_initial_lr": 1e-6,
    "lr_warmup_epochs": 3,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "Adam",
    "optimizer_params": {"weight_decay": 0.0001},
    "zero_weight_decay_on_bias_and_bn": True,
    "ema": True,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    # ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
    "max_epochs": 200,
    "mixed_precision": True,
    "loss": PPYoloELoss(
        use_static_assigner=False,
        # NOTE: num_classes needs to be defined here
        num_classes=config.NUM_CLASSES,
        reg_max=16
    ),
    "valid_metrics_list": [
        DetectionMetrics_050(
            score_thres=0.1,
            top_k_predictions=300,
            # NOTE: num_classes needs to be defined here
            num_cls=config.NUM_CLASSES,
            normalize_targets=True,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01,
                nms_top_k=1000,
                max_predictions=300,
                nms_threshold=0.7
            )
        )
    ],
    "metric_to_watch": 'mAP@0.50'
}

trainer.train(model=model,
              training_params=train_params,
              train_loader=train_data,
              valid_loader=val_data)

best_model = models.get(config.MODEL_NAME,
                        num_classes=config.NUM_CLASSES,
                        checkpoint_path=os.path.join(config.CHECKPOINT_DIR, config.EXPERIMENT_NAME, 'average_model.pth'))

训练时仍然代码检测不到的错误格式图片

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据集_08

根据第一行红字我们发现它的id是338。

根据id找到图片路径

import os


def print_file_name(directory, file_index):
    files = os.listdir(directory)

    # 确保文件索引在有效范围内
    if file_index >= 0 and file_index < len(files):
        file_name = files[file_index]
        print("第{}个文件的文件名为:{}".format(file_index + 1, file_name))
    else:
        print("无法找到第{}个文件".format(file_index + 1))


# 指定目录路径
directory_path = "data/images/train"

# 要打印的文件索引
file_index = 338

# 调用函数打印指定文件的文件名
print_file_name(directory_path, file_index)

发现是:data/images/train/smoke_b002192.jpg

根据文件路径将图片移动到error文件夹

import os
import shutil

def move_file_to_new_directory(file_name):
    # 源目录的路径
    train_dir = "data/images/train"
    # 目标目录的路径
    new_dir = "data/error"

    # 构造文件的完整路径
    file_path = os.path.join(train_dir, file_name)

    if os.path.exists(file_path):
        # 移动文件到new目录
        shutil.move(file_path, new_dir)
        print(f"成功将文件{file_name}移动到new目录")
    else:
        print(f"找不到文件{file_name}")

# 测试 ,输入图片名,将图片移动
move_file_to_new_directory("smoke_b002192.jpg")

再次运行训练代码

成功开始训练。

yolo-nas对自定义数据集进行训练,测试详解 & 香烟数据集 & 处理损坏的图片数据 & 对网络摄像头,视频,图片预测_数据集_09

连接网络摄像头用训练好的模型参数进行预测

import torch
from super_gradients.training import models
import cv2
import time
def get_video_capture(video, width=None, height=None, fps=None):
    """
     获得视频读取对象
     --   7W   Pix--> width=320,height=240
     --   30W  Pix--> width=640,height=480
     720P,100W Pix--> width=1280,height=720
     960P,130W Pix--> width=1280,height=1024
    1080P,200W Pix--> width=1920,height=1080
    :param video: video file or Camera ID
    :param width:   图像分辨率width
    :param height:  图像分辨率height
    :param fps:  设置视频播放帧率
    :return:
    """
    video_cap = cv2.VideoCapture(video)
    # 如果指定了宽度,高度,fps,则按照制定的值来设置,此处并没有指定
    if width:
        video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
    if height:
        video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
    if fps:
        video_cap.set(cv2.CAP_PROP_FPS, fps)
    return video_cap

# 此处连接网络摄像头进行测试
video_file = 'rtsp://账号:密码@ip/Streaming/Channels/1'
# video_file = 'data/output.mp4'
num_classes = 1
# best_pth = '/home/computer_vision/code/my_code/checkpoints/cars-from-above/ckpt_best.pth'
best_pth = 'checkpoints/cars-from-above/smoke_small_ckpt_best.pth'
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
best_model = models.get("yolo_nas_s", num_classes=num_classes, checkpoint_path=best_pth).to(device)

'''开始计时'''
start_time = time.time()
video_cap = get_video_capture(video_file)
while True:
    isSuccess, frame = video_cap.read()
    if not isSuccess:
        break
    result_image = best_model.predict(frame, conf=0.45, fuse_model=False)
    result_image = result_image._images_prediction_lst[0]
    result_image = result_image.draw()
    '''改动'''
    result_image = cv2.resize(result_image, (960, 540))
    '''end'''
    cv2.namedWindow('result', flags=cv2.WINDOW_NORMAL)
    cv2.imshow('result', result_image)
    kk = cv2.waitKey(1)
    if kk == ord('q'):
        break
video_cap.release()
'''时间结束'''
end_time = time.time()
run_time = end_time - start_time
print(run_time)

补充

对视频进行预测

import torch
from super_gradients.training import models

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = models.get("yolo_nas_l", pretrained_weights="coco").to(device)
model.predict("data/output.mp4",conf=0.4).save("output/output_lianzhang.mp4")

对图片进行预测

import torch
from super_gradients.training import models

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = models.get("yolo_nas_s", pretrained_weights="coco").to(device)
out = model.predict("camera01.png", conf=0.6)
out.show()
out.save("output")

预测data目录下的视频并保存预测结果

model.predict("data/output.mp4").save("output/output_lianzhang.mp4")
【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月08日 0

暂无评论

推荐阅读
XZVAVmlOqzo6