Overslaan naar inhoud

Geavanceerd aanpassen

Zowel de Ultralytics YOLO command-line en Python interfaces zijn eenvoudigweg een abstractie op hoog niveau van de basis engine executors. Laten we eens kijken naar de Trainer engine.



Kijken: Mastering Ultralytics YOLOv8 : Geavanceerd aanpassen

Basistrainer

BaseTrainer bevat de generieke boilerplate trainingsroutine. Deze kan voor elke taak worden aangepast door de vereiste functies of bewerkingen te overschrijven, zolang de juiste formaten worden gevolgd. Je kunt bijvoorbeeld je eigen aangepaste model en dataloader ondersteunen door deze functies te overschrijven:

  • get_model(cfg, weights) - De functie die het te trainen model bouwt
  • get_dataloader() - De functie die de dataloader bouwt Meer details en broncode zijn te vinden in BaseTrainer Referentie

Opsporingstrainer

Hier lees je hoe je de YOLOv8 DetectionTrainer en pas het aan.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

De Detectietrainer aanpassen

Laten we de trainer aanpassen om een aangepast detectiemodel te trainen die niet direct wordt ondersteund. Je kunt dit doen door eenvoudigweg de bestaande de get_model functionaliteit:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Je realiseert je nu dat je de trainer verder moet aanpassen:

  • Pas de loss function.
  • Voeg toe callback dat model uploadt naar je Google Drive na elke 10 epochs Hier lees je hoe je dat kunt doen:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Als je meer wilt weten over Callback-gebeurtenissen en ingangspunten, bekijk dan onze Callbacks-gids

Andere motoronderdelen

Er zijn andere onderdelen die op dezelfde manier kunnen worden aangepast, zoals Validators en Predictors. Zie het gedeelte Verwijzingen voor meer informatie hierover.

FAQ

Hoe pas ik de Ultralytics YOLOv8 Detectietrainer aan voor specifieke taken?

Om de Ultralytics YOLOv8 DetectionTrainer voor een specifieke taak, kun je de methoden ervan overschrijven om aan te passen aan je aangepaste model en dataloader. Begin met erven van DetectionTrainer en definieer dan methoden als get_model om je aangepaste functionaliteiten te implementeren. Hier is een voorbeeld:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Voor verdere aanpassingen zoals het veranderen van de loss function of een callbackkun je verwijzen naar onze Gids voor terugbellen.

Wat zijn de belangrijkste onderdelen van de BaseTrainer in Ultralytics YOLOv8 ?

De BaseTrainer in Ultralytics YOLOv8 dient als basis voor trainingsroutines en kan worden aangepast voor verschillende taken door de generieke methoden te overschrijven. De belangrijkste onderdelen zijn:

  • get_model(cfg, weights) om het te trainen model op te bouwen.
  • get_dataloader() om de dataloader te bouwen.

Voor meer details over de aanpassing en de broncode, zie de BaseTrainer Referentie.

Hoe kan ik een callback toevoegen aan de Ultralytics YOLOv8 Detectietrainer?

Je kunt callbacks toevoegen om het trainingsproces te controleren en aan te passen in Ultralytics YOLOv8 DetectionTrainer. Zo kun je bijvoorbeeld een callback toevoegen om modelgewichten na elke trainingsepoch te loggen:

from ultralytics.models.yolo.detect import DetectionTrainer


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

Raadpleeg onze Callbacks-gids voor meer informatie over callback-gebeurtenissen en ingangspunten.

Waarom zou ik Ultralytics YOLOv8 gebruiken voor modeltraining?

Ultralytics YOLOv8 biedt een abstractie op hoog niveau op krachtige engine executors, waardoor het ideaal is voor snelle ontwikkeling en aanpassing. De belangrijkste voordelen zijn:

  • Gebruiksgemak: Zowel commandoregel als Python interfaces vereenvoudigen complexe taken.
  • Prestaties: Geoptimaliseerd voor real-time objectdetectie en diverse AI-toepassingen met vision.
  • Aanpassing: Eenvoudig uitbreidbaar voor aangepaste modellen, verliesfuncties en dataloaders.

Bezoek YOLOv8 voor meer informatie over de mogelijkheden. Ultralytics YOLO.

Kan ik de Ultralytics YOLOv8 DetectionTrainer gebruiken voor niet-standaard modellen?

Ja, Ultralytics YOLOv8 DetectionTrainer is zeer flexibel en kan worden aangepast voor niet-standaard modellen. Door te erven van DetectionTrainerkun je verschillende methoden overbelasten om aan de behoeften van je specifieke model te voldoen. Hier is een eenvoudig voorbeeld:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model."""
        ...


trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

Bekijk de documentatie van Detectietrainer voor uitgebreidere instructies en voorbeelden.



Aangemaakt 2023-11-12, Bijgewerkt 2024-07-04
Auteurs: glenn-jocher (7), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Reacties