Vai al contenuto

Personalizzazione avanzata

Entrambe le interfacce Ultralytics YOLO a riga di comando e Python sono semplicemente un'astrazione di alto livello sugli esecutori del motore di base. Diamo un'occhiata al motore Trainer.



Guarda: Mastering Ultralytics YOLOv8 : Personalizzazione avanzata

Allenatore di base

BaseTrainer contiene la routine di formazione generica. Può essere personalizzato per qualsiasi attività, sovrascrivendo le funzioni o le operazioni necessarie, purché vengano rispettati i formati corretti. Ad esempio, puoi supportare un modello e un dataloader personalizzati semplicemente sovrascrivendo queste funzioni:

  • get_model(cfg, weights) - La funzione che costruisce il modello da addestrare
  • get_dataloader() - La funzione che costruisce il dataloader Maggiori dettagli e codice sorgente sono disponibili in BaseTrainer Riferimento

Addestratore di rilevamento

Ecco come puoi utilizzare la funzione YOLOv8 DetectionTrainer e personalizzarlo.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Personalizzare il DetectionTrainer

Personalizziamo il trainer per addestrare un modello di rilevamento personalizzato che non è supportato direttamente. Puoi farlo semplicemente sovraccaricando la funzione esistente the get_model funzionalità:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Ora ti rendi conto che devi personalizzare ulteriormente il trainer:

  • Personalizza il sito loss function.
  • Aggiungi callback che carica il modello su Google Drive dopo ogni 10 epochs Ecco come puoi farlo:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Per saperne di più sugli eventi di attivazione delle Callback e sul punto di ingresso, consulta la nostra Guida alle Callback.

Altri componenti del motore

Ci sono altri componenti che possono essere personalizzati in modo simile, come ad esempio Validators e Predictors. Per ulteriori informazioni, consulta la sezione Riferimenti.



Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (6), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Commenti