Geavanceerd aanpassen
Zowel de Ultralytics YOLO command-line en Python interfaces zijn eenvoudigweg een abstractie op hoog niveau van de basis engine executors. Laten we eens kijken naar de Trainer engine.
Kijken: Mastering Ultralytics YOLOv8 : Geavanceerd aanpassen
Basistrainer
BaseTrainer bevat de generieke boilerplate trainingsroutine. Deze kan voor elke taak worden aangepast door de vereiste functies of bewerkingen te overschrijven, zolang de juiste formaten worden gevolgd. Je kunt bijvoorbeeld je eigen aangepaste model en dataloader ondersteunen door deze functies te overschrijven:
get_model(cfg, weights)
- De functie die het te trainen model bouwtget_dataloader()
- De functie die de dataloader bouwt Meer details en broncode zijn te vinden inBaseTrainer
Referentie
Opsporingstrainer
Hier lees je hoe je de YOLOv8 DetectionTrainer
en pas het aan.
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
De Detectietrainer aanpassen
Laten we de trainer aanpassen om een aangepast detectiemodel te trainen die niet direct wordt ondersteund. Je kunt dit doen door eenvoudigweg de bestaande de get_model
functionaliteit:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
Je realiseert je nu dat je de trainer verder moet aanpassen:
- Pas de
loss function
. - Voeg toe
callback
dat model uploadt naar je Google Drive na elke 10epochs
Hier lees je hoe je dat kunt doen:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
Als je meer wilt weten over Callback-gebeurtenissen en ingangspunten, bekijk dan onze Callbacks-gids
Andere motoronderdelen
Er zijn andere onderdelen die op dezelfde manier kunnen worden aangepast, zoals Validators
en Predictors
. Zie het gedeelte Verwijzingen voor meer informatie hierover.
Aangemaakt 2023-11-12, Bijgewerkt 2024-02-03
Auteurs: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)