Skip to content

Personnalisation avancée

Les interfaces Ultralytics YOLO en ligne de commande et Python sont simplement une abstraction de haut niveau sur les exécuteurs du moteur de base. Jetons un coup d'œil au moteur Trainer.



Regarde : Mastering Ultralytics YOLOv8 : Personnalisation avancée

Formateur de base

BaseTrainer contient la routine de formation générique. Elle peut être personnalisée pour n'importe quelle tâche en remplaçant les fonctions ou les opérations requises, tant que les formats corrects sont respectés. Par exemple, tu peux prendre en charge ton propre modèle et ton propre chargeur de données en remplaçant simplement ces fonctions :

  • get_model(cfg, weights) - La fonction qui construit le modèle Ă  entraĂ®ner.
  • get_dataloader() - La fonction qui construit le dataloader Plus de dĂ©tails et le code source peuvent ĂŞtre trouvĂ©s dans BaseTrainer RĂ©fĂ©rence

Formateur en détection

Voici comment tu peux utiliser le YOLOv8 DetectionTrainer et de la personnaliser.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Personnaliser le DetectionTrainer

Personnalisons le formateur pour former un modèle de détection personnalisé qui n'est pas pris en charge directement. Tu peux le faire en surchargeant simplement la fonction existante, la fonction get_model fonctionnalité :

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Tu réalises maintenant que tu dois personnaliser davantage le formateur pour :

  • Personnalise le loss function.
  • Ajouter callback qui tĂ©lĂ©charge un modèle sur ton Google Drive tous les 10 ans. epochs Voici comment tu peux le faire :
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Pour en savoir plus sur les événements déclencheurs et le point d'entrée des rappels, consulte notre guide sur les rappels.

Autres composants du moteur

Il existe d'autres composants qui peuvent être personnalisés de la même manière, comme par exemple Validators et Predictors. Voir la section Référence pour plus d'informations à ce sujet.



Créé le 2023-11-12, Mis à jour le 2024-02-03
Auteurs : glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Commentaires