Bỏ để qua phần nội dung

Tài liệu tham khảo cho ultralytics/models/yolo/world/train_world.py

Ghi

Tệp này có sẵn tại https://github.com/ultralytics/ultralytics/blob/main/ultralytics/Mô hình/yolo/thế giới/train_world.py. Nếu bạn phát hiện ra một vấn đề, vui lòng giúp khắc phục nó bằng cách đóng góp Yêu cầu 🛠️ kéo. Cảm ơn bạn 🙏 !



ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch

Căn cứ: WorldTrainer

Một lớp mở rộng lớp WorldTrainer để đào tạo một mô hình thế giới từ đầu trên tập dữ liệu tập mở.

Ví dụ
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld

data = dict(
    train=dict(
        yolo_data=["Objects365.yaml"],
        grounding_data=[
            dict(
                img_path="../datasets/flickr30k/images",
                json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
            ),
            dict(
                img_path="../datasets/GQA/images",
                json_file="../datasets/GQA/final_mixed_train_no_coco.json",
            ),
        ],
    ),
    val=dict(yolo_data=["lvis.yaml"]),
)

model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, trainer=WorldTrainerFromScratch)
Mã nguồn trong ultralytics/models/yolo/world/train_world.py
class WorldTrainerFromScratch(WorldTrainer):
    """
    A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.

    Example:
        ```python
        from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
        from ultralytics import YOLOWorld

        data = dict(
            train=dict(
                yolo_data=["Objects365.yaml"],
                grounding_data=[
                    dict(
                        img_path="../datasets/flickr30k/images",
                        json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
                    ),
                    dict(
                        img_path="../datasets/GQA/images",
                        json_file="../datasets/GQA/final_mixed_train_no_coco.json",
                    ),
                ],
            ),
            val=dict(yolo_data=["lvis.yaml"]),
        )

        model = YOLOWorld("yolov8s-worldv2.yaml")
        model.train(data=data, trainer=WorldTrainerFromScratch)
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a WorldTrainer object with given arguments."""
        if overrides is None:
            overrides = {}
        super().__init__(cfg, overrides, _callbacks)

    def build_dataset(self, img_path, mode="train", batch=None):
        """
        Build YOLO Dataset.

        Args:
            img_path (List[str] | str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        if mode == "train":
            dataset = [
                build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
                if isinstance(im_path, str)
                else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
                for im_path in img_path
            ]
            return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
        else:
            return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

    def get_dataset(self):
        """
        Get train, val path from data dict if it exists.

        Returns None if data format is not recognized.
        """
        final_data = dict()
        data_yaml = self.args.data
        assert data_yaml.get("train", False)  # object365.yaml
        assert data_yaml.get("val", False)  # lvis.yaml
        data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
        assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
        val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
        for d in data["val"]:
            if d.get("minival") is None:  # for lvis dataset
                continue
            d["minival"] = str(d["path"] / d["minival"])
        for s in ["train", "val"]:
            final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
            # save grounding data if there's one
            grounding_data = data_yaml[s].get("grounding_data")
            if grounding_data is None:
                continue
            grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
            for g in grounding_data:
                assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
            final_data[s] += grounding_data
        # NOTE: to make training work properly, set `nc` and `names`
        final_data["nc"] = data["val"][0]["nc"]
        final_data["names"] = data["val"][0]["names"]
        self.data = final_data
        return final_data["train"], final_data["val"][0]

    def plot_training_labels(self):
        """DO NOT plot labels."""
        pass

    def final_eval(self):
        """Performs final evaluation and validation for object detection YOLO-World model."""
        val = self.args.data["val"]["yolo_data"][0]
        self.validator.args.data = val
        self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
        return super().final_eval()

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Khởi tạo một đối tượng WorldTrainer với các đối số đã cho.

Mã nguồn trong ultralytics/models/yolo/world/train_world.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a WorldTrainer object with given arguments."""
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)

build_dataset(img_path, mode='train', batch=None)

Xây dựng YOLO Tập dữ liệu.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
img_path List[str] | str

Đường dẫn đến thư mục chứa hình ảnh.

bắt buộc
mode str

train chế độ hoặc val Chế độ, người dùng có thể tùy chỉnh các phần tăng cường khác nhau cho từng chế độ.

'train'
batch int

Kích thước của lô, cái này dành cho rect. Mặc định là Không có.

None
Mã nguồn trong ultralytics/models/yolo/world/train_world.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset.

    Args:
        img_path (List[str] | str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    if mode == "train":
        dataset = [
            build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
            if isinstance(im_path, str)
            else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
            for im_path in img_path
        ]
        return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
    else:
        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

final_eval()

Thực hiện đánh giá cuối cùng và xác nhận để phát hiện đối tượng YOLO-Mô hình thế giới.

Mã nguồn trong ultralytics/models/yolo/world/train_world.py
def final_eval(self):
    """Performs final evaluation and validation for object detection YOLO-World model."""
    val = self.args.data["val"]["yolo_data"][0]
    self.validator.args.data = val
    self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
    return super().final_eval()

get_dataset()

Nhận train, val path từ dữ liệu ra lệnh nếu nó tồn tại.

Trả về Không có nếu định dạng dữ liệu không được nhận dạng.

Mã nguồn trong ultralytics/models/yolo/world/train_world.py
def get_dataset(self):
    """
    Get train, val path from data dict if it exists.

    Returns None if data format is not recognized.
    """
    final_data = dict()
    data_yaml = self.args.data
    assert data_yaml.get("train", False)  # object365.yaml
    assert data_yaml.get("val", False)  # lvis.yaml
    data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
    assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
    val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
    for d in data["val"]:
        if d.get("minival") is None:  # for lvis dataset
            continue
        d["minival"] = str(d["path"] / d["minival"])
    for s in ["train", "val"]:
        final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
        # save grounding data if there's one
        grounding_data = data_yaml[s].get("grounding_data")
        if grounding_data is None:
            continue
        grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
        for g in grounding_data:
            assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
        final_data[s] += grounding_data
    # NOTE: to make training work properly, set `nc` and `names`
    final_data["nc"] = data["val"][0]["nc"]
    final_data["names"] = data["val"][0]["names"]
    self.data = final_data
    return final_data["train"], final_data["val"][0]

plot_training_labels()

KHÔNG vẽ nhãn.

Mã nguồn trong ultralytics/models/yolo/world/train_world.py
def plot_training_labels(self):
    """DO NOT plot labels."""
    pass





Đã tạo 2024-03-31, Cập nhật 2024-05-08
Tác giả: Burhan-Q (1), Cười-q (1)