A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings.
It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and
running validation with different prompt types.
Attributes:
Name
Type
Description
device
device
The device on which validation is performed.
args
namespace
Configuration arguments for validation.
dataloader
DataLoader
DataLoader for validation data.
Source code in ultralytics/models/yolo/detect/val.py
def__init__(self,dataloader=None,save_dir=None,pbar=None,args=None,_callbacks=None):""" Initialize detection validator with necessary variables and settings. Args: dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation. save_dir (Path, optional): Directory to save results. pbar (Any, optional): Progress bar for displaying progress. args (dict, optional): Arguments for the validator. _callbacks (list, optional): List of callback functions. """super().__init__(dataloader,save_dir,pbar,args,_callbacks)self.nt_per_class=Noneself.nt_per_image=Noneself.is_coco=Falseself.is_lvis=Falseself.class_map=Noneself.args.task="detect"self.metrics=DetMetrics(save_dir=self.save_dir)self.iouv=torch.linspace(0.5,0.95,10)# IoU vector for mAP@0.5:0.95self.niou=self.iouv.numel()
Run validation on the model using either text or visual prompt embeddings.
This method validates the model using either text prompts or visual prompts, depending
on the load_vp flag. It supports validation during training (using a trainer object)
or standalone validation with a provided model.
@smart_inference_mode()def__call__(self,trainer=None,model=None,refer_data=None,load_vp=False):""" Run validation on the model using either text or visual prompt embeddings. This method validates the model using either text prompts or visual prompts, depending on the `load_vp` flag. It supports validation during training (using a trainer object) or standalone validation with a provided model. Args: trainer (object, optional): Trainer object containing the model and device. model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided. refer_data (str, optional): Path to reference data for visual prompts. load_vp (bool): Whether to load visual prompts. If False, text prompts are used. Returns: (dict): Validation statistics containing metrics computed during validation. """iftrainerisnotNone:self.device=trainer.devicemodel=trainer.ema.emanames=[name.split("/")[0]fornameinlist(self.dataloader.dataset.data["names"].values())]ifload_vp:LOGGER.info("Validate using the visual prompt.")self.args.half=False# Directly use the same dataloader for visual embeddings extracted during trainingvpe=self.get_visual_pe(self.dataloader,model)model.set_classes(names,vpe)else:LOGGER.info("Validate using the text prompt.")tpe=model.get_text_pe(names)model.set_classes(names,tpe)stats=super().__call__(trainer,model)else:ifrefer_dataisnotNone:assertload_vp,"Refer data is only used for visual prompt validation."self.device=select_device(self.args.device)ifisinstance(model,str):fromultralytics.nn.tasksimportattempt_load_weightsmodel=attempt_load_weights(model,device=self.device,inplace=True)model.eval().to(self.device)data=check_det_dataset(refer_dataorself.args.data)names=[name.split("/")[0]fornameinlist(data["names"].values())]ifload_vp:LOGGER.info("Validate using the visual prompt.")self.args.half=False# TODO: need to check if the names from refer data is consistent with the evaluated dataset# could use same dataset or refer to extract visual prompt embeddingsdataloader=self.get_vpe_dataloader(data)vpe=self.get_visual_pe(dataloader,model)model.set_classes(names,vpe)stats=super().__call__(model=deepcopy(model))elifisinstance(model.model[-1],YOLOEDetect)andhasattr(model.model[-1],"lrpc"):# prompt-freereturnsuper().__call__(trainer,model)else:LOGGER.info("Validate using the text prompt.")tpe=model.get_text_pe(names)model.set_classes(names,tpe)stats=super().__call__(model=deepcopy(model))returnstats
get_visual_pe
get_visual_pe(dataloader,model)
Extract visual prompt embeddings from training samples.
This function processes a dataloader to compute visual prompt embeddings for each class
using a YOLOE model. It normalizes the embeddings and handles cases where no samples
exist for a class.
@smart_inference_mode()defget_visual_pe(self,dataloader,model):""" Extract visual prompt embeddings from training samples. This function processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It normalizes the embeddings and handles cases where no samples exist for a class. Args: dataloader (torch.utils.data.DataLoader): The dataloader providing training samples. model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings. Returns: (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim). """assertisinstance(model,YOLOEModel)names=[name.split("/")[0]fornameinlist(dataloader.dataset.data["names"].values())]visual_pe=torch.zeros(len(names),model.model[-1].embed,device=self.device)cls_visual_num=torch.zeros(len(names))desc="Get visual prompt embeddings from samples"forbatchindataloader:cls=batch["cls"].squeeze(-1).to(torch.int).unique()count=torch.bincount(cls,minlength=len(names))cls_visual_num+=countcls_visual_num=cls_visual_num.to(self.device)pbar=TQDM(dataloader,total=len(dataloader),desc=desc)forbatchinpbar:batch=self.preprocess(batch)preds=model.get_visual_pe(batch["img"],visual=batch["visuals"])# (B, max_n, embed_dim)batch_idx=batch["batch_idx"]foriinrange(preds.shape[0]):cls=batch["cls"][batch_idx==i].squeeze(-1).to(torch.int).unique(sorted=True)pad_cls=torch.ones(preds.shape[1],device=self.device)*-1pad_cls[:len(cls)]=clsforcincls:visual_pe[c]+=preds[i][pad_cls==c].sum(0)/cls_visual_num[c]visual_pe[cls_visual_num!=0]=F.normalize(visual_pe[cls_visual_num!=0],dim=-1,p=2)visual_pe[cls_visual_num==0]=0returnvisual_pe.unsqueeze(0)
get_vpe_dataloader
get_vpe_dataloader(data)
Create a dataloader for LVIS training visual prompt samples.
This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
It applies necessary transformations and configurations to the dataset and returns a dataloader
for validation purposes.
Parameters:
Name
Type
Description
Default
data
dict
Dataset configuration dictionary containing paths and settings.
required
Returns:
Type
Description
DataLoader
The dataLoader for visual prompt samples.
Source code in ultralytics/models/yolo/yoloe/val.py
defget_vpe_dataloader(self,data):""" Create a dataloader for LVIS training visual prompt samples. This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset. It applies necessary transformations and configurations to the dataset and returns a dataloader for validation purposes. Args: data (dict): Dataset configuration dictionary containing paths and settings. Returns: (torch.utils.data.DataLoader): The dataLoader for visual prompt samples. """dataset=build_yolo_dataset(self.args,data.get(self.args.split,data.get("val")),self.args.batch,data,mode="val",rect=False,)ifisinstance(dataset,YOLOConcatDataset):fordindataset.datasets:d.transforms.append(LoadVisualPrompt())else:dataset.transforms.append(LoadVisualPrompt())returnbuild_dataloader(dataset,self.args.batch,self.args.workers,shuffle=False,rank=-1,)
preprocess
preprocess(batch)
Preprocess batch data, ensuring visuals are on the same device as images.
Source code in ultralytics/models/yolo/yoloe/val.py
defpreprocess(self,batch):"""Preprocess batch data, ensuring visuals are on the same device as images."""batch=super().preprocess(batch)if"visuals"inbatch:batch["visuals"]=batch["visuals"].to(batch["img"].device)returnbatch
def__init__(self,dataloader=None,save_dir=None,pbar=None,args=None,_callbacks=None):""" Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics. Args: dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation. save_dir (Path, optional): Directory to save results. pbar (Any, optional): Progress bar for displaying progress. args (namespace, optional): Arguments for the validator. _callbacks (list, optional): List of callback functions. """super().__init__(dataloader,save_dir,pbar,args,_callbacks)self.plot_masks=Noneself.process=Noneself.args.task="segment"self.metrics=SegmentMetrics(save_dir=self.save_dir)