defon_predict_start(predictor:object,persist:bool=False)->None:""" Initialize trackers for object tracking during prediction. Args: predictor (object): The predictor object to initialize trackers for. persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. Raises: AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. """ifhasattr(predictor,"trackers")andpersist:returntracker=check_yaml(predictor.args.tracker)cfg=IterableSimpleNamespace(**yaml_load(tracker))ifcfg.tracker_typenotin{"bytetrack","botsort"}:raiseAssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")trackers=[]for_inrange(predictor.dataset.bs):tracker=TRACKER_MAP[cfg.tracker_type](args=cfg,frame_rate=30)trackers.append(tracker)ifpredictor.dataset.mode!="stream":# only need one tracker for other modes.breakpredictor.trackers=trackerspredictor.vid_path=[None]*predictor.dataset.bs# for determining when to reset tracker on new video
defon_predict_postprocess_end(predictor:object,persist:bool=False)->None:""" Postprocess detected boxes and update with object tracking. Args: predictor (object): The predictor object containing the predictions. persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. """path,im0s=predictor.batch[:2]is_obb=predictor.args.task=="obb"is_stream=predictor.dataset.mode=="stream"foriinrange(len(im0s)):tracker=predictor.trackers[iifis_streamelse0]vid_path=predictor.save_dir/Path(path[i]).nameifnotpersistandpredictor.vid_path[iifis_streamelse0]!=vid_path:tracker.reset()predictor.vid_path[iifis_streamelse0]=vid_pathdet=(predictor.results[i].obbifis_obbelsepredictor.results[i].boxes).cpu().numpy()iflen(det)==0:continuetracks=tracker.update(det,im0s[i])iflen(tracks)==0:continueidx=tracks[:,-1].astype(int)predictor.results[i]=predictor.results[i][idx]update_args=dict()update_args["obb"ifis_obbelse"boxes"]=torch.as_tensor(tracks[:,:-1])predictor.results[i].update(**update_args)
defregister_tracker(model:object,persist:bool)->None:""" Register tracking callbacks to the model for object tracking during prediction. Args: model (object): The model object to register tracking callbacks for. persist (bool): Whether to persist the trackers if they already exist. """model.add_callback("on_predict_start",partial(on_predict_start,persist=persist))model.add_callback("on_predict_postprocess_end",partial(on_predict_postprocess_end,persist=persist))