Erweiterte Anpassung
Sowohl die Kommandozeilenschnittstelle Ultralytics YOLO als auch die Schnittstelle Python sind lediglich eine Abstraktion der Basis-Engine-Executors. Werfen wir einen Blick auf die Trainer-Engine.
Pass auf: Mastering Ultralytics YOLOv8 : Erweiterte Anpassungen
BaseTrainer
Der BaseTrainer enthält die generische Trainingsroutine. Er kann für jede Aufgabe angepasst werden, indem die erforderlichen Funktionen oder Operationen überschrieben werden, solange die richtigen Formate eingehalten werden. Du kannst zum Beispiel dein eigenes benutzerdefiniertes Modell und deinen eigenen Dataloader unterstützen, indem du diese Funktionen einfach überschreibst:
get_model(cfg, weights)
- Die Funktion, die das zu trainierende Modell erstelltget_dataloader()
- Die Funktion, die den Dataloader aufbaut Weitere Details und Quellcode findest du inBaseTrainer
Referenz
DetectionTrainer
Hier erfährst du, wie du die YOLOv8 DetectionTrainer
und passe sie an.
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
Anpassen des DetectionTrainers
Lass uns den Trainer anpassen um ein benutzerdefiniertes Erkennungsmodell zu trainieren die nicht direkt unterstützt wird. Du kannst dies tun, indem du einfach die bestehende Methode get_model
Funktionalität:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
Du merkst jetzt, dass du den Trainer weiter anpassen musst, um:
- Anpassen der
loss function
. - hinzufügen
callback
das Modell nach jeweils 10 Minuten auf dein Google Drive hochlädt.epochs
So kannst du es tun:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
Mehr über Callback-Ereignisse und den Einstiegspunkt erfährst du in unserem Callbacks Guide
Andere Motorkomponenten
Es gibt andere Komponenten, die ähnlich angepasst werden können, wie Validators
und Predictors
. Weitere Informationen dazu findest du im Abschnitt Referenz.
Erstellt 2023-11-12, Aktualisiert 2024-02-03
Autoren: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)