Yolov8 源码解析(四十)

2024-09-13 17:19:43 浏览数 (1)

.yolov8ultralyticsutilsbenchmarks.py

代码语言:javascript复制
# 从 glob 模块中导入 glob 函数,用于文件路径的模糊匹配
import glob
# 导入 os 模块,提供了许多与操作系统交互的函数
import os
# 导入 platform 模块,用于获取系统平台信息
import platform
# 导入 re 模块,支持正则表达式操作
import re
# 导入 shutil 模块,提供了高级的文件操作功能
import shutil
# 导入 time 模块,提供时间相关的功能
import time
# 从 pathlib 模块中导入 Path 类,用于操作文件路径
from pathlib import Path

# 导入 numpy 库,用于数值计算
import numpy as np
# 导入 torch.cuda 模块,用于 CUDA 相关操作
import torch.cuda
# 导入 yaml 库,用于处理 YAML 格式的文件
import yaml

# 从 ultralytics 包中导入 YOLO 和 YOLOWorld 类
from ultralytics import YOLO, YOLOWorld
# 从 ultralytics.cfg 模块中导入 TASK2DATA 和 TASK2METRIC 变量
from ultralytics.cfg import TASK2DATA, TASK2METRIC
# 从 ultralytics.engine.exporter 模块中导入 export_formats 函数
from ultralytics.engine.exporter import export_formats
# 从 ultralytics.utils 模块中导入 ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI 等变量
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
# 从 ultralytics.utils.checks 模块中导入 IS_PYTHON_3_12, check_requirements, check_yolo 等函数和变量
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
# 从 ultralytics.utils.downloads 模块中导入 safe_download 函数
from ultralytics.utils.downloads import safe_download
# 从 ultralytics.utils.files 模块中导入 file_size 函数
from ultralytics.utils.files import file_size
# 从 ultralytics.utils.torch_utils 模块中导入 select_device 函数
from ultralytics.utils.torch_utils import select_device


def benchmark(
    model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
):
    """
    Benchmark a YOLO model across different formats for speed and accuracy.

    Args:
        model (str | Path | optional): Path to the model file or directory. Default is
            Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
        data (str, optional): Dataset to evaluate on, inherited from TASK2DATA if not passed. Default is None.
        imgsz (int, optional): Image size for the benchmark. Default is 160.
        half (bool, optional): Use half-precision for the model if True. Default is False.
        int8 (bool, optional): Use int8-precision for the model if True. Default is False.
        device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
        verbose (bool | float | optional): If True or a float, assert benchmarks pass with given metric.
            Default is False.
    """
    # 函数主体,用于评估 YOLO 模型在不同格式下的速度和准确性,参数详细说明在函数文档字符串中给出
    pass  # 这里是示例,实际代码会在此基础上继续开发
    def benchmark(model='yolov8n.pt', imgsz=640):
        """
        Benchmark function to evaluate model performance.
    
        Args:
            model (str or Path): Path to the model checkpoint.
            imgsz (int): Image size for inference.
    
        Returns:
            df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
                metric, and inference time.
    
        Example:
            ```python
            from ultralytics.utils.benchmarks import benchmark
    
            benchmark(model='yolov8n.pt', imgsz=640)
            ```
        """
        import pandas as pd  # Import pandas library for DataFrame operations
        pd.options.display.max_columns = 10  # Set maximum display columns in pandas DataFrame
        pd.options.display.width = 120  # Set display width for pandas DataFrame
    
        device = select_device(device, verbose=False)  # Select device for model inference
        if isinstance(model, (str, Path)):
            model = YOLO(model)  # Initialize YOLO model if model is given as a string or Path
    
        is_end2end = getattr(model.model.model[-1], "end2end", False)  # Check if model supports end-to-end inference
    
        y = []  # Initialize an empty list to store benchmark results
        t0 = time.time()  # Record current time for benchmarking purposes
    
        check_yolo(device=device)  # Print system information relevant to YOLO
    
        # Create a pandas DataFrame 'df' with columns defined for benchmark results
        df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
    
        name = Path(model.ckpt_path).name  # Extract the name of the model checkpoint file
        # Construct a string 's' summarizing benchmark results and logging information
        s = f"nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)n{df}n"
        LOGGER.info(s)  # Log 's' to the logger file
    
        with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
            f.write(s)  # Append string 's' to the 'benchmarks.log' file
    
        if verbose and isinstance(verbose, float):
            metrics = df[key].array  # Extract the 'key' column values from the DataFrame 'df'
            floor = verbose  # Set the minimum metric floor to compare against
            # Assert that all metrics are greater than 'floor' if they are not NaN
            assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
    
        return df  # Return the pandas DataFrame 'df' containing benchmark results
class RF100Benchmark:
    """Benchmark YOLO model performance across formats for speed and accuracy."""

    def __init__(self):
        """Function for initialization of RF100Benchmark."""
        # 初始化空列表,用于存储数据集名称
        self.ds_names = []
        # 初始化空列表,用于存储数据集配置文件路径
        self.ds_cfg_list = []
        # 初始化 RF 对象为 None
        self.rf = None
        # 定义验证指标列表
        self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]

    def set_key(self, api_key):
        """
        Set Roboflow API key for processing.

        Args:
            api_key (str): The API key.
        """
        # 检查是否满足 Roboflow 相关的依赖
        check_requirements("roboflow")
        # 导入 Roboflow 模块
        from roboflow import Roboflow
        # 创建 Roboflow 对象并设置 API 密钥
        self.rf = Roboflow(api_key=api_key)

    def parse_dataset(self, ds_link_txt="datasets_links.txt"):
        """
        Parse dataset links and downloads datasets.

        Args:
            ds_link_txt (str): Path to dataset_links file.
        """
        # 如果存在 rf-100 目录,则删除并重新创建;否则直接创建
        (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
        # 切换当前工作目录至 rf-100
        os.chdir("rf-100")
        # 在 rf-100 目录下创建 ultralytics-benchmarks 目录
        os.mkdir("ultralytics-benchmarks")
        # 安全下载 datasets_links.txt 文件
        safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")

        # 打开数据集链接文件,逐行处理
        with open(ds_link_txt, "r") as file:
            for line in file:
                try:
                    # 使用正则表达式拆分数据集链接
                    _, url, workspace, project, version = re.split("/ ", line.strip())
                    # 将项目名称添加到数据集名称列表
                    self.ds_names.append(project)
                    # 组合项目和版本信息
                    proj_version = f"{project}-{version}"
                    # 如果该版本数据集尚未下载,则使用 Roboflow 对象下载到 yolov8 目录下
                    if not Path(proj_version).exists():
                        self.rf.workspace(workspace).project(project).version(version).download("yolov8")
                    else:
                        print("Dataset already downloaded.")
                    # 添加数据集配置文件路径到列表中
                    self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
                except Exception:
                    continue

        return self.ds_names, self.ds_cfg_list

    @staticmethod
    def fix_yaml(path):
        """
        Function to fix YAML train and val path.

        Args:
            path (str): YAML file path.
        """
        # 使用安全加载方式读取 YAML 文件
        with open(path, "r") as file:
            yaml_data = yaml.safe_load(file)
        # 修改 YAML 文件中的训练和验证路径
        yaml_data["train"] = "train/images"
        yaml_data["val"] = "valid/images"
        # 使用安全写入方式将修改后的 YAML 数据写回文件
        with open(path, "w") as file:
            yaml.safe_dump(yaml_data, file)
    def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind):
        """
        Model evaluation on validation results.

        Args:
            yaml_path (str): YAML file path.
            val_log_file (str): val_log_file path.
            eval_log_file (str): eval_log_file path.
            list_ind (int): Index for current dataset.
        """
        # 定义跳过的符号列表,这些符号出现在日志行中时将被跳过
        skip_symbols = ["


	

0 人点赞