Overslaan naar inhoud

Geavanceerd aanpassen

Zowel de Ultralytics YOLO command-line en Python interfaces zijn eenvoudigweg een abstractie op hoog niveau van de basis engine executors. Laten we eens kijken naar de Trainer engine.



Kijken: Mastering Ultralytics YOLOv8 : Geavanceerd aanpassen

Basistrainer

BaseTrainer bevat de generieke boilerplate trainingsroutine. Deze kan voor elke taak worden aangepast door de vereiste functies of bewerkingen te overschrijven, zolang de juiste formaten worden gevolgd. Je kunt bijvoorbeeld je eigen aangepaste model en dataloader ondersteunen door deze functies te overschrijven:

  • get_model(cfg, weights) - De functie die het te trainen model bouwt
  • get_dataloader() - De functie die de dataloader bouwt Meer details en broncode zijn te vinden in BaseTrainer Referentie

Opsporingstrainer

Hier lees je hoe je de YOLOv8 DetectionTrainer en pas het aan.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

De Detectietrainer aanpassen

Laten we de trainer aanpassen om een aangepast detectiemodel te trainen die niet direct wordt ondersteund. Je kunt dit doen door eenvoudigweg de bestaande de get_model functionaliteit:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Je realiseert je nu dat je de trainer verder moet aanpassen:

  • Pas de loss function.
  • Voeg toe callback dat model uploadt naar je Google Drive na elke 10 epochs Hier lees je hoe je dat kunt doen:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Als je meer wilt weten over Callback-gebeurtenissen en ingangspunten, bekijk dan onze Callbacks-gids

Andere motoronderdelen

Er zijn andere onderdelen die op dezelfde manier kunnen worden aangepast, zoals Validators en Predictors. Zie het gedeelte Verwijzingen voor meer informatie hierover.



Aangemaakt 2023-11-12, Bijgewerkt 2024-02-03
Auteurs: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Reacties