์ฝ˜ํ…์ธ ๋กœ ๊ฑด๋„ˆ๋›ฐ๊ธฐ

์ฐธ์กฐ ultralytics/models/yolo/world/train.py

์ฐธ๊ณ 

์ด ํŒŒ์ผ์€ https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/ yolo/world/train .py์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌธ์ œ๋ฅผ ๋ฐœ๊ฒฌํ•˜๋ฉด ํ’€ ๋ฆฌํ€˜์ŠคํŠธ (๐Ÿ› ๏ธ) ๋ฅผ ํ†ตํ•ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋„๋ก ๋„์™€์ฃผ์„ธ์š”. ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค ๐Ÿ™!



ultralytics.models.yolo.world.train.WorldTrainer

๋ฒ ์ด์Šค: DetectionTrainer

ํด๋กœ์ฆˆ์…‹ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์›”๋“œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.

์˜ˆ
from ultralytics.models.yolo.world import WorldModel

args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
class WorldTrainer(yolo.detect.DetectionTrainer):
    """
    A class to fine-tune a world model on a close-set dataset.

    Example:
        ```python
        from ultralytics.models.yolo.world import WorldModel

        args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
        trainer = WorldTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a WorldTrainer object with given arguments."""
        if overrides is None:
            overrides = {}
        super().__init__(cfg, overrides, _callbacks)

        # Import and assign clip
        try:
            import clip
        except ImportError:
            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
            import clip
        self.clip = clip

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return WorldModel initialized with specified config and weights."""
        # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
        # NOTE: Following the official config, nc hard-coded to 80 for now.
        model = WorldModel(
            cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
            ch=3,
            nc=min(self.data["nc"], 80),
            verbose=verbose and RANK == -1,
        )
        if weights:
            model.load(weights)
        self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

        return model

    def build_dataset(self, img_path, mode="train", batch=None):
        """
        Build YOLO Dataset.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        return build_yolo_dataset(
            self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
        )

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
        batch = super().preprocess_batch(batch)

        # NOTE: add text features
        texts = list(itertools.chain(*batch["texts"]))
        text_token = self.clip.tokenize(texts).to(batch["img"].device)
        txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
        batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
        return batch

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

์ฃผ์–ด์ง„ ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ WorldTrainer ๊ฐ์ฒด๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a WorldTrainer object with given arguments."""
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)

    # Import and assign clip
    try:
        import clip
    except ImportError:
        checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip
    self.clip = clip

build_dataset(img_path, mode='train', batch=None)

YOLO ๋ฐ์ดํ„ฐ ์ง‘ํ•ฉ์„ ๊ตฌ์ถ•ํ•ฉ๋‹ˆ๋‹ค.

๋งค๊ฐœ๋ณ€์ˆ˜:

์ด๋ฆ„ ์œ ํ˜• ์„ค๋ช… ๊ธฐ๋ณธ๊ฐ’
img_path str

์ด๋ฏธ์ง€๊ฐ€ ํฌํ•จ๋œ ํด๋”์˜ ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.

ํ•„์ˆ˜
mode str

train ๋ชจ๋“œ ๋˜๋Š” val ๋ชจ๋“œ์—์„œ ์‚ฌ์šฉ์ž๋Š” ๊ฐ ๋ชจ๋“œ๋งˆ๋‹ค ๋‹ค๋ฅธ ์ฆ๊ฐ• ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉ์ž ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

'train'
batch int

๋ฐฐ์น˜์˜ ํฌ๊ธฐ, ์ด๊ฒƒ์€ ๋‹ค์Œ์„ ์œ„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. rect. ๊ธฐ๋ณธ๊ฐ’์€ ์—†์Œ์ž…๋‹ˆ๋‹ค.

None
์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )

get_model(cfg=None, weights=None, verbose=True)

์ง€์ •๋œ ๊ตฌ์„ฑ๊ณผ ๊ฐ€์ค‘์น˜๋กœ ์ดˆ๊ธฐํ™”๋œ ์›”๋“œ๋ชจ๋ธ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return WorldModel initialized with specified config and weights."""
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = WorldModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=3,
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)
    self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

    return model

preprocess_batch(batch)

ํ•„์š”์— ๋”ฐ๋ผ ์„œ์‹๊ณผ ์น˜์ˆ˜๋ฅผ ์กฐ์ •ํ•˜์—ฌ YOLOWorld ๊ต์œก์šฉ ์ด๋ฏธ์ง€ ๋ฐฐ์น˜๋ฅผ ์‚ฌ์ „ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
    batch = super().preprocess_batch(batch)

    # NOTE: add text features
    texts = list(itertools.chain(*batch["texts"]))
    text_token = self.clip.tokenize(texts).to(batch["img"].device)
    txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
    return batch



ultralytics.models.yolo.world.train.on_pretrain_routine_end(trainer)

์ฝœ๋ฐฑ.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/world/train.py
def on_pretrain_routine_end(trainer):
    """Callback."""
    if RANK in {-1, 0}:
        # NOTE: for evaluation
        names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
    device = next(trainer.model.parameters()).device
    trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
    for p in trainer.text_model.parameters():
        p.requires_grad_(False)





์ƒ์„ฑ 2024-03-31, ์—…๋ฐ์ดํŠธ 2024-03-31
์ž‘์„ฑ์ž: Laughing-q (1)