सामग्री पर जाएं

के लिए संदर्भ ultralytics/models/yolo/pose/predict.py

नोट

यह फ़ाइल यहाँ उपलब्ध है https://github.com/ultralytics/ultralytics/बूँद/मुख्य/ultralytics/मॉडल/yolo/मुद्रा/भविष्यवाणी.py का उपयोग करें। यदि आप कोई समस्या देखते हैं तो कृपया पुल अनुरोध का योगदान करके इसे ठीक करने में मदद करें 🛠️। 🙏 धन्यवाद !



ultralytics.models.yolo.pose.predict.PosePredictor

का रूप: DetectionPredictor

एक मुद्रा मॉडल के आधार पर भविष्यवाणी के लिए DetectionPredictor वर्ग का विस्तार करने वाला वर्ग।

उदाहरण
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.pose import PosePredictor

args = dict(model='yolov8n-pose.pt', source=ASSETS)
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
में स्रोत कोड ultralytics/models/yolo/pose/predict.py
 8 बांग्लादेश 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34353637383940414243 44 45 46 47 48 495051 525354 55 565758
class PosePredictor(DetectionPredictor):
    """
    A class extending the DetectionPredictor class for prediction based on a pose model.

    Example:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.pose import PosePredictor

        args = dict(model='yolov8n-pose.pt', source=ASSETS)
        predictor = PosePredictor(overrides=args)
        predictor.predict_cli()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "pose"
        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
            LOGGER.warning(
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."
            )

    def postprocess(self, preds, img, orig_imgs):
        """Return detection results for a given input image or list of images."""
        preds = ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            classes=self.args.classes,
            nc=len(self.model.names),
        )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for i, pred in enumerate(preds):
            orig_img = orig_imgs[i]
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
            pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
            pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
            img_path = self.batch[0][i]
            results.append(
                Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
            )
        return results

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

PosePredictor को इनरिमूइलाइज़ करता है, टास्क को 'पोज़' पर सेट करता है और डिवाइस के रूप में 'mps' का उपयोग करने के लिए चेतावनी लॉग करता है।

में स्रोत कोड ultralytics/models/yolo/pose/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "pose"
    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

postprocess(preds, img, orig_imgs)

किसी दिए गए इनपुट छवि या छवियों की सूची के लिए पता लगाने के परिणाम लौटाएं।

में स्रोत कोड ultralytics/models/yolo/pose/predict.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 5051 52 5354555657 58
def postprocess(self, preds, img, orig_imgs):
    """Return detection results for a given input image or list of images."""
    preds = ops.non_max_suppression(
        preds,
        self.args.conf,
        self.args.iou,
        agnostic=self.args.agnostic_nms,
        max_det=self.args.max_det,
        classes=self.args.classes,
        nc=len(self.model.names),
    )

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    for i, pred in enumerate(preds):
        orig_img = orig_imgs[i]
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
        pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
        pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
        img_path = self.batch[0][i]
        results.append(
            Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
        )
    return results





2023-11-12 बनाया गया, अपडेट किया गया 2023-11-25
लेखक: ग्लेन-जोचर (3)