Reference for ultralytics/models/yolo/segment/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/segment/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.segment.train.SegmentationTrainer
SegmentationTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Bases: DetectionTrainer
A class extending the DetectionTrainer class for training based on a segmentation model.
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific functionality including model initialization, validation, and visualization.
Attributes:
Name | Type | Description |
---|---|---|
loss_names |
Tuple[str]
|
Names of the loss components used during training. |
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
>>> trainer = SegmentationTrainer(overrides=args)
>>> trainer.train()
This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict
|
Configuration dictionary with default training settings. Defaults to DEFAULT_CFG. |
DEFAULT_CFG
|
overrides
|
dict
|
Dictionary of parameter overrides for the default configuration. |
None
|
_callbacks
|
list
|
List of callback functions to be executed during training. |
None
|
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
>>> args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
>>> trainer = SegmentationTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/segment/train.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
get_model
get_model(cfg=None, weights=None, verbose=True)
Initialize and return a SegmentationModel with specified configuration and weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict | str | None
|
Model configuration. Can be a dictionary, a path to a YAML file, or None. |
None
|
weights
|
str | Path | None
|
Path to pretrained weights file. |
None
|
verbose
|
bool
|
Whether to display model information during initialization. |
True
|
Returns:
Type | Description |
---|---|
SegmentationModel
|
Initialized segmentation model with loaded weights if specified. |
Examples:
>>> trainer = SegmentationTrainer()
>>> model = trainer.get_model(cfg="yolov8n-seg.yaml")
>>> model = trainer.get_model(weights="yolov8n-seg.pt", verbose=False)
Source code in ultralytics/models/yolo/segment/train.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
get_validator
get_validator()
Return an instance of SegmentationValidator for validation of YOLO model.
Source code in ultralytics/models/yolo/segment/train.py
74 75 76 77 78 79 |
|
plot_metrics
plot_metrics()
Plots training/val metrics.
Source code in ultralytics/models/yolo/segment/train.py
121 122 123 |
|
plot_training_samples
plot_training_samples(batch, ni)
Plot training sample images with labels, bounding boxes, and masks.
This method creates a visualization of training batch images with their corresponding labels, bounding boxes, and segmentation masks, saving the result to a file for inspection and debugging.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
dict
|
Dictionary containing batch data with the following keys: 'img': Images tensor 'batch_idx': Batch indices for each box 'cls': Class labels tensor (squeezed to remove last dimension) 'bboxes': Bounding box coordinates tensor 'masks': Segmentation masks tensor 'im_file': List of image file paths |
required |
ni
|
int
|
Current training iteration number, used for naming the output file. |
required |
Examples:
>>> trainer = SegmentationTrainer()
>>> batch = {
... "img": torch.rand(16, 3, 640, 640),
... "batch_idx": torch.zeros(16),
... "cls": torch.randint(0, 80, (16, 1)),
... "bboxes": torch.rand(16, 4),
... "masks": torch.rand(16, 640, 640),
... "im_file": ["image1.jpg", "image2.jpg"],
... }
>>> trainer.plot_training_samples(batch, ni=5)
Source code in ultralytics/models/yolo/segment/train.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|