सामग्री पर जाएं

के लिए संदर्भ ultralytics/models/yolo/classify/predict.py

नोट

यह फ़ाइल यहाँ उपलब्ध है https://github.com/ultralytics/ultralytics/बूँद/मुख्य/ultralytics/मॉडल/yolo/वर्गीकृत/भविष्यवाणी करें.py का उपयोग करें। यदि आप कोई समस्या देखते हैं तो कृपया पुल अनुरोध का योगदान करके इसे ठीक करने में मदद करें 🛠️। 🙏 धन्यवाद !



ultralytics.models.yolo.classify.predict.ClassificationPredictor

का रूप: BasePredictor

वर्गीकरण मॉडल के आधार पर भविष्यवाणी के लिए BasePredictor वर्ग का विस्तार करने वाला वर्ग।

नोट्स
  • टॉर्चविजन वर्गीकरण मॉडल को 'मॉडल' तर्क में भी पारित किया जा सकता है, यानी मॉडल = 'रेसनेट18'।
उदाहरण
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.classify import ClassificationPredictor

args = dict(model='yolov8n-cls.pt', source=ASSETS)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
में स्रोत कोड ultralytics/models/yolo/classify/predict.py
12 बांग्लादेश 12 बांग्लादेश 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 373839 404142434445 46 47 48 49 50 51 52 53 54 55565758596061
class ClassificationPredictor(BasePredictor):
    """
    A class extending the BasePredictor class for prediction based on a classification model.

    Notes:
        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.

    Example:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.classify import ClassificationPredictor

        args = dict(model='yolov8n-cls.pt', source=ASSETS)
        predictor = ClassificationPredictor(overrides=args)
        predictor.predict_cli()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initializes ClassificationPredictor setting the task to 'classify'."""
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "classify"
        self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"

    def preprocess(self, img):
        """Converts input image to model-compatible data type."""
        if not isinstance(img, torch.Tensor):
            is_legacy_transform = any(
                self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
            )
            if is_legacy_transform:  # to handle legacy transforms
                img = torch.stack([self.transforms(im) for im in img], dim=0)
            else:
                img = torch.stack(
                    [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
                )
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
        return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions to return Results objects."""
        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for i, pred in enumerate(preds):
            orig_img = orig_imgs[i]
            img_path = self.batch[0][i]
            results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
        return results

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

ClassificationPredictor को कार्य को 'वर्गीकृत' पर सेट करता है।

में स्रोत कोड ultralytics/models/yolo/classify/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initializes ClassificationPredictor setting the task to 'classify'."""
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "classify"
    self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"

postprocess(preds, img, orig_imgs)

परिणाम वस्तुओं को वापस करने के लिए भविष्यवाणियों को संसाधित करने के बाद।

में स्रोत कोड ultralytics/models/yolo/classify/predict.py
def postprocess(self, preds, img, orig_imgs):
    """Post-processes predictions to return Results objects."""
    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    for i, pred in enumerate(preds):
        orig_img = orig_imgs[i]
        img_path = self.batch[0][i]
        results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
    return results

preprocess(img)

इनपुट छवि को मॉडल-संगत डेटा प्रकार में कनवर्ट करता है।

में स्रोत कोड ultralytics/models/yolo/classify/predict.py
36 बांग्लादेश 37 38 3940 41 42 43 44 4546474849
def preprocess(self, img):
    """Converts input image to model-compatible data type."""
    if not isinstance(img, torch.Tensor):
        is_legacy_transform = any(
            self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
        )
        if is_legacy_transform:  # to handle legacy transforms
            img = torch.stack([self.transforms(im) for im in img], dim=0)
        else:
            img = torch.stack(
                [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
            )
    img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
    return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32





2023-11-12 बनाया गया, अपडेट किया गया 2023-11-25
लेखक: ग्लेन-जोचर (3)