Bỏ để qua phần nội dung

Tùy chỉnh nâng cao

Cả hai Ultralytics YOLO dòng lệnh và Python Giao diện chỉ đơn giản là một sự trừu tượng cấp cao trên các trình thực thi công cụ cơ sở. Chúng ta hãy nhìn vào động cơ Trainer.



Xem: Mastering Ultralytics YOLOv8: Tùy chỉnh nâng cao

BaseTrainer

BaseTrainer chứa thói quen đào tạo soạn sẵn chung. Nó có thể được tùy chỉnh cho bất kỳ tác vụ nào dựa trên việc ghi đè lên các chức năng hoặc hoạt động cần thiết miễn là tuân theo các định dạng chính xác. Ví dụ: bạn có thể hỗ trợ mô hình tùy chỉnh và bộ tải dữ liệu của riêng mình bằng cách ghi đè các chức năng sau:

  • get_model(cfg, weights) - Chức năng xây dựng mô hình cần đào tạo
  • get_dataloader() - Chức năng xây dựng bộ tải dữ liệu Thêm chi tiết và mã nguồn có thể được tìm thấy trong BaseTrainer Tham khảo

DetectionTrainer

Đây là cách bạn có thể sử dụng YOLOv8 DetectionTrainer và tùy chỉnh nó.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Tùy chỉnh DetectionTrainer

Hãy tùy chỉnh huấn luyện viên để đào tạo mô hình phát hiện tùy chỉnh không được hỗ trợ trực tiếp. Bạn có thể làm điều này bằng cách chỉ cần quá tải các get_model Chức năng:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Bây giờ bạn nhận ra rằng bạn cần tùy chỉnh huấn luyện viên hơn nữa để:

  • Tùy chỉnh loss function.
  • Thêm callback tải mô hình lên Google Drive của bạn sau mỗi 10 epochs Đây là cách bạn có thể làm điều đó:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Để biết thêm về các sự kiện kích hoạt Callback và điểm vào lệnh, hãy xem Hướng dẫn gọi lại của chúng tôi

Các thành phần động cơ khác

Có những thành phần khác có thể được tùy chỉnh tương tự như ValidatorsPredictors. Xem phần Tham khảo để biết thêm thông tin về những điều này.



Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (6), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Ý kiến