Zum Inhalt springen

Erweiterte Anpassung

Sowohl die Kommandozeilenschnittstelle Ultralytics YOLO als auch die Schnittstelle Python sind lediglich eine Abstraktion der Basis-Engine-Executors. Werfen wir einen Blick auf die Trainer-Engine.



Pass auf: Mastering Ultralytics YOLOv8 : Erweiterte Anpassungen

BaseTrainer

Der BaseTrainer enth├Ąlt die generische Trainingsroutine. Er kann f├╝r jede Aufgabe angepasst werden, indem die erforderlichen Funktionen oder Operationen ├╝berschrieben werden, solange die richtigen Formate eingehalten werden. Du kannst zum Beispiel dein eigenes benutzerdefiniertes Modell und deinen eigenen Dataloader unterst├╝tzen, indem du diese Funktionen einfach ├╝berschreibst:

  • get_model(cfg, weights) - Die Funktion, die das zu trainierende Modell erstellt
  • get_dataloader() - Die Funktion, die den Dataloader aufbaut Weitere Details und Quellcode findest du in BaseTrainer Referenz

DetectionTrainer

Hier erf├Ąhrst du, wie du die YOLOv8 DetectionTrainer und passe sie an.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Anpassen des DetectionTrainers

Lass uns den Trainer anpassen um ein benutzerdefiniertes Erkennungsmodell zu trainieren die nicht direkt unterst├╝tzt wird. Du kannst dies tun, indem du einfach die bestehende Methode get_model Funktionalit├Ąt:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Du merkst jetzt, dass du den Trainer weiter anpassen musst, um:

  • Anpassen der loss function.
  • hinzuf├╝gen callback das Modell nach jeweils 10 Minuten auf dein Google Drive hochl├Ądt. epochs So kannst du es tun:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Mehr ├╝ber Callback-Ereignisse und den Einstiegspunkt erf├Ąhrst du in unserem Callbacks Guide

Andere Motorkomponenten

Es gibt andere Komponenten, die ├Ąhnlich angepasst werden k├Ânnen, wie Validators und Predictors. Weitere Informationen dazu findest du im Abschnitt Referenz.



Erstellt 2023-11-12, Aktualisiert 2024-02-03
Autoren: glenn-jocher (4), chr043416@gmail.com (1), AyushExel (1), Laughing-q (1)

Kommentare