Personalizzazione avanzata
Entrambe le interfacce Ultralytics YOLO a riga di comando e Python sono semplicemente un'astrazione di alto livello sugli esecutori del motore di base. Diamo un'occhiata al motore Trainer.
Guarda: Mastering Ultralytics YOLOv8 : Personalizzazione avanzata
Allenatore di base
BaseTrainer contiene la routine di formazione generica. Può essere personalizzato per qualsiasi attività , sovrascrivendo le funzioni o le operazioni necessarie, purché vengano rispettati i formati corretti. Ad esempio, puoi supportare un modello e un dataloader personalizzati semplicemente sovrascrivendo queste funzioni:
get_model(cfg, weights)
- La funzione che costruisce il modello da addestrareget_dataloader()
- La funzione che costruisce il dataloader Maggiori dettagli e codice sorgente sono disponibili inBaseTrainer
Riferimento
Addestratore di rilevamento
Ecco come puoi utilizzare la funzione YOLOv8 DetectionTrainer
e personalizzarlo.
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
Personalizzare il DetectionTrainer
Personalizziamo il trainer per addestrare un modello di rilevamento personalizzato che non è supportato direttamente. Puoi farlo semplicemente sovraccaricando la funzione esistente the get_model
funzionalità :
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
Ora ti rendi conto che devi personalizzare ulteriormente il trainer:
- Personalizza il sito
loss function
. - Aggiungi
callback
che carica il modello su Google Drive dopo ogni 10epochs
Ecco come puoi farlo:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
Per saperne di più sugli eventi di attivazione delle Callback e sul punto di ingresso, consulta la nostra Guida alle Callback.
Altri componenti del motore
Ci sono altri componenti che possono essere personalizzati in modo simile, come ad esempio Validators
e Predictors
. Per ulteriori informazioni, consulta la sezione Riferimenti.
Creato 2023-11-12, Aggiornato 2024-02-03
Autori: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)