Bỏ để qua phần nội dung

Tùy chỉnh nâng cao

Cả hai Ultralytics YOLO dòng lệnh và Python Giao diện chỉ đơn giản là một sự trừu tượng cấp cao trên các trình thực thi công cụ cơ sở. Chúng ta hãy nhìn vào động cơ Trainer.



Xem: Mastering Ultralytics YOLOv8: Tùy chỉnh nâng cao

BaseTrainer

BaseTrainer chứa thói quen đào tạo soạn sẵn chung. Nó có thể được tùy chỉnh cho bất kỳ tác vụ nào dựa trên việc ghi đè lên các chức năng hoặc hoạt động cần thiết miễn là tuân theo các định dạng chính xác. Ví dụ: bạn có thể hỗ trợ mô hình tùy chỉnh và bộ tải dữ liệu của riêng mình bằng cách ghi đè các chức năng sau:

  • get_model(cfg, weights) - Chức năng xây dựng mô hình cần đào tạo
  • get_dataloader() - Chức năng xây dựng bộ tải dữ liệu Thêm chi tiết và mã nguồn có thể được tìm thấy trong BaseTrainer Tham khảo

DetectionTrainer

Đây là cách bạn có thể sử dụng YOLOv8 DetectionTrainer và tùy chỉnh nó.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Tùy chỉnh DetectionTrainer

Hãy tùy chỉnh huấn luyện viên để đào tạo mô hình phát hiện tùy chỉnh không được hỗ trợ trực tiếp. Bạn có thể làm điều này bằng cách chỉ cần quá tải các get_model Chức năng:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Bây giờ bạn nhận ra rằng bạn cần tùy chỉnh huấn luyện viên hơn nữa để:

  • Tùy chỉnh loss function.
  • Thêm callback tải mô hình lên Google Lái xe sau mỗi 10 epochs Đây là cách bạn có thể làm điều đó:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Để biết thêm về các sự kiện kích hoạt Callback và điểm vào lệnh, hãy xem Hướng dẫn gọi lại của chúng tôi

Các thành phần động cơ khác

Có những thành phần khác có thể được tùy chỉnh tương tự như ValidatorsPredictors. Xem phần Tham khảo để biết thêm thông tin về những điều này.

FAQ

Làm cách nào để tùy chỉnh Ultralytics YOLOv8 DetectionTrainer cho các tác vụ cụ thể?

Để tùy chỉnh Ultralytics YOLOv8 DetectionTrainer Đối với một tác vụ cụ thể, bạn có thể ghi đè các phương thức của nó để thích ứng với mô hình tùy chỉnh và bộ tải dữ liệu của bạn. Bắt đầu bằng cách kế thừa từ DetectionTrainer và sau đó định nghĩa lại các phương thức như get_model để triển khai các chức năng tùy chỉnh của bạn. Đây là một ví dụ:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Để tùy chỉnh thêm như thay đổi loss function hoặc thêm một callback, bạn có thể tham khảo Hướng dẫn gọi lại.

Các thành phần chính của BaseTrainer trong là gì Ultralytics YOLOv8?

Các BaseTrainer trong Ultralytics YOLOv8 đóng vai trò là nền tảng cho các thói quen đào tạo và có thể được tùy chỉnh cho các nhiệm vụ khác nhau bằng cách ghi đè lên các phương pháp chung của nó. Các thành phần chính bao gồm:

  • get_model(cfg, weights) để xây dựng mô hình cần đào tạo.
  • get_dataloader() để xây dựng bộ tải dữ liệu.

Để biết thêm chi tiết về tùy chỉnh và mã nguồn, hãy xem BaseTrainer Tham khảo.

Làm cách nào để thêm callback vào Ultralytics YOLOv8 Phát hiệnTrainer?

Bạn có thể thêm callback để theo dõi và sửa đổi quá trình đào tạo trong Ultralytics YOLOv8 DetectionTrainer. Ví dụ: đây là cách bạn có thể thêm lệnh gọi lại để ghi lại trọng số mô hình sau mỗi kỷ nguyên đào tạo:

from ultralytics.models.yolo.detect import DetectionTrainer


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

Để biết thêm chi tiết về các sự kiện gọi lại và điểm vào lệnh, hãy tham khảo Hướng dẫn gọi lại của chúng tôi.

Tại sao tôi nên sử dụng Ultralytics YOLOv8 để đào tạo mô hình?

Ultralytics YOLOv8 Cung cấp một sự trừu tượng cấp cao trên các trình thực thi động cơ mạnh mẽ, làm cho nó trở nên lý tưởng để phát triển và tùy chỉnh nhanh chóng. Các lợi ích chính bao gồm:

  • Dễ sử dụng: Cả dòng lệnh và Python Giao diện đơn giản hóa các tác vụ phức tạp.
  • Hiệu suất: Được tối ưu hóa để phát hiện đối tượng theo thời gian thực và các ứng dụng AI trực quan khác nhau.
  • Tùy chỉnh: Dễ dàng mở rộng cho các mô hình tùy chỉnh, chức năng mất mát và bộ tải dữ liệu.

Tìm hiểu thêm về YOLOv8Khả năng của bằng cách truy cập Ultralytics YOLO.

Tôi có thể sử dụng Ultralytics YOLOv8 DetectionTrainer cho các mô hình không chuẩn?

Có Ultralytics YOLOv8 DetectionTrainer có tính linh hoạt cao và có thể được tùy chỉnh cho các mô hình không chuẩn. Bằng cách kế thừa từ DetectionTrainer, bạn có thể quá tải các phương pháp khác nhau để hỗ trợ nhu cầu của mô hình cụ thể của bạn. Đây là một ví dụ đơn giản:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model."""
        ...


trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

Để biết hướng dẫn và ví dụ toàn diện hơn, hãy xem lại tài liệu DetectionTrainer .



Đã tạo 2023-11-12, Cập nhật 2024-07-04
Tác giả: glenn-jocher (7), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Ý kiến