使用YOLOv10进行自定义目标检测

2024-07-25 18:39:34 浏览数 (2)

视觉/图像重磅干货,第一时间送达

YOLO(You Only Look Once)是一种流行的物体检测算法,以其速度和准确性而闻名。与涉及生成区域提案然后对其进行分类的多阶段过程的传统方法不同,YOLO 将物体检测框架化为单个回归问题,只需一次评估即可直接从完整图像中预测边界框和类别概率。

YOLOv10是 YOLO 系列的一项进步,在速度、准确性和效率方面均比以前的版本有所改进。主要功能包括:

  • 单次检测:YOLO 在网络的单次前向传递中检测物体,因此速度极快。
  • 统一架构:模型在训练和测试期间看到整个图像,从而实现更好的上下文理解。
  • 改进的锚点和损失函数:YOLOv10 使用更好的锚框生成和改进的损失函数来实现更精确的边界框预测。

使用YOLOv10进行自定义对象检测

自定义对象检测涉及在特定数据集上训练 YOLOv10 模型,该数据集可能包含预训练模型未涵盖的各种对象。此过程涉及几个关键步骤:

  • 设置环境:安装必要的库和依赖项。
  • 准备数据集:以 YOLO 格式构建数据集并应用数据增强来提高模型鲁棒性。
  • 配置YOLOv10模型:准备配置文件并设置模型。
  • 训练 YOLOv10 模型:使用先进的训练技术来优化性能。
  • 评估模型性能:使用各种指标评估模型。
  • 推理和可视化:在新图像上测试模型并可视化结果。
  • 微调和超参数优化:进一步优化模型以获得更好的准确性和性能。

1. 设置环境

首先,确保您已经安装了必要的库:

代码语言:javascript复制
pip install ultralytics
pip install matplotlib
pip install albumentations

2. 准备并扩充数据集

确保您的数据集结构正确并实现数据增强。

代码语言:javascript复制
import os
from glob import glob
import albumentations as A
import cv2

# Example augmentation pipeline using Albumentations
augmentation_pipeline = A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.Rotate(limit=10, p=0.2),
A.MotionBlur(p=0.2),
A.HueSaturationValue(p=0.2),
])

# Function to apply augmentations and save augmented images
def augment_and_save(image_dir, label_dir, output_dir):
image_paths = glob(os.path.join(image_dir, ‘*.jpg’))
for image_path in image_paths:
image = cv2.imread(image_path)
augmented = augmentation_pipeline(image=image)
augmented_image = augmented[‘image’]
image_name = os.path.basename(image_path)
cv2.imwrite(os.path.join(output_dir, image_name), augmented_image)

augment_and_save(‘path_to_train_images’, ‘path_to_train_labels’, ‘path_to_augmented_images’)

3. 配置YOLOv10模型

准备YOLO格式的数据集配置文件:

代码语言:javascript复制
# data.yaml
train: ../train/images
val: ../val/images

nc: 2 # number of classes
names: [‘class1’, ‘class2’]

4. 训练YOLOv10模型

利用混合精度训练和学习率调度器等先进的训练技术。

代码语言:javascript复制
from ultralytics import YOLO

# Load a YOLOv10 model pre-trained on COCO dataset
model = YOLO(‘yolov10.pt’)

# Train the model with advanced settings
results = model.train(
data=’data.yaml’,
epochs=100,
imgsz=640,
batch=16,
lr0=0.01,
momentum=0.9,
weight_decay=0.0005,
optimizer=’SGD’,
patience=10,
img_weights=True,
augment=True,
precision=’mixed’
)

# Save the model weights
model.save(‘custom_yolov10_advanced.pt’)

5. 评估模型性能

评估模型并使用 mAP、精度、召回率和 F1 分数等指标。

代码语言:javascript复制
# Evaluate the model
metrics = model.val()

# Print the metrics
print(f”mAP: {metrics[‘mAP’]}”)
print(f”Precision: {metrics[‘precision’]}”)
print(f”Recall: {metrics[‘recall’]}”)
print(f”F1-Score: {metrics[‘f1’]}”)

6. 推理和可视化

使用模型进行推理并可视化结果。

代码语言:javascript复制
import matplotlib.pyplot as plt
from PIL import Image

# Load the custom-trained model
model = YOLO(‘custom_yolov10_advanced.pt’)

# Perform inference on a new image
img_path = ‘path_to_new_image.jpg’
results = model.predict(img_path)

# Display the image with predictions
img = Image.open(img_path)
plt.imshow(img)
plt.axis(‘off’)
plt.show()

# Optionally, draw bounding boxes on the image
preds = results.pred[0].numpy()
for box in preds:
plt.gca().add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, edgecolor=’red’, linewidth=2))
plt.show()

7. 微调和超参数优化

尝试不同的超参数和微调技术。

代码语言:javascript复制
import numpy as np
from ultralytics import YOLO

# Define a function for hyperparameter optimization
def hyperparameter_optimization(trials=50):
for trial in range(trials):
# Randomly sample hyperparameters
lr0 = 10**np.random.uniform(-4, -2)
momentum = np.random.uniform(0.8, 0.99)
weight_decay = 10**np.random.uniform(-5, -3)

# Train the model with sampled hyperparameters
model = YOLO(‘yolov10.pt’)
results = model.train(
data=’data.yaml’,
epochs=50,
imgsz=640,
batch=16,
lr0=lr0,
momentum=momentum,
weight_decay=weight_decay,
optimizer=’SGD’,
patience=10,
img_weights=True,
augment=True
)

# Evaluate the model
metrics = model.val()
print(f”Trial {trial   1}/{trials} — mAP: {metrics[‘mAP’]}, Precision: {metrics[‘precision’]}, Recall: {metrics[‘recall’]}, F1-Score: {metrics[‘f1’]}”)

# Run hyperparameter optimization
hyperparameter_optimization()

0 人点赞