์ฝ˜ํ…์ธ ๋กœ ๊ฑด๋„ˆ๋›ฐ๊ธฐ

์ฐธ์กฐ ultralytics/models/yolo/pose/train.py

์ฐธ๊ณ 

์ด ํŒŒ์ผ์€ https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/ yolo/pose/train .py์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌธ์ œ๋ฅผ ๋ฐœ๊ฒฌํ•˜๋ฉด ํ’€ ๋ฆฌํ€˜์ŠคํŠธ ๐Ÿ› ๏ธ ์— ๊ธฐ์—ฌํ•˜์—ฌ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋„๋ก ๋„์™€์ฃผ์„ธ์š”. ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค ๐Ÿ™!



ultralytics.models.yolo.pose.train.PoseTrainer

๊ธฐ์ง€: DetectionTrainer

ํฌ์ฆˆ ๋ชจ๋ธ์— ๊ธฐ๋ฐ˜ํ•œ ํ›ˆ๋ จ์„ ์œ„ํ•ด DetectionTrainer ํด๋ž˜์Šค๋ฅผ ํ™•์žฅํ•œ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.

์˜ˆ์ œ
from ultralytics.models.yolo.pose import PoseTrainer

args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
class PoseTrainer(yolo.detect.DetectionTrainer):
    """
    A class extending the DetectionTrainer class for training based on a pose model.

    Example:
        ```python
        from ultralytics.models.yolo.pose import PoseTrainer

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
        trainer = PoseTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a PoseTrainer object with specified configurations and overrides."""
        if overrides is None:
            overrides = {}
        overrides["task"] = "pose"
        super().__init__(cfg, overrides, _callbacks)

        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
            LOGGER.warning(
                "WARNING โš ๏ธ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."
            )

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Get pose estimation model with specified configuration and weights."""
        model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
        if weights:
            model.load(weights)

        return model

    def set_model_attributes(self):
        """Sets keypoints shape attribute of PoseModel."""
        super().set_model_attributes()
        self.model.kpt_shape = self.data["kpt_shape"]

    def get_validator(self):
        """Returns an instance of the PoseValidator class for validation."""
        self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
        return yolo.pose.PoseValidator(
            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
        )

    def plot_training_samples(self, batch, ni):
        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
        images = batch["img"]
        kpts = batch["keypoints"]
        cls = batch["cls"].squeeze(-1)
        bboxes = batch["bboxes"]
        paths = batch["im_file"]
        batch_idx = batch["batch_idx"]
        plot_images(
            images,
            batch_idx,
            cls,
            bboxes,
            kpts=kpts,
            paths=paths,
            fname=self.save_dir / f"train_batch{ni}.jpg",
            on_plot=self.on_plot,
        )

    def plot_metrics(self):
        """Plots training/val metrics."""
        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

์ง€์ •๋œ ๊ตฌ์„ฑ๊ณผ ์˜ค๋ฒ„๋ผ์ด๋“œ๋กœ ํฌ์ฆˆ ํŠธ๋ ˆ์ด๋„ˆ ์˜ค๋ธŒ์ ํŠธ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a PoseTrainer object with specified configurations and overrides."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING โš ๏ธ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_model(cfg=None, weights=None, verbose=True)

์ง€์ •๋œ ๊ตฌ์„ฑ๊ณผ ๊ฐ€์ค‘์น˜๋กœ ํฌ์ฆˆ ์ถ”์ • ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get pose estimation model with specified configuration and weights."""
    model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
    if weights:
        model.load(weights)

    return model

get_validator()

์œ ํšจ์„ฑ ๊ฒ€์‚ฌ๋ฅผ ์œ„ํ•ด PoseValidator ํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics()

ํŠธ๋ ˆ์ด๋‹/๊ฐ’ ๋ฉ”ํŠธ๋ฆญ์„ ํ”Œ๋กœํŒ…ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

์ฃผ์„์ด ๋‹ฌ๋ฆฐ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”, ๊ฒฝ๊ณ„ ์ƒ์ž, ํ‚คํฌ์ธํŠธ๋กœ ํ›ˆ๋ จ ์ƒ˜ํ”Œ์˜ ๋ฐฐ์น˜๋ฅผ ํ”Œ๋กฏํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def plot_training_samples(self, batch, ni):
    """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
    images = batch["img"]
    kpts = batch["keypoints"]
    cls = batch["cls"].squeeze(-1)
    bboxes = batch["bboxes"]
    paths = batch["im_file"]
    batch_idx = batch["batch_idx"]
    plot_images(
        images,
        batch_idx,
        cls,
        bboxes,
        kpts=kpts,
        paths=paths,
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

set_model_attributes()

ํฌ์ฆˆ๋ชจ๋ธ์˜ ํ‚คํฌ์ธํŠธ ๋ชจ์–‘ ์†์„ฑ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

์˜ ์†Œ์Šค ์ฝ”๋“œ ultralytics/models/yolo/pose/train.py
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]





์ƒ์„ฑ๋จ 2023-11-12, ์—…๋ฐ์ดํŠธ๋จ 2023-11-25
์ž‘์„ฑ์ž: glenn-jocher (3)