Saltar al contenido

Personalizaci贸n avanzada

Tanto la interfaz de l铆nea de comandos Ultralytics YOLO como Python son simplemente una abstracci贸n de alto nivel sobre los ejecutores del motor base. Echemos un vistazo al motor Entrenador.



Observa: Dominio de Ultralytics YOLOv8 : Personalizaci贸n avanzada

BaseTrainer

BaseTrainer contiene la rutina de entrenamiento gen茅rica. Se puede personalizar para cualquier tarea sobreescribiendo las funciones u operaciones necesarias, siempre que se sigan los formatos correctos. Por ejemplo, puedes utilizar tu propio modelo y cargador de datos personalizados simplemente sobreescribiendo estas funciones:

  • get_model(cfg, weights) - La funci贸n que construye el modelo a entrenar
  • get_dataloader() - La funci贸n que construye el cargador de datos Puedes encontrar m谩s detalles y el c贸digo fuente en BaseTrainer Referencia

Detecci贸nEntrenador

A continuaci贸n te explicamos c贸mo puedes utilizar la YOLOv8 DetectionTrainer y personal铆zalo.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Personalizar el DetectionTrainer

Personalicemos el entrenador para entrenar un modelo de detecci贸n personalizado que no se admite directamente. Puedes hacerlo simplemente sobrecargando la funci贸n existente get_model funcionalidad:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Ahora te das cuenta de que necesitas personalizar m谩s el entrenador para:

  • Personaliza el loss function.
  • A帽ade callback que sube modelos a tu Google Drive cada 10 epochs He aqu铆 c贸mo puedes hacerlo:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Para saber m谩s sobre los eventos desencadenantes de devoluci贸n de llamada y el punto de entrada, consulta nuestra Gu铆a de Devoluciones de llamada

Otros componentes del motor

Hay otros componentes que se pueden personalizar de forma similar, como Validators y Predictors. Consulta la secci贸n Referencia para obtener m谩s informaci贸n al respecto.



Creado 2023-11-12, Actualizado 2024-05-03
Autores: glenn-jocher (5), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Comentarios