Personalizaci贸n avanzada
Tanto la interfaz de l铆nea de comandos Ultralytics YOLO como Python son simplemente una abstracci贸n de alto nivel sobre los ejecutores del motor base. Echemos un vistazo al motor Entrenador.
Observa: Dominio de Ultralytics YOLOv8 : Personalizaci贸n avanzada
BaseTrainer
BaseTrainer contiene la rutina de entrenamiento gen茅rica. Se puede personalizar para cualquier tarea sobreescribiendo las funciones u operaciones necesarias, siempre que se sigan los formatos correctos. Por ejemplo, puedes utilizar tu propio modelo y cargador de datos personalizados simplemente sobreescribiendo estas funciones:
get_model(cfg, weights)
- La funci贸n que construye el modelo a entrenarget_dataloader()
- La funci贸n que construye el cargador de datos Puedes encontrar m谩s detalles y el c贸digo fuente enBaseTrainer
Referencia
Detecci贸nEntrenador
A continuaci贸n te explicamos c贸mo puedes utilizar la YOLOv8 DetectionTrainer
y personal铆zalo.
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
Personalizar el DetectionTrainer
Personalicemos el entrenador para entrenar un modelo de detecci贸n personalizado que no se admite directamente. Puedes hacerlo simplemente sobrecargando la funci贸n existente get_model
funcionalidad:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
Ahora te das cuenta de que necesitas personalizar m谩s el entrenador para:
- Personaliza el
loss function
. - A帽ade
callback
que sube modelos a tu Google Drive cada 10epochs
He aqu铆 c贸mo puedes hacerlo:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
Para saber m谩s sobre los eventos desencadenantes de devoluci贸n de llamada y el punto de entrada, consulta nuestra Gu铆a de Devoluciones de llamada
Otros componentes del motor
Hay otros componentes que se pueden personalizar de forma similar, como Validators
y Predictors
. Consulta la secci贸n Referencia para obtener m谩s informaci贸n al respecto.
Creado 2023-11-12, Actualizado 2024-02-03
Autores: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)