classClassificationPredictor(BasePredictor):""" A class extending the BasePredictor class for prediction based on a classification model. Notes: - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. Example: ```python from ultralytics.utils import ASSETS from ultralytics.models.yolo.classify import ClassificationPredictor args = dict(model='yolov8n-cls.pt', source=ASSETS) predictor = ClassificationPredictor(overrides=args) predictor.predict_cli() ``` """def__init__(self,cfg=DEFAULT_CFG,overrides=None,_callbacks=None):"""Initializes ClassificationPredictor setting the task to 'classify'."""super().__init__(cfg,overrides,_callbacks)self.args.task="classify"self._legacy_transform_name="ultralytics.yolo.data.augment.ToTensor"defpreprocess(self,img):"""Converts input image to model-compatible data type."""ifnotisinstance(img,torch.Tensor):is_legacy_transform=any(self._legacy_transform_nameinstr(transform)fortransforminself.transforms.transforms)ifis_legacy_transform:# to handle legacy transformsimg=torch.stack([self.transforms(im)foriminimg],dim=0)else:img=torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im,cv2.COLOR_BGR2RGB)))foriminimg],dim=0)img=(imgifisinstance(img,torch.Tensor)elsetorch.from_numpy(img)).to(self.model.device)returnimg.half()ifself.model.fp16elseimg.float()# uint8 to fp16/32defpostprocess(self,preds,img,orig_imgs):"""Post-processes predictions to return Results objects."""ifnotisinstance(orig_imgs,list):# input images are a torch.Tensor, not a listorig_imgs=ops.convert_torch2numpy_batch(orig_imgs)results=[]fori,predinenumerate(preds):orig_img=orig_imgs[i]img_path=self.batch[0][i]results.append(Results(orig_img,path=img_path,names=self.model.names,probs=pred))returnresults
def__init__(self,cfg=DEFAULT_CFG,overrides=None,_callbacks=None):"""Initializes ClassificationPredictor setting the task to 'classify'."""super().__init__(cfg,overrides,_callbacks)self.args.task="classify"self._legacy_transform_name="ultralytics.yolo.data.augment.ToTensor"
postprocess(preds,img,orig_imgs)
Verarbeitet die Vorhersagen nach, um Ergebnisobjekte zurückzugeben.
Quellcode in ultralytics/models/yolo/classify/predict.py
defpostprocess(self,preds,img,orig_imgs):"""Post-processes predictions to return Results objects."""ifnotisinstance(orig_imgs,list):# input images are a torch.Tensor, not a listorig_imgs=ops.convert_torch2numpy_batch(orig_imgs)results=[]fori,predinenumerate(preds):orig_img=orig_imgs[i]img_path=self.batch[0][i]results.append(Results(orig_img,path=img_path,names=self.model.names,probs=pred))returnresults
preprocess(img)
Konvertiert das Eingabebild in einen modellkompatiblen Datentyp.
Quellcode in ultralytics/models/yolo/classify/predict.py
defpreprocess(self,img):"""Converts input image to model-compatible data type."""ifnotisinstance(img,torch.Tensor):is_legacy_transform=any(self._legacy_transform_nameinstr(transform)fortransforminself.transforms.transforms)ifis_legacy_transform:# to handle legacy transformsimg=torch.stack([self.transforms(im)foriminimg],dim=0)else:img=torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im,cv2.COLOR_BGR2RGB)))foriminimg],dim=0)img=(imgifisinstance(img,torch.Tensor)elsetorch.from_numpy(img)).to(self.model.device)returnimg.half()ifself.model.fp16elseimg.float()# uint8 to fp16/32
Erstellt am 2023-11-12, Aktualisiert am 2024-05-08 Autoren: Burhan-Q (1), glenn-jocher (3)