Handles dynamic backend selection for running inference using Ultralytics YOLO models.
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
range of formats, each with specific naming conventions as outlined below:
@torch.no_grad()def__init__(self,weights:Union[str,List[str],torch.nn.Module]="yolo11n.pt",device:torch.device=torch.device("cpu"),dnn:bool=False,data:Optional[Union[str,Path]]=None,fp16:bool=False,batch:int=1,fuse:bool=True,verbose:bool=True,):""" Initialize the AutoBackend for inference. Args: weights (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance. device (torch.device): Device to run the model on. dnn (bool): Use OpenCV DNN module for ONNX inference. data (str | Path | optional): Path to the additional data.yaml file containing class names. fp16 (bool): Enable half-precision inference. Supported only on specific backends. batch (int): Batch-size to assume for inference. fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. verbose (bool): Enable verbose logging. """super().__init__()w=str(weights[0]ifisinstance(weights,list)elseweights)nn_module=isinstance(weights,torch.nn.Module)(pt,jit,onnx,xml,engine,coreml,saved_model,pb,tflite,edgetpu,tfjs,paddle,mnn,ncnn,imx,rknn,triton,)=self._model_type(w)fp16&=ptorjitoronnxorxmlorengineornn_moduleortriton# FP16nhwc=coremlorsaved_modelorpbortfliteoredgetpuorrknn# BHWC formats (vs torch BCWH)stride,ch=32,3# default stride and channelsend2end,dynamic=False,Falsemodel,metadata,task=None,None,None# Set devicecuda=isinstance(device,torch.device)andtorch.cuda.is_available()anddevice.type!="cpu"# use CUDAifcudaandnotany([nn_module,pt,jit,engine,onnx,paddle]):# GPU dataloader formatsdevice=torch.device("cpu")cuda=False# Download if not localifnot(ptortritonornn_module):w=attempt_download_asset(w)# In-memory PyTorch modelifnn_module:model=weights.to(device)iffuse:model=model.fuse(verbose=verbose)ifhasattr(model,"kpt_shape"):kpt_shape=model.kpt_shape# pose-onlystride=max(int(model.stride.max()),32)# model stridenames=model.module.namesifhasattr(model,"module")elsemodel.names# get class namesmodel.half()iffp16elsemodel.float()ch=model.yaml.get("channels",3)self.model=model# explicitly assign for to(), cpu(), cuda(), half()pt=True# PyTorchelifpt:fromultralytics.nn.tasksimportattempt_load_weightsmodel=attempt_load_weights(weightsifisinstance(weights,list)elsew,device=device,inplace=True,fuse=fuse)ifhasattr(model,"kpt_shape"):kpt_shape=model.kpt_shape# pose-onlystride=max(int(model.stride.max()),32)# model stridenames=model.module.namesifhasattr(model,"module")elsemodel.names# get class namesmodel.half()iffp16elsemodel.float()ch=model.yaml.get("channels",3)self.model=model# explicitly assign for to(), cpu(), cuda(), half()# TorchScriptelifjit:importtorchvision# noqa - https://github.com/ultralytics/ultralytics/pull/19747LOGGER.info(f"Loading {w} for TorchScript inference...")extra_files={"config.txt":""}# model metadatamodel=torch.jit.load(w,_extra_files=extra_files,map_location=device)model.half()iffp16elsemodel.float()ifextra_files["config.txt"]:# load metadata dictmetadata=json.loads(extra_files["config.txt"],object_hook=lambdax:dict(x.items()))# ONNX OpenCV DNNelifdnn:LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")check_requirements("opencv-python>=4.5.4")net=cv2.dnn.readNetFromONNX(w)# ONNX Runtime and IMXelifonnxorimx:LOGGER.info(f"Loading {w} for ONNX Runtime inference...")check_requirements(("onnx","onnxruntime-gpu"ifcudaelse"onnxruntime"))ifIS_RASPBERRYPIorIS_JETSON:# Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetsoncheck_requirements("numpy==1.23.5")importonnxruntimeproviders=["CPUExecutionProvider"]ifcuda:if"CUDAExecutionProvider"inonnxruntime.get_available_providers():providers.insert(0,"CUDAExecutionProvider")else:# Only log warning if CUDA was requested but unavailableLOGGER.warning("Failed to start ONNX Runtime with CUDA. Using CPU...")device=torch.device("cpu")cuda=FalseLOGGER.info(f"Using ONNX Runtime {providers[0]}")ifonnx:session=onnxruntime.InferenceSession(w,providers=providers)else:check_requirements(["model-compression-toolkit>=2.3.0","sony-custom-layers[torch]>=0.3.0","onnxruntime-extensions"])w=next(Path(w).glob("*.onnx"))LOGGER.info(f"Loading {w} for ONNX IMX inference...")importmct_quantizersasmctqfromsony_custom_layers.pytorch.nmsimportnms_ort# noqasession_options=mctq.get_ort_session_options()session_options.enable_mem_reuse=False# fix the shape mismatch from onnxruntimesession=onnxruntime.InferenceSession(w,session_options,providers=["CPUExecutionProvider"])task="detect"output_names=[x.nameforxinsession.get_outputs()]metadata=session.get_modelmeta().custom_metadata_mapdynamic=isinstance(session.get_outputs()[0].shape[0],str)fp16="float16"insession.get_inputs()[0].typeifnotdynamic:io=session.io_binding()bindings=[]foroutputinsession.get_outputs():out_fp16="float16"inoutput.typey_tensor=torch.empty(output.shape,dtype=torch.float16ifout_fp16elsetorch.float32).to(device)io.bind_output(name=output.name,device_type=device.type,device_id=device.indexifcudaelse0,element_type=np.float16ifout_fp16elsenp.float32,shape=tuple(y_tensor.shape),buffer_ptr=y_tensor.data_ptr(),)bindings.append(y_tensor)# OpenVINOelifxml:LOGGER.info(f"Loading {w} for OpenVINO inference...")check_requirements("openvino>=2024.0.0")importopenvinoasovcore=ov.Core()device_name="AUTO"ifisinstance(device,str)anddevice.startswith("intel"):device_name=device.split(":")[1].upper()# Intel OpenVINO devicedevice=torch.device("cpu")ifdevice_namenotincore.available_devices:LOGGER.warning(f"OpenVINO device '{device_name}' not available. Using 'AUTO' instead.")device_name="AUTO"w=Path(w)ifnotw.is_file():# if not *.xmlw=next(w.glob("*.xml"))# get *.xml file from *_openvino_model dirov_model=core.read_model(model=str(w),weights=w.with_suffix(".bin"))ifov_model.get_parameters()[0].get_layout().empty:ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))# OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'inference_mode="CUMULATIVE_THROUGHPUT"ifbatch>1else"LATENCY"LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...")ov_compiled_model=core.compile_model(ov_model,device_name=device_name,config={"PERFORMANCE_HINT":inference_mode},)input_name=ov_compiled_model.input().get_any_name()metadata=w.parent/"metadata.yaml"# TensorRTelifengine:LOGGER.info(f"Loading {w} for TensorRT inference...")ifIS_JETSONandcheck_version(PYTHON_VERSION,"<=3.8.0"):# fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 with Python <= 3.8.0check_requirements("numpy==1.23.5")try:# https://developer.nvidia.com/nvidia-tensorrt-downloadimporttensorrtastrt# noqaexceptImportError:ifLINUX:check_requirements("tensorrt>7.0.0,!=10.1.0")importtensorrtastrt# noqacheck_version(trt.__version__,">=7.0.0",hard=True)check_version(trt.__version__,"!=10.1.0",msg="https://github.com/ultralytics/ultralytics/pull/14239")ifdevice.type=="cpu":device=torch.device("cuda:0")Binding=namedtuple("Binding",("name","dtype","shape","data","ptr"))logger=trt.Logger(trt.Logger.INFO)# Read filewithopen(w,"rb")asf,trt.Runtime(logger)asruntime:try:meta_len=int.from_bytes(f.read(4),byteorder="little")# read metadata lengthmetadata=json.loads(f.read(meta_len).decode("utf-8"))# read metadataexceptUnicodeDecodeError:f.seek(0)# engine file may lack embedded Ultralytics metadatadla=metadata.get("dla",None)ifdlaisnotNone:runtime.DLA_core=int(dla)model=runtime.deserialize_cuda_engine(f.read())# read engine# Model contexttry:context=model.create_execution_context()exceptExceptionase:# model is NoneLOGGER.error(f"TensorRT model exported with a different version than {trt.__version__}\n")raiseebindings=OrderedDict()output_names=[]fp16=False# default updated belowdynamic=Falseis_trt10=nothasattr(model,"num_bindings")num=range(model.num_io_tensors)ifis_trt10elserange(model.num_bindings)foriinnum:ifis_trt10:name=model.get_tensor_name(i)dtype=trt.nptype(model.get_tensor_dtype(name))is_input=model.get_tensor_mode(name)==trt.TensorIOMode.INPUTifis_input:if-1intuple(model.get_tensor_shape(name)):dynamic=Truecontext.set_input_shape(name,tuple(model.get_tensor_profile_shape(name,0)[1]))ifdtype==np.float16:fp16=Trueelse:output_names.append(name)shape=tuple(context.get_tensor_shape(name))else:# TensorRT < 10.0name=model.get_binding_name(i)dtype=trt.nptype(model.get_binding_dtype(i))is_input=model.binding_is_input(i)ifmodel.binding_is_input(i):if-1intuple(model.get_binding_shape(i)):# dynamicdynamic=Truecontext.set_binding_shape(i,tuple(model.get_profile_shape(0,i)[1]))ifdtype==np.float16:fp16=Trueelse:output_names.append(name)shape=tuple(context.get_binding_shape(i))im=torch.from_numpy(np.empty(shape,dtype=dtype)).to(device)bindings[name]=Binding(name,dtype,shape,im,int(im.data_ptr()))binding_addrs=OrderedDict((n,d.ptr)forn,dinbindings.items())batch_size=bindings["images"].shape[0]# if dynamic, this is instead max batch size# CoreMLelifcoreml:LOGGER.info(f"Loading {w} for CoreML inference...")importcoremltoolsasctmodel=ct.models.MLModel(w)metadata=dict(model.user_defined_metadata)# TF SavedModelelifsaved_model:LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")importtensorflowastfkeras=False# assume TF1 saved_modelmodel=tf.keras.models.load_model(w)ifkeraselsetf.saved_model.load(w)metadata=Path(w)/"metadata.yaml"# TF GraphDefelifpb:# https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxtLOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")importtensorflowastffromultralytics.engine.exporterimportgd_outputsdefwrap_frozen_graph(gd,inputs,outputs):"""Wrap frozen graphs for deployment."""x=tf.compat.v1.wrap_function(lambda:tf.compat.v1.import_graph_def(gd,name=""),[])# wrappedge=x.graph.as_graph_elementreturnx.prune(tf.nest.map_structure(ge,inputs),tf.nest.map_structure(ge,outputs))gd=tf.Graph().as_graph_def()# TF GraphDefwithopen(w,"rb")asf:gd.ParseFromString(f.read())frozen_func=wrap_frozen_graph(gd,inputs="x:0",outputs=gd_outputs(gd))try:# find metadata in SavedModel alongside GraphDefmetadata=next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))exceptStopIteration:pass# TFLite or TFLite Edge TPUeliftfliteoredgetpu:# https://ai.google.dev/edge/litert/microcontrollers/pythontry:# https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpufromtflite_runtime.interpreterimportInterpreter,load_delegateexceptImportError:importtensorflowastfInterpreter,load_delegate=tf.lite.Interpreter,tf.lite.experimental.load_delegateifedgetpu:# TF Edge TPU https://coral.ai/software/#edgetpu-runtimedevice=device[3:]ifstr(device).startswith("tpu")else":0"LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...")delegate={"Linux":"libedgetpu.so.1","Darwin":"libedgetpu.1.dylib","Windows":"edgetpu.dll"}[platform.system()]interpreter=Interpreter(model_path=w,experimental_delegates=[load_delegate(delegate,options={"device":device})],)device="cpu"# Required, otherwise PyTorch will try to use the wrong deviceelse:# TFLiteLOGGER.info(f"Loading {w} for TensorFlow Lite inference...")interpreter=Interpreter(model_path=w)# load TFLite modelinterpreter.allocate_tensors()# allocateinput_details=interpreter.get_input_details()# inputsoutput_details=interpreter.get_output_details()# outputs# Load metadatatry:withzipfile.ZipFile(w,"r")asmodel:meta_file=model.namelist()[0]metadata=ast.literal_eval(model.read(meta_file).decode("utf-8"))exceptzipfile.BadZipFile:pass# TF.jseliftfjs:raiseNotImplementedError("YOLOv8 TF.js inference is not currently supported.")# PaddlePaddleelifpaddle:LOGGER.info(f"Loading {w} for PaddlePaddle inference...")check_requirements("paddlepaddle-gpu"ifcudaelse"paddlepaddle>=3.0.0")importpaddle.inferenceaspdi# noqaw=Path(w)model_file,params_file=None,Noneifw.is_dir():model_file=next(w.rglob("*.json"),None)params_file=next(w.rglob("*.pdiparams"),None)elifw.suffix==".pdiparams":model_file=w.with_name("model.json")params_file=wifnot(model_fileandparams_fileandmodel_file.is_file()andparams_file.is_file()):raiseFileNotFoundError(f"Paddle model not found in {w}. Both .json and .pdiparams files are required.")config=pdi.Config(str(model_file),str(params_file))ifcuda:config.enable_use_gpu(memory_pool_init_size_mb=2048,device_id=0)predictor=pdi.create_predictor(config)input_handle=predictor.get_input_handle(predictor.get_input_names()[0])output_names=predictor.get_output_names()metadata=w/"metadata.yaml"# MNNelifmnn:LOGGER.info(f"Loading {w} for MNN inference...")check_requirements("MNN")# requires MNNimportosimportMNNconfig={"precision":"low","backend":"CPU","numThread":(os.cpu_count()+1)//2}rt=MNN.nn.create_runtime_manager((config,))net=MNN.nn.load_module_from_file(w,[],[],runtime_manager=rt,rearrange=True)deftorch_to_mnn(x):returnMNN.expr.const(x.data_ptr(),x.shape)metadata=json.loads(net.get_info()["bizCode"])# NCNNelifncnn:LOGGER.info(f"Loading {w} for NCNN inference...")check_requirements("git+https://github.com/Tencent/ncnn.git"ifARM64else"ncnn")# requires NCNNimportncnnaspyncnnnet=pyncnn.Net()net.opt.use_vulkan_compute=cudaw=Path(w)ifnotw.is_file():# if not *.paramw=next(w.glob("*.param"))# get *.param file from *_ncnn_model dirnet.load_param(str(w))net.load_model(str(w.with_suffix(".bin")))metadata=w.parent/"metadata.yaml"# NVIDIA Triton Inference Servereliftriton:check_requirements("tritonclient[all]")fromultralytics.utils.tritonimportTritonRemoteModelmodel=TritonRemoteModel(w)metadata=model.metadata# RKNNelifrknn:ifnotis_rockchip():raiseOSError("RKNN inference is only supported on Rockchip devices.")LOGGER.info(f"Loading {w} for RKNN inference...")check_requirements("rknn-toolkit-lite2")fromrknnlite.apiimportRKNNLitew=Path(w)ifnotw.is_file():# if not *.rknnw=next(w.rglob("*.rknn"))# get *.rknn file from *_rknn_model dirrknn_model=RKNNLite()rknn_model.load_rknn(str(w))rknn_model.init_runtime()metadata=w.parent/"metadata.yaml"# Any other format (unsupported)else:fromultralytics.engine.exporterimportexport_formatsraiseTypeError(f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"f"See https://docs.ultralytics.com/modes/predict for help.")# Load external metadata YAMLifisinstance(metadata,(str,Path))andPath(metadata).exists():metadata=yaml_load(metadata)ifmetadataandisinstance(metadata,dict):fork,vinmetadata.items():ifkin{"stride","batch","channels"}:metadata[k]=int(v)elifkin{"imgsz","names","kpt_shape","args"}andisinstance(v,str):metadata[k]=eval(v)stride=metadata["stride"]task=metadata["task"]batch=metadata["batch"]imgsz=metadata["imgsz"]names=metadata["names"]kpt_shape=metadata.get("kpt_shape")end2end=metadata.get("args",{}).get("nms",False)dynamic=metadata.get("args",{}).get("dynamic",dynamic)ch=metadata.get("channels",3)elifnot(ptortritonornn_module):LOGGER.warning(f"Metadata not found for 'model={weights}'")# Check namesif"names"notinlocals():# names missingnames=default_class_names(data)names=check_class_names(names)# Disable gradientsifpt:forpinmodel.parameters():p.requires_grad=Falseself.__dict__.update(locals())# assign all variables to self
defforward(self,im,augment=False,visualize=False,embed=None,**kwargs):""" Runs inference on the YOLOv8 MultiBackend model. Args: im (torch.Tensor): The image tensor to perform inference on. augment (bool): Whether to perform data augmentation during inference. visualize (bool): Whether to visualize the output predictions. embed (list | None): A list of feature vectors/embeddings to return. **kwargs (Any): Additional keyword arguments for model configuration. Returns: (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model. """b,ch,h,w=im.shape# batch, channel, height, widthifself.fp16andim.dtype!=torch.float16:im=im.half()# to FP16ifself.nhwc:im=im.permute(0,2,3,1)# torch BCHW to numpy BHWC shape(1,320,192,3)# PyTorchifself.ptorself.nn_module:y=self.model(im,augment=augment,visualize=visualize,embed=embed,**kwargs)# TorchScriptelifself.jit:y=self.model(im)# ONNX OpenCV DNNelifself.dnn:im=im.cpu().numpy()# torch to numpyself.net.setInput(im)y=self.net.forward()# ONNX Runtimeelifself.onnxorself.imx:ifself.dynamic:im=im.cpu().numpy()# torch to numpyy=self.session.run(self.output_names,{self.session.get_inputs()[0].name:im})else:ifnotself.cuda:im=im.cpu()self.io.bind_input(name="images",device_type=im.device.type,device_id=im.device.indexifim.device.type=="cuda"else0,element_type=np.float16ifself.fp16elsenp.float32,shape=tuple(im.shape),buffer_ptr=im.data_ptr(),)self.session.run_with_iobinding(self.io)y=self.bindingsifself.imx:# boxes, conf, clsy=np.concatenate([y[0],y[1][:,:,None],y[2][:,:,None]],axis=-1)# OpenVINOelifself.xml:im=im.cpu().numpy()# FP32ifself.inference_modein{"THROUGHPUT","CUMULATIVE_THROUGHPUT"}:# optimized for larger batch-sizesn=im.shape[0]# number of images in batchresults=[None]*n# preallocate list with None to match the number of imagesdefcallback(request,userdata):"""Places result in preallocated list using userdata index."""results[userdata]=request.results# Create AsyncInferQueue, set the callback and start asynchronous inference for each input imageasync_queue=self.ov.AsyncInferQueue(self.ov_compiled_model)async_queue.set_callback(callback)foriinrange(n):# Start async inference with userdata=i to specify the position in results listasync_queue.start_async(inputs={self.input_name:im[i:i+1]},userdata=i)# keep image as BCHWasync_queue.wait_all()# wait for all inference requests to completey=np.concatenate([list(r.values())[0]forrinresults])else:# inference_mode = "LATENCY", optimized for fastest first result at batch-size 1y=list(self.ov_compiled_model(im).values())# TensorRTelifself.engine:ifself.dynamicandim.shape!=self.bindings["images"].shape:ifself.is_trt10:self.context.set_input_shape("images",im.shape)self.bindings["images"]=self.bindings["images"]._replace(shape=im.shape)fornameinself.output_names:self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))else:i=self.model.get_binding_index("images")self.context.set_binding_shape(i,im.shape)self.bindings["images"]=self.bindings["images"]._replace(shape=im.shape)fornameinself.output_names:i=self.model.get_binding_index(name)self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))s=self.bindings["images"].shapeassertim.shape==s,f"input size {im.shape}{'>'ifself.dynamicelse'not equal to'} max model size {s}"self.binding_addrs["images"]=int(im.data_ptr())self.context.execute_v2(list(self.binding_addrs.values()))y=[self.bindings[x].dataforxinsorted(self.output_names)]# CoreMLelifself.coreml:im=im[0].cpu().numpy()im_pil=Image.fromarray((im*255).astype("uint8"))# im = im.resize((192, 320), Image.BILINEAR)y=self.model.predict({"image":im_pil})# coordinates are xywh normalizedif"confidence"iny:raiseTypeError("Ultralytics only supports inference of non-pipelined CoreML models exported with "f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")# TODO: CoreML NMS inference handling# from ultralytics.utils.ops import xywh2xyxy# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels# conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)# y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)y=list(y.values())iflen(y)==2andlen(y[1].shape)!=4:# segmentation modely=list(reversed(y))# reversed for segmentation models (pred, proto)# PaddlePaddleelifself.paddle:im=im.cpu().numpy().astype(np.float32)self.input_handle.copy_from_cpu(im)self.predictor.run()y=[self.predictor.get_output_handle(x).copy_to_cpu()forxinself.output_names]# MNNelifself.mnn:input_var=self.torch_to_mnn(im)output_var=self.net.onForward([input_var])y=[x.read()forxinoutput_var]# NCNNelifself.ncnn:mat_in=self.pyncnn.Mat(im[0].cpu().numpy())withself.net.create_extractor()asex:ex.input(self.net.input_names()[0],mat_in)# WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130y=[np.array(ex.extract(x)[1])[None]forxinsorted(self.net.output_names())]# NVIDIA Triton Inference Serverelifself.triton:im=im.cpu().numpy()# torch to numpyy=self.model(im)# RKNNelifself.rknn:im=(im.cpu().numpy()*255).astype("uint8")im=imifisinstance(im,(list,tuple))else[im]y=self.rknn_model.inference(inputs=im)# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)else:im=im.cpu().numpy()ifself.saved_model:# SavedModely=self.model(im,training=False)ifself.keraselseself.model(im)ifnotisinstance(y,list):y=[y]elifself.pb:# GraphDefy=self.frozen_func(x=self.tf.constant(im))else:# Lite or Edge TPUdetails=self.input_details[0]is_int=details["dtype"]in{np.int8,np.int16}# is TFLite quantized int8 or int16 modelifis_int:scale,zero_point=details["quantization"]im=(im/scale+zero_point).astype(details["dtype"])# de-scaleself.interpreter.set_tensor(details["index"],im)self.interpreter.invoke()y=[]foroutputinself.output_details:x=self.interpreter.get_tensor(output["index"])ifis_int:scale,zero_point=output["quantization"]x=(x.astype(np.float32)-zero_point)*scale# re-scaleifx.ndim==3:# if task is not classification, excluding masks (ndim=4) as well# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer modelsifx.shape[-1]==6orself.end2end:# end-to-end modelx[:,:,[0,2]]*=wx[:,:,[1,3]]*=hifself.task=="pose":x[:,:,6::3]*=wx[:,:,7::3]*=helse:x[:,[0,2]]*=wx[:,[1,3]]*=hifself.task=="pose":x[:,5::3]*=wx[:,6::3]*=hy.append(x)# TF segment fixes: export is reversed vs ONNX export and protos are transposediflen(y)==2:# segment with (det, proto) output order reversediflen(y[1].shape)!=4:y=list(reversed(y))# should be y = (1, 116, 8400), (1, 160, 160, 32)ify[1].shape[-1]==6:# end-to-end modely=[y[1]]else:y[1]=np.transpose(y[1],(0,3,1,2))# should be y = (1, 116, 8400), (1, 32, 160, 160)y=[xifisinstance(x,np.ndarray)elsex.numpy()forxiny]# for x in y:# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapesifisinstance(y,(list,tuple)):iflen(self.names)==999and(self.task=="segment"orlen(y)==2):# segments and names not definednc=y[0].shape[1]-y[1].shape[1]-4# y = (1, 32, 160, 160), (1, 116, 8400)self.names={i:f"class{i}"foriinrange(nc)}returnself.from_numpy(y[0])iflen(y)==1else[self.from_numpy(x)forxiny]else:returnself.from_numpy(y)
from_numpy
from_numpy(x)
Convert a numpy array to a tensor.
Parameters:
Name
Type
Description
Default
x
ndarray
The array to be converted.
required
Returns:
Type
Description
Tensor
The converted tensor
Source code in ultralytics/nn/autobackend.py
782783784785786787788789790791792
deffrom_numpy(self,x):""" Convert a numpy array to a tensor. Args: x (np.ndarray): The array to be converted. Returns: (torch.Tensor): The converted tensor """returntorch.tensor(x).to(self.device)ifisinstance(x,np.ndarray)elsex
warmup
warmup(imgsz=(1,3,640,640))
Warm up the model by running one forward pass with a dummy input.
Parameters:
Name
Type
Description
Default
imgsz
tuple
The shape of the dummy input tensor in the format (batch_size, channels, height, width)
(1, 3, 640, 640)
Source code in ultralytics/nn/autobackend.py
794795796797798799800801802803804805806807
defwarmup(self,imgsz=(1,3,640,640)):""" Warm up the model by running one forward pass with a dummy input. Args: imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) """importtorchvision# noqa (import here so torchvision import time not recorded in postprocess time)warmup_types=self.pt,self.jit,self.onnx,self.engine,self.saved_model,self.pb,self.triton,self.nn_moduleifany(warmup_types)and(self.device.type!="cpu"orself.triton):im=torch.empty(*imgsz,dtype=torch.halfifself.fp16elsetorch.float,device=self.device)# inputfor_inrange(2ifself.jitelse1):self.forward(im)# warmup
ultralytics.nn.autobackend.check_class_names
check_class_names(names)
Check class names and convert to dict format if needed.
Source code in ultralytics/nn/autobackend.py
2223242526272829303132333435363738
defcheck_class_names(names):"""Check class names and convert to dict format if needed."""ifisinstance(names,list):# names is a listnames=dict(enumerate(names))# convert to dictifisinstance(names,dict):# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'names={int(k):str(v)fork,vinnames.items()}n=len(names)ifmax(names.keys())>=n:raiseKeyError(f"{n}-class dataset requires class indices 0-{n-1}, but you have invalid class indices "f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.")ifisinstance(names[0],str)andnames[0].startswith("n0"):# imagenet class codes, i.e. 'n01440764'names_map=yaml_load(ROOT/"cfg/datasets/ImageNet.yaml")["map"]# human-readable namesnames={k:names_map[v]fork,vinnames.items()}returnnames
ultralytics.nn.autobackend.default_class_names
default_class_names(data=None)
Applies default class names to an input YAML file or returns numerical class names.
Source code in ultralytics/nn/autobackend.py
4142434445464748
defdefault_class_names(data=None):"""Applies default class names to an input YAML file or returns numerical class names."""ifdata:try:returnyaml_load(check_yaml(data))["names"]exceptException:passreturn{i:f"class{i}"foriinrange(999)}# return default if above errors