Vai al contenuto

Riferimento per ultralytics/trackers/bot_sort.py

Nota

Questo file è disponibile all'indirizzo https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/trackers/bot_sort .py. Se riscontri un problema, contribuisci a risolverlo inviando una Pull Request 🛠️. Grazie 🙏!



ultralytics.trackers.bot_sort.BOTrack

Basi: STrack

Una versione estesa della classe STrack per YOLOv8, che aggiunge funzioni di tracciamento degli oggetti.

Attributi:

Nome Tipo Descrizione
shared_kalman KalmanFilterXYWH

Un filtro Kalman condiviso per tutte le istanze di BOTrack.

smooth_feat ndarray

Vettore di caratteristiche smussate.

curr_feat ndarray

Vettore di caratteristiche corrente.

features deque

Un deque per memorizzare vettori di caratteristiche con una lunghezza massima definita da feat_history.

alpha float

Fattore di lisciatura per la media mobile esponenziale delle caratteristiche.

mean ndarray

Lo stato medio del filtro di Kalman.

covariance ndarray

La matrice di covarianza del filtro di Kalman.

Metodi:

Nome Descrizione
update_features

Aggiorna il vettore delle caratteristiche e lo smussa utilizzando la media mobile esponenziale.

predict

Prevede la media e la covarianza utilizzando il filtro di Kalman.

re_activate

Riattiva un brano con caratteristiche aggiornate e, facoltativamente, un nuovo ID.

update

Aggiorna l'istanza di YOLOv8 con il nuovo ID della traccia e del fotogramma.

tlwh

Proprietà che ottiene la posizione corrente in formato tlwh (top left x, top left y, width, height).

multi_predict

Prevede la media e la covarianza di più tracce di oggetti utilizzando un filtro di Kalman condiviso.

convert_coords

Converte le coordinate del rettangolo di selezione tlwh in formato xywh.

tlwh_to_xywh

Convertire il rettangolo di selezione in formato xywh (center x, center y, width, height).

Utilizzo

bo_track = BOTrack(tlwh, score, cls, feat) bo_track.predict() bo_track.update(new_track, frame_id)

Codice sorgente in ultralytics/trackers/bot_sort.py
class BOTrack(STrack):
    """
    An extended version of the STrack class for YOLOv8, adding object tracking features.

    Attributes:
        shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
        smooth_feat (np.ndarray): Smoothed feature vector.
        curr_feat (np.ndarray): Current feature vector.
        features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
        alpha (float): Smoothing factor for the exponential moving average of features.
        mean (np.ndarray): The mean state of the Kalman filter.
        covariance (np.ndarray): The covariance matrix of the Kalman filter.

    Methods:
        update_features(feat): Update features vector and smooth it using exponential moving average.
        predict(): Predicts the mean and covariance using Kalman filter.
        re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
        update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
        tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
        multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
        convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
        tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.

    Usage:
        bo_track = BOTrack(tlwh, score, cls, feat)
        bo_track.predict()
        bo_track.update(new_track, frame_id)
    """

    shared_kalman = KalmanFilterXYWH()

    def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
        """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
        super().__init__(tlwh, score, cls)

        self.smooth_feat = None
        self.curr_feat = None
        if feat is not None:
            self.update_features(feat)
        self.features = deque([], maxlen=feat_history)
        self.alpha = 0.9

    def update_features(self, feat):
        """Update features vector and smooth it using exponential moving average."""
        feat /= np.linalg.norm(feat)
        self.curr_feat = feat
        if self.smooth_feat is None:
            self.smooth_feat = feat
        else:
            self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
        self.features.append(feat)
        self.smooth_feat /= np.linalg.norm(self.smooth_feat)

    def predict(self):
        """Predicts the mean and covariance using Kalman filter."""
        mean_state = self.mean.copy()
        if self.state != TrackState.Tracked:
            mean_state[6] = 0
            mean_state[7] = 0

        self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

    def re_activate(self, new_track, frame_id, new_id=False):
        """Reactivates a track with updated features and optionally assigns a new ID."""
        if new_track.curr_feat is not None:
            self.update_features(new_track.curr_feat)
        super().re_activate(new_track, frame_id, new_id)

    def update(self, new_track, frame_id):
        """Update the YOLOv8 instance with new track and frame ID."""
        if new_track.curr_feat is not None:
            self.update_features(new_track.curr_feat)
        super().update(new_track, frame_id)

    @property
    def tlwh(self):
        """Get current position in bounding box format `(top left x, top left y, width, height)`."""
        if self.mean is None:
            return self._tlwh.copy()
        ret = self.mean[:4].copy()
        ret[:2] -= ret[2:] / 2
        return ret

    @staticmethod
    def multi_predict(stracks):
        """Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
        if len(stracks) <= 0:
            return
        multi_mean = np.asarray([st.mean.copy() for st in stracks])
        multi_covariance = np.asarray([st.covariance for st in stracks])
        for i, st in enumerate(stracks):
            if st.state != TrackState.Tracked:
                multi_mean[i][6] = 0
                multi_mean[i][7] = 0
        multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
        for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
            stracks[i].mean = mean
            stracks[i].covariance = cov

    def convert_coords(self, tlwh):
        """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
        return self.tlwh_to_xywh(tlwh)

    @staticmethod
    def tlwh_to_xywh(tlwh):
        """Convert bounding box to format `(center x, center y, width, height)`."""
        ret = np.asarray(tlwh).copy()
        ret[:2] += ret[2:] / 2
        return ret

tlwh property

Ottieni la posizione corrente in formato bounding box (top left x, top left y, width, height).

__init__(tlwh, score, cls, feat=None, feat_history=50)

Inizializza l'oggetto YOLOv8 con i parametri temporali, come la storia delle caratteristiche, l'alfa e le caratteristiche correnti.

Codice sorgente in ultralytics/trackers/bot_sort.py
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
    """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
    super().__init__(tlwh, score, cls)

    self.smooth_feat = None
    self.curr_feat = None
    if feat is not None:
        self.update_features(feat)
    self.features = deque([], maxlen=feat_history)
    self.alpha = 0.9

convert_coords(tlwh)

Converte le coordinate del rettangolo di selezione Top-Left-Width-Height nel formato X-Y-Width-Height.

Codice sorgente in ultralytics/trackers/bot_sort.py
def convert_coords(self, tlwh):
    """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
    return self.tlwh_to_xywh(tlwh)

multi_predict(stracks) staticmethod

Prevede la media e la covarianza di più tracce di oggetti utilizzando un filtro di Kalman condiviso.

Codice sorgente in ultralytics/trackers/bot_sort.py
@staticmethod
def multi_predict(stracks):
    """Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
    if len(stracks) <= 0:
        return
    multi_mean = np.asarray([st.mean.copy() for st in stracks])
    multi_covariance = np.asarray([st.covariance for st in stracks])
    for i, st in enumerate(stracks):
        if st.state != TrackState.Tracked:
            multi_mean[i][6] = 0
            multi_mean[i][7] = 0
    multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
    for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
        stracks[i].mean = mean
        stracks[i].covariance = cov

predict()

Prevede la media e la covarianza utilizzando il filtro di Kalman.

Codice sorgente in ultralytics/trackers/bot_sort.py
def predict(self):
    """Predicts the mean and covariance using Kalman filter."""
    mean_state = self.mean.copy()
    if self.state != TrackState.Tracked:
        mean_state[6] = 0
        mean_state[7] = 0

    self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

re_activate(new_track, frame_id, new_id=False)

Riattiva un tracciato con caratteristiche aggiornate e assegna facoltativamente un nuovo ID.

Codice sorgente in ultralytics/trackers/bot_sort.py
def re_activate(self, new_track, frame_id, new_id=False):
    """Reactivates a track with updated features and optionally assigns a new ID."""
    if new_track.curr_feat is not None:
        self.update_features(new_track.curr_feat)
    super().re_activate(new_track, frame_id, new_id)

tlwh_to_xywh(tlwh) staticmethod

Convertire il rettangolo di selezione in formato (center x, center y, width, height).

Codice sorgente in ultralytics/trackers/bot_sort.py
@staticmethod
def tlwh_to_xywh(tlwh):
    """Convert bounding box to format `(center x, center y, width, height)`."""
    ret = np.asarray(tlwh).copy()
    ret[:2] += ret[2:] / 2
    return ret

update(new_track, frame_id)

Aggiorna l'istanza di YOLOv8 con il nuovo ID della traccia e del fotogramma.

Codice sorgente in ultralytics/trackers/bot_sort.py
def update(self, new_track, frame_id):
    """Update the YOLOv8 instance with new track and frame ID."""
    if new_track.curr_feat is not None:
        self.update_features(new_track.curr_feat)
    super().update(new_track, frame_id)

update_features(feat)

Aggiorna il vettore delle caratteristiche e lo smussa utilizzando la media mobile esponenziale.

Codice sorgente in ultralytics/trackers/bot_sort.py
def update_features(self, feat):
    """Update features vector and smooth it using exponential moving average."""
    feat /= np.linalg.norm(feat)
    self.curr_feat = feat
    if self.smooth_feat is None:
        self.smooth_feat = feat
    else:
        self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
    self.features.append(feat)
    self.smooth_feat /= np.linalg.norm(self.smooth_feat)



ultralytics.trackers.bot_sort.BOTSORT

Basi: BYTETracker

Una versione estesa della classe BYTETracker per YOLOv8, progettata per il tracciamento degli oggetti con ReID e l'algoritmo GMC.

Attributi:

Nome Tipo Descrizione
proximity_thresh float

Soglia di prossimità spaziale (IoU) tra tracce e rilevamenti.

appearance_thresh float

Soglia per la somiglianza di aspetto (incorporazioni ReID) tra le tracce e i rilevamenti.

encoder object

Oggetto per gestire le incorporazioni ReID, impostato su None se ReID non è abilitato.

gmc GMC

Un'istanza dell'algoritmo GMC per l'associazione dei dati.

args object

Argomenti della riga di comando analizzati contenenti parametri di tracciamento.

Metodi:

Nome Descrizione
get_kalmanfilter

Restituisce un'istanza di KalmanFilterXYWH per il tracciamento degli oggetti.

init_track

Inizializza la traccia con i rilevamenti, i punteggi e le classi.

get_dists

Ottiene le distanze tra le tracce e i rilevamenti utilizzando IoU e (facoltativamente) ReID.

multi_predict

Prevedere e seguire più oggetti con il modello YOLOv8 .

Utilizzo

bot_sort = BOTSORT(args, frame_rate) bot_sort.init_track(dets, scores, cls, img) bot_sort.multi_predict(tracce)

Nota

La classe è progettata per funzionare con il modello di rilevamento degli oggetti di YOLOv8 e supporta ReID solo se abilitato tramite gli args.

Codice sorgente in ultralytics/trackers/bot_sort.py
class BOTSORT(BYTETracker):
    """
    An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.

    Attributes:
        proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
        appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
        encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
        gmc (GMC): An instance of the GMC algorithm for data association.
        args (object): Parsed command-line arguments containing tracking parameters.

    Methods:
        get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
        init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
        get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
        multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.

    Usage:
        bot_sort = BOTSORT(args, frame_rate)
        bot_sort.init_track(dets, scores, cls, img)
        bot_sort.multi_predict(tracks)

    Note:
        The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
    """

    def __init__(self, args, frame_rate=30):
        """Initialize YOLOv8 object with ReID module and GMC algorithm."""
        super().__init__(args, frame_rate)
        # ReID module
        self.proximity_thresh = args.proximity_thresh
        self.appearance_thresh = args.appearance_thresh

        if args.with_reid:
            # Haven't supported BoT-SORT(reid) yet
            self.encoder = None
        self.gmc = GMC(method=args.gmc_method)

    def get_kalmanfilter(self):
        """Returns an instance of KalmanFilterXYWH for object tracking."""
        return KalmanFilterXYWH()

    def init_track(self, dets, scores, cls, img=None):
        """Initialize track with detections, scores, and classes."""
        if len(dets) == 0:
            return []
        if self.args.with_reid and self.encoder is not None:
            features_keep = self.encoder.inference(img, dets)
            return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)]  # detections
        else:
            return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)]  # detections

    def get_dists(self, tracks, detections):
        """Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
        dists = matching.iou_distance(tracks, detections)
        dists_mask = dists > self.proximity_thresh

        # TODO: mot20
        # if not self.args.mot20:
        dists = matching.fuse_score(dists, detections)

        if self.args.with_reid and self.encoder is not None:
            emb_dists = matching.embedding_distance(tracks, detections) / 2.0
            emb_dists[emb_dists > self.appearance_thresh] = 1.0
            emb_dists[dists_mask] = 1.0
            dists = np.minimum(dists, emb_dists)
        return dists

    def multi_predict(self, tracks):
        """Predict and track multiple objects with YOLOv8 model."""
        BOTrack.multi_predict(tracks)

    def reset(self):
        """Reset tracker."""
        super().reset()
        self.gmc.reset_params()

__init__(args, frame_rate=30)

Inizializza l'oggetto YOLOv8 con il modulo ReID e l'algoritmo GMC.

Codice sorgente in ultralytics/trackers/bot_sort.py
def __init__(self, args, frame_rate=30):
    """Initialize YOLOv8 object with ReID module and GMC algorithm."""
    super().__init__(args, frame_rate)
    # ReID module
    self.proximity_thresh = args.proximity_thresh
    self.appearance_thresh = args.appearance_thresh

    if args.with_reid:
        # Haven't supported BoT-SORT(reid) yet
        self.encoder = None
    self.gmc = GMC(method=args.gmc_method)

get_dists(tracks, detections)

Ottiene le distanze tra le tracce e i rilevamenti utilizzando le incorporazioni IoU e (opzionalmente) ReID.

Codice sorgente in ultralytics/trackers/bot_sort.py
def get_dists(self, tracks, detections):
    """Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
    dists = matching.iou_distance(tracks, detections)
    dists_mask = dists > self.proximity_thresh

    # TODO: mot20
    # if not self.args.mot20:
    dists = matching.fuse_score(dists, detections)

    if self.args.with_reid and self.encoder is not None:
        emb_dists = matching.embedding_distance(tracks, detections) / 2.0
        emb_dists[emb_dists > self.appearance_thresh] = 1.0
        emb_dists[dists_mask] = 1.0
        dists = np.minimum(dists, emb_dists)
    return dists

get_kalmanfilter()

Restituisce un'istanza di KalmanFilterXYWH per il tracciamento degli oggetti.

Codice sorgente in ultralytics/trackers/bot_sort.py
def get_kalmanfilter(self):
    """Returns an instance of KalmanFilterXYWH for object tracking."""
    return KalmanFilterXYWH()

init_track(dets, scores, cls, img=None)

Inizializza la traccia con i rilevamenti, i punteggi e le classi.

Codice sorgente in ultralytics/trackers/bot_sort.py
def init_track(self, dets, scores, cls, img=None):
    """Initialize track with detections, scores, and classes."""
    if len(dets) == 0:
        return []
    if self.args.with_reid and self.encoder is not None:
        features_keep = self.encoder.inference(img, dets)
        return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)]  # detections
    else:
        return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)]  # detections

multi_predict(tracks)

Prevedere e seguire più oggetti con il modello YOLOv8 .

Codice sorgente in ultralytics/trackers/bot_sort.py
def multi_predict(self, tracks):
    """Predict and track multiple objects with YOLOv8 model."""
    BOTrack.multi_predict(tracks)

reset()

Azzeramento del tracker.

Codice sorgente in ultralytics/trackers/bot_sort.py
def reset(self):
    """Reset tracker."""
    super().reset()
    self.gmc.reset_params()





Creato 2023-11-12, Aggiornato 2024-05-08
Autori: Burhan-Q (1), glenn-jocher (3), Laughing-q (1)