.yolov8ultralyticsutilsbenchmarks.py
代码语言:javascript
复制# 从 glob 模块中导入 glob 函数,用于文件路径的模糊匹配
import glob
# 导入 os 模块,提供了许多与操作系统交互的函数
import os
# 导入 platform 模块,用于获取系统平台信息
import platform
# 导入 re 模块,支持正则表达式操作
import re
# 导入 shutil 模块,提供了高级的文件操作功能
import shutil
# 导入 time 模块,提供时间相关的功能
import time
# 从 pathlib 模块中导入 Path 类,用于操作文件路径
from pathlib import Path
# 导入 numpy 库,用于数值计算
import numpy as np
# 导入 torch.cuda 模块,用于 CUDA 相关操作
import torch.cuda
# 导入 yaml 库,用于处理 YAML 格式的文件
import yaml
# 从 ultralytics 包中导入 YOLO 和 YOLOWorld 类
from ultralytics import YOLO, YOLOWorld
# 从 ultralytics.cfg 模块中导入 TASK2DATA 和 TASK2METRIC 变量
from ultralytics.cfg import TASK2DATA, TASK2METRIC
# 从 ultralytics.engine.exporter 模块中导入 export_formats 函数
from ultralytics.engine.exporter import export_formats
# 从 ultralytics.utils 模块中导入 ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI 等变量
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
# 从 ultralytics.utils.checks 模块中导入 IS_PYTHON_3_12, check_requirements, check_yolo 等函数和变量
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
# 从 ultralytics.utils.downloads 模块中导入 safe_download 函数
from ultralytics.utils.downloads import safe_download
# 从 ultralytics.utils.files 模块中导入 file_size 函数
from ultralytics.utils.files import file_size
# 从 ultralytics.utils.torch_utils 模块中导入 select_device 函数
from ultralytics.utils.torch_utils import select_device
def benchmark(
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
Args:
model (str | Path | optional): Path to the model file or directory. Default is
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
data (str, optional): Dataset to evaluate on, inherited from TASK2DATA if not passed. Default is None.
imgsz (int, optional): Image size for the benchmark. Default is 160.
half (bool, optional): Use half-precision for the model if True. Default is False.
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
verbose (bool | float | optional): If True or a float, assert benchmarks pass with given metric.
Default is False.
"""
# 函数主体,用于评估 YOLO 模型在不同格式下的速度和准确性,参数详细说明在函数文档字符串中给出
pass # 这里是示例,实际代码会在此基础上继续开发
def benchmark(model='yolov8n.pt', imgsz=640):
"""
Benchmark function to evaluate model performance.
Args:
model (str or Path): Path to the model checkpoint.
imgsz (int): Image size for inference.
Returns:
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
metric, and inference time.
Example:
```python
from ultralytics.utils.benchmarks import benchmark
benchmark(model='yolov8n.pt', imgsz=640)
```
"""
import pandas as pd # Import pandas library for DataFrame operations
pd.options.display.max_columns = 10 # Set maximum display columns in pandas DataFrame
pd.options.display.width = 120 # Set display width for pandas DataFrame
device = select_device(device, verbose=False) # Select device for model inference
if isinstance(model, (str, Path)):
model = YOLO(model) # Initialize YOLO model if model is given as a string or Path
is_end2end = getattr(model.model.model[-1], "end2end", False) # Check if model supports end-to-end inference
y = [] # Initialize an empty list to store benchmark results
t0 = time.time() # Record current time for benchmarking purposes
check_yolo(device=device) # Print system information relevant to YOLO
# Create a pandas DataFrame 'df' with columns defined for benchmark results
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
name = Path(model.ckpt_path).name # Extract the name of the model checkpoint file
# Construct a string 's' summarizing benchmark results and logging information
s = f"nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)n{df}n"
LOGGER.info(s) # Log 's' to the logger file
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s) # Append string 's' to the 'benchmarks.log' file
if verbose and isinstance(verbose, float):
metrics = df[key].array # Extract the 'key' column values from the DataFrame 'df'
floor = verbose # Set the minimum metric floor to compare against
# Assert that all metrics are greater than 'floor' if they are not NaN
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df # Return the pandas DataFrame 'df' containing benchmark results
class RF100Benchmark:
"""Benchmark YOLO model performance across formats for speed and accuracy."""
def __init__(self):
"""Function for initialization of RF100Benchmark."""
# 初始化空列表,用于存储数据集名称
self.ds_names = []
# 初始化空列表,用于存储数据集配置文件路径
self.ds_cfg_list = []
# 初始化 RF 对象为 None
self.rf = None
# 定义验证指标列表
self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
def set_key(self, api_key):
"""
Set Roboflow API key for processing.
Args:
api_key (str): The API key.
"""
# 检查是否满足 Roboflow 相关的依赖
check_requirements("roboflow")
# 导入 Roboflow 模块
from roboflow import Roboflow
# 创建 Roboflow 对象并设置 API 密钥
self.rf = Roboflow(api_key=api_key)
def parse_dataset(self, ds_link_txt="datasets_links.txt"):
"""
Parse dataset links and downloads datasets.
Args:
ds_link_txt (str): Path to dataset_links file.
"""
# 如果存在 rf-100 目录,则删除并重新创建;否则直接创建
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
# 切换当前工作目录至 rf-100
os.chdir("rf-100")
# 在 rf-100 目录下创建 ultralytics-benchmarks 目录
os.mkdir("ultralytics-benchmarks")
# 安全下载 datasets_links.txt 文件
safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")
# 打开数据集链接文件,逐行处理
with open(ds_link_txt, "r") as file:
for line in file:
try:
# 使用正则表达式拆分数据集链接
_, url, workspace, project, version = re.split("/ ", line.strip())
# 将项目名称添加到数据集名称列表
self.ds_names.append(project)
# 组合项目和版本信息
proj_version = f"{project}-{version}"
# 如果该版本数据集尚未下载,则使用 Roboflow 对象下载到 yolov8 目录下
if not Path(proj_version).exists():
self.rf.workspace(workspace).project(project).version(version).download("yolov8")
else:
print("Dataset already downloaded.")
# 添加数据集配置文件路径到列表中
self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
except Exception:
continue
return self.ds_names, self.ds_cfg_list
@staticmethod
def fix_yaml(path):
"""
Function to fix YAML train and val path.
Args:
path (str): YAML file path.
"""
# 使用安全加载方式读取 YAML 文件
with open(path, "r") as file:
yaml_data = yaml.safe_load(file)
# 修改 YAML 文件中的训练和验证路径
yaml_data["train"] = "train/images"
yaml_data["val"] = "valid/images"
# 使用安全写入方式将修改后的 YAML 数据写回文件
with open(path, "w") as file:
yaml.safe_dump(yaml_data, file)
def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind):
"""
Model evaluation on validation results.
Args:
yaml_path (str): YAML file path.
val_log_file (str): val_log_file path.
eval_log_file (str): eval_log_file path.
list_ind (int): Index for current dataset.
"""
# 定义跳过的符号列表,这些符号出现在日志行中时将被跳过
skip_symbols = ["