classNAS(Model):def__init__(self,model='yolo_nas_s.pt')->None:assertPath(model).suffixnotin('.yaml','.yml'),'YOLO-NAS models only support pre-trained models.'super().__init__(model,task='detect')@smart_inference_mode()def_load(self,weights:str,task:str):# Load or create new NAS modelimportsuper_gradientssuffix=Path(weights).suffixifsuffix=='.pt':self.model=torch.load(weights)elifsuffix=='':self.model=super_gradients.training.models.get(weights,pretrained_weights='coco')# Standardize modelself.model.fuse=lambdaverbose=True:self.modelself.model.stride=torch.tensor([32])self.model.names=dict(enumerate(self.model._class_names))self.model.is_fused=lambda:False# for info()self.model.yaml={}# for info()self.model.pt_path=weights# for export()self.model.task='detect'# for export()definfo(self,detailed=False,verbose=True):""" Logs model info. Args: detailed (bool): Show detailed information about model. verbose (bool): Controls verbosity. """returnmodel_info(self.model,detailed=detailed,verbose=verbose,imgsz=640)@propertydeftask_map(self):return{'detect':{'predictor':NASPredictor,'validator':NASValidator}}
definfo(self,detailed=False,verbose=True):""" Logs model info. Args: detailed (bool): Show detailed information about model. verbose (bool): Controls verbosity. """returnmodel_info(self.model,detailed=detailed,verbose=verbose,imgsz=640)