Zum Inhalt springen

Erweiterte Anpassung

Sowohl die Kommandozeilenschnittstelle Ultralytics YOLO als auch die Schnittstelle Python sind lediglich eine Abstraktion der Basis-Engine-Executors. Werfen wir einen Blick auf die Trainer-Engine.



Pass auf: Mastering Ultralytics YOLOv8 : Erweiterte Anpassungen

BaseTrainer

Der BaseTrainer enthält die generische Trainingsroutine. Er kann für jede Aufgabe angepasst werden, indem die erforderlichen Funktionen oder Operationen überschrieben werden, solange die richtigen Formate eingehalten werden. Du kannst zum Beispiel dein eigenes benutzerdefiniertes Modell und deinen eigenen Dataloader unterstützen, indem du diese Funktionen einfach überschreibst:

  • get_model(cfg, weights) - Die Funktion, die das zu trainierende Modell erstellt
  • get_dataloader() - Die Funktion, die den Dataloader aufbaut Weitere Details und Quellcode findest du in BaseTrainer Referenz

DetectionTrainer

Hier erfährst du, wie du die YOLOv8 DetectionTrainer und passe sie an.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Anpassen des DetectionTrainers

Lass uns den Trainer anpassen um ein benutzerdefiniertes Erkennungsmodell zu trainieren die nicht direkt unterstützt wird. Du kannst dies tun, indem du einfach die bestehende Methode get_model Funktionalität:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Du merkst jetzt, dass du den Trainer weiter anpassen musst, um:

  • Anpassen der loss function.
  • hinzufügen callback das Modell nach jeweils 10 Minuten auf dein Google Drive hochlädt. epochs So kannst du es tun:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Mehr über Callback-Ereignisse und den Einstiegspunkt erfährst du in unserem Callbacks Guide

Andere Motorkomponenten

Es gibt andere Komponenten, die ähnlich angepasst werden können, wie Validators und Predictors. Weitere Informationen dazu findest du im Abschnitt Referenz.



Erstellt 2023-11-12, Aktualisiert 2024-02-03
Autoren: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Kommentare