高级定制
Ultralytics YOLO 命令行接口和Python 接口都只是基础引擎执行器的高级抽象。让我们来看看 Trainer 引擎。
观看: 掌握Ultralytics YOLOv8 :高级定制
基础培训师
BaseTrainer 包含通用的模板培训例程。只要遵循正确的格式,就可以为任何任务定制所需的函数或操作。例如,只需重载这些函数,就能支持自己的自定义模型和数据加载器:
get_model(cfg, weights)
- 建立待训练模型的函数get_dataloader()
- 构建数据加载器的函数 更多详情和源代码请参见BaseTrainer
参考资料
探测训练器
以下是如何使用YOLOv8 DetectionTrainer
并进行定制。
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
自定义检测训练器
让我们定制培训师 来训练自定义检测模型 的重载。您只需重载现有的 get_model
功能性:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
您现在意识到,您需要进一步定制培训师:
- 自定义
loss function
. - 添加
callback
每隔 10 分钟就会将模型上传到您的 Google Driveepochs
具体方法如下
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
要了解有关回调触发事件和入口点的更多信息,请查阅我们的《回调指南》。
其他发动机部件
还有其他类似的可定制组件,如 Validators
和 Predictors
.有关更多信息,请参见参考资料部分。