Skip to content

Reference for ultralytics/trackers/byte_tracker.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/trackers/byte_tracker.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.trackers.byte_tracker.STrack

STrack(self, xywh: list[float], score: float, cls: Any)

Bases: BaseTrack

Single object tracking representation that uses Kalman filtering for state estimation.

This class is responsible for storing all the information regarding individual tracklets and performs state updates and predictions based on Kalman filter.

Args

NameTypeDescriptionDefault
xywhlist[float]Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.required
scorefloatConfidence score of the detection.required
clsAnyClass label for the detected object.required

Attributes

NameTypeDescription
shared_kalmanKalmanFilterXYAHShared Kalman filter used across all STrack instances for prediction.
_tlwhnp.ndarrayPrivate attribute to store top-left corner coordinates and width and height of bounding box.
kalman_filterKalmanFilterXYAHInstance of Kalman filter used for this particular object track.
meannp.ndarrayMean state estimate vector.
covariancenp.ndarrayCovariance of state estimate.
is_activatedboolBoolean flag indicating if the track has been activated.
scorefloatConfidence score of the track.
tracklet_lenintLength of the tracklet.
clsAnyClass label for the object.
idxintIndex or identifier for the object.
frame_idintCurrent frame ID.
start_frameintFrame where the object was first detected.
anglefloat | NoneOptional angle information for oriented bounding boxes.

Methods

NameDescription
tlwhGet the bounding box in top-left-width-height format from the current state estimate.
xyxyConvert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.
xywhGet the current position of the bounding box in (center x, center y, width, height) format.
xywhaGet position in (center x, center y, width, height, angle) format, warning if angle is missing.
resultGet the current tracking results in the appropriate bounding box format.
__repr__Return a string representation of the STrack object including start frame, end frame, and track ID.
activateActivate a new tracklet using the provided Kalman filter and initialize its state and covariance.
convert_coordsConvert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.
multi_gmcUpdate state tracks positions and covariances using a homography matrix for multiple tracks.
multi_predictPerform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.
predictPredict the next state (mean and covariance) of the object using the Kalman filter.
re_activateReactivate a previously lost track using new detection data and update its state and attributes.
tlwh_to_xyahConvert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.
updateUpdate the state of a matched track.

Examples

Initialize and activate a new track
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
class STrack(BaseTrack):
    """Single object tracking representation that uses Kalman filtering for state estimation.

    This class is responsible for storing all the information regarding individual tracklets and performs state updates
    and predictions based on Kalman filter.

    Attributes:
        shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.
        _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
        kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
        mean (np.ndarray): Mean state estimate vector.
        covariance (np.ndarray): Covariance of state estimate.
        is_activated (bool): Boolean flag indicating if the track has been activated.
        score (float): Confidence score of the track.
        tracklet_len (int): Length of the tracklet.
        cls (Any): Class label for the object.
        idx (int): Index or identifier for the object.
        frame_id (int): Current frame ID.
        start_frame (int): Frame where the object was first detected.
        angle (float | None): Optional angle information for oriented bounding boxes.

    Methods:
        predict: Predict the next state of the object using Kalman filter.
        multi_predict: Predict the next states for multiple tracks.
        multi_gmc: Update multiple track states using a homography matrix.
        activate: Activate a new tracklet.
        re_activate: Reactivate a previously lost tracklet.
        update: Update the state of a matched track.
        convert_coords: Convert bounding box to x-y-aspect-height format.
        tlwh_to_xyah: Convert tlwh bounding box to xyah format.

    Examples:
        Initialize and activate a new track
        >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
        >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
    """

    shared_kalman = KalmanFilterXYAH()

    def __init__(self, xywh: list[float], score: float, cls: Any):
        """Initialize a new STrack instance.

        Args:
            xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where (x,
                y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
            score (float): Confidence score of the detection.
            cls (Any): Class label for the detected object.
        """
        super().__init__()
        # xywh+idx or xywha+idx
        assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
        self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
        self.kalman_filter = None
        self.mean, self.covariance = None, None
        self.is_activated = False

        self.score = score
        self.tracklet_len = 0
        self.cls = cls
        self.idx = xywh[-1]
        self.angle = xywh[4] if len(xywh) == 6 else None


property ultralytics.trackers.byte_tracker.STrack.tlwh

def tlwh(self) -> np.ndarray

Get the bounding box in top-left-width-height format from the current state estimate.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@property
def tlwh(self) -> np.ndarray:
    """Get the bounding box in top-left-width-height format from the current state estimate."""
    if self.mean is None:
        return self._tlwh.copy()
    ret = self.mean[:4].copy()
    ret[2] *= ret[3]
    ret[:2] -= ret[2:] / 2
    return ret


property ultralytics.trackers.byte_tracker.STrack.xyxy

def xyxy(self) -> np.ndarray

Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@property
def xyxy(self) -> np.ndarray:
    """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
    ret = self.tlwh.copy()
    ret[2:] += ret[:2]
    return ret


property ultralytics.trackers.byte_tracker.STrack.xywh

def xywh(self) -> np.ndarray

Get the current position of the bounding box in (center x, center y, width, height) format.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@property
def xywh(self) -> np.ndarray:
    """Get the current position of the bounding box in (center x, center y, width, height) format."""
    ret = np.asarray(self.tlwh).copy()
    ret[:2] += ret[2:] / 2
    return ret


property ultralytics.trackers.byte_tracker.STrack.xywha

def xywha(self) -> np.ndarray

Get position in (center x, center y, width, height, angle) format, warning if angle is missing.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@property
def xywha(self) -> np.ndarray:
    """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
    if self.angle is None:
        LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
        return self.xywh
    return np.concatenate([self.xywh, self.angle[None]])


property ultralytics.trackers.byte_tracker.STrack.result

def result(self) -> list[float]

Get the current tracking results in the appropriate bounding box format.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@property
def result(self) -> list[float]:
    """Get the current tracking results in the appropriate bounding box format."""
    coords = self.xyxy if self.angle is None else self.xywha
    return [*coords.tolist(), self.track_id, self.score, self.cls, self.idx]


method ultralytics.trackers.byte_tracker.STrack.__repr__

def __repr__(self) -> str

Return a string representation of the STrack object including start frame, end frame, and track ID.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def __repr__(self) -> str:
    """Return a string representation of the STrack object including start frame, end frame, and track ID."""
    return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"


method ultralytics.trackers.byte_tracker.STrack.activate

def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int)

Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.

Args

NameTypeDescriptionDefault
kalman_filterKalmanFilterXYAHrequired
frame_idintrequired
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
    """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
    self.kalman_filter = kalman_filter
    self.track_id = self.next_id()
    self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))

    self.tracklet_len = 0
    self.state = TrackState.Tracked
    if frame_id == 1:
        self.is_activated = True
    self.frame_id = frame_id
    self.start_frame = frame_id


method ultralytics.trackers.byte_tracker.STrack.convert_coords

def convert_coords(self, tlwh: np.ndarray) -> np.ndarray

Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.

Args

NameTypeDescriptionDefault
tlwhnp.ndarrayrequired
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
    """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
    return self.tlwh_to_xyah(tlwh)


method ultralytics.trackers.byte_tracker.STrack.multi_gmc

def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3))

Update state tracks positions and covariances using a homography matrix for multiple tracks.

Args

NameTypeDescriptionDefault
strackslist[STrack]required
Hnp.ndarraynp.eye(2, 3)
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3)):
    """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
    if stracks:
        multi_mean = np.asarray([st.mean.copy() for st in stracks])
        multi_covariance = np.asarray([st.covariance for st in stracks])

        R = H[:2, :2]
        R8x8 = np.kron(np.eye(4, dtype=float), R)
        t = H[:2, 2]

        for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
            mean = R8x8.dot(mean)
            mean[:2] += t
            cov = R8x8.dot(cov).dot(R8x8.transpose())

            stracks[i].mean = mean
            stracks[i].covariance = cov


method ultralytics.trackers.byte_tracker.STrack.multi_predict

def multi_predict(stracks: list[STrack])

Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.

Args

NameTypeDescriptionDefault
strackslist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def multi_predict(stracks: list[STrack]):
    """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
    if len(stracks) <= 0:
        return
    multi_mean = np.asarray([st.mean.copy() for st in stracks])
    multi_covariance = np.asarray([st.covariance for st in stracks])
    for i, st in enumerate(stracks):
        if st.state != TrackState.Tracked:
            multi_mean[i][7] = 0
    multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
    for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
        stracks[i].mean = mean
        stracks[i].covariance = cov


method ultralytics.trackers.byte_tracker.STrack.predict

def predict(self)

Predict the next state (mean and covariance) of the object using the Kalman filter.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def predict(self):
    """Predict the next state (mean and covariance) of the object using the Kalman filter."""
    mean_state = self.mean.copy()
    if self.state != TrackState.Tracked:
        mean_state[7] = 0
    self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)


method ultralytics.trackers.byte_tracker.STrack.re_activate

def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False)

Reactivate a previously lost track using new detection data and update its state and attributes.

Args

NameTypeDescriptionDefault
new_trackSTrackrequired
frame_idintrequired
new_idboolFalse
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False):
    """Reactivate a previously lost track using new detection data and update its state and attributes."""
    self.mean, self.covariance = self.kalman_filter.update(
        self.mean, self.covariance, self.convert_coords(new_track.tlwh)
    )
    self.tracklet_len = 0
    self.state = TrackState.Tracked
    self.is_activated = True
    self.frame_id = frame_id
    if new_id:
        self.track_id = self.next_id()
    self.score = new_track.score
    self.cls = new_track.cls
    self.angle = new_track.angle
    self.idx = new_track.idx


method ultralytics.trackers.byte_tracker.STrack.tlwh_to_xyah

def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray

Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.

Args

NameTypeDescriptionDefault
tlwhnp.ndarrayrequired
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
    """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
    ret = np.asarray(tlwh).copy()
    ret[:2] += ret[2:] / 2
    ret[2] /= ret[3]
    return ret


method ultralytics.trackers.byte_tracker.STrack.update

def update(self, new_track: STrack, frame_id: int)

Update the state of a matched track.

Args

NameTypeDescriptionDefault
new_trackSTrackThe new track containing updated information.required
frame_idintThe ID of the current frame.required

Examples

Update the state of a track with new detection information
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
>>> track.update(new_track, 2)
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def update(self, new_track: STrack, frame_id: int):
    """Update the state of a matched track.

    Args:
        new_track (STrack): The new track containing updated information.
        frame_id (int): The ID of the current frame.

    Examples:
        Update the state of a track with new detection information
        >>> track = STrack([100, 200, 50, 80, 0.9, 1])
        >>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
        >>> track.update(new_track, 2)
    """
    self.frame_id = frame_id
    self.tracklet_len += 1

    new_tlwh = new_track.tlwh
    self.mean, self.covariance = self.kalman_filter.update(
        self.mean, self.covariance, self.convert_coords(new_tlwh)
    )
    self.state = TrackState.Tracked
    self.is_activated = True

    self.score = new_track.score
    self.cls = new_track.cls
    self.angle = new_track.angle
    self.idx = new_track.idx





class ultralytics.trackers.byte_tracker.BYTETracker

BYTETracker(self, args, frame_rate: int = 30)

BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.

This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting the new object locations, and performs data association.

Args

NameTypeDescriptionDefault
argsNamespaceCommand-line arguments containing tracking parameters.required
frame_rateintFrame rate of the video sequence.30

Attributes

NameTypeDescription
tracked_strackslist[STrack]List of successfully activated tracks.
lost_strackslist[STrack]List of lost tracks.
removed_strackslist[STrack]List of removed tracks.
frame_idintThe current frame ID.
argsNamespaceCommand-line arguments.
max_time_lostintThe maximum frames for a track to be considered as 'lost'.
kalman_filterKalmanFilterXYAHKalman Filter object.

Methods

NameDescription
get_distsCalculate the distance between tracks and detections using IoU and optionally fuse scores.
get_kalmanfilterReturn a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.
init_trackInitialize object tracking with given detections, scores, and class labels using the STrack algorithm.
joint_stracksCombine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.
multi_predictPredict the next states for multiple tracks using Kalman filter.
remove_duplicate_stracksRemove duplicate stracks from two lists based on Intersection over Union (IoU) distance.
resetReset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.
reset_idReset the ID counter for STrack instances to ensure unique track IDs across tracking sessions.
sub_stracksFilter out the stracks present in the second list from the first list.
updateUpdate the tracker with new detections and return the current list of tracked objects.

Examples

Initialize BYTETracker and update with detection results
>>> tracker = BYTETracker(args, frame_rate=30)
>>> results = yolo_model.detect(image)
>>> tracked_objects = tracker.update(results)
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
class BYTETracker:
    """BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.

    This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
    in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
    filtering for predicting the new object locations, and performs data association.

    Attributes:
        tracked_stracks (list[STrack]): List of successfully activated tracks.
        lost_stracks (list[STrack]): List of lost tracks.
        removed_stracks (list[STrack]): List of removed tracks.
        frame_id (int): The current frame ID.
        args (Namespace): Command-line arguments.
        max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
        kalman_filter (KalmanFilterXYAH): Kalman Filter object.

    Methods:
        update: Update object tracker with new detections.
        get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
        init_track: Initialize object tracking with detections.
        get_dists: Calculate the distance between tracks and detections.
        multi_predict: Predict the location of tracks.
        reset_id: Reset the ID counter of STrack.
        reset: Reset the tracker by clearing all tracks.
        joint_stracks: Combine two lists of stracks.
        sub_stracks: Filter out the stracks present in the second list from the first list.
        remove_duplicate_stracks: Remove duplicate stracks based on IoU.

    Examples:
        Initialize BYTETracker and update with detection results
        >>> tracker = BYTETracker(args, frame_rate=30)
        >>> results = yolo_model.detect(image)
        >>> tracked_objects = tracker.update(results)
    """

    def __init__(self, args, frame_rate: int = 30):
        """Initialize a BYTETracker instance for object tracking.

        Args:
            args (Namespace): Command-line arguments containing tracking parameters.
            frame_rate (int): Frame rate of the video sequence.
        """
        self.tracked_stracks = []  # type: list[STrack]
        self.lost_stracks = []  # type: list[STrack]
        self.removed_stracks = []  # type: list[STrack]

        self.frame_id = 0
        self.args = args
        self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)
        self.kalman_filter = self.get_kalmanfilter()
        self.reset_id()


method ultralytics.trackers.byte_tracker.BYTETracker.get_dists

def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray

Calculate the distance between tracks and detections using IoU and optionally fuse scores.

Args

NameTypeDescriptionDefault
trackslist[STrack]required
detectionslist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray:
    """Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
    dists = matching.iou_distance(tracks, detections)
    if self.args.fuse_score:
        dists = matching.fuse_score(dists, detections)
    return dists


method ultralytics.trackers.byte_tracker.BYTETracker.get_kalmanfilter

def get_kalmanfilter(self) -> KalmanFilterXYAH

Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def get_kalmanfilter(self) -> KalmanFilterXYAH:
    """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
    return KalmanFilterXYAH()


method ultralytics.trackers.byte_tracker.BYTETracker.init_track

def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]

Initialize object tracking with given detections, scores, and class labels using the STrack algorithm.

Args

NameTypeDescriptionDefault
resultsrequired
imgnp.ndarray | NoneNone
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]:
    """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
    if len(results) == 0:
        return []
    bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
    bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
    return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]


method ultralytics.trackers.byte_tracker.BYTETracker.joint_stracks

def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]

Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.

Args

NameTypeDescriptionDefault
tlistalist[STrack]required
tlistblist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
    """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
    exists = {}
    res = []
    for t in tlista:
        exists[t.track_id] = 1
        res.append(t)
    for t in tlistb:
        tid = t.track_id
        if not exists.get(tid, 0):
            exists[tid] = 1
            res.append(t)
    return res


method ultralytics.trackers.byte_tracker.BYTETracker.multi_predict

def multi_predict(self, tracks: list[STrack])

Predict the next states for multiple tracks using Kalman filter.

Args

NameTypeDescriptionDefault
trackslist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def multi_predict(self, tracks: list[STrack]):
    """Predict the next states for multiple tracks using Kalman filter."""
    STrack.multi_predict(tracks)


method ultralytics.trackers.byte_tracker.BYTETracker.remove_duplicate_stracks

def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]

Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance.

Args

NameTypeDescriptionDefault
stracksalist[STrack]required
stracksblist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]:
    """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
    pdist = matching.iou_distance(stracksa, stracksb)
    pairs = np.where(pdist < 0.15)
    dupa, dupb = [], []
    for p, q in zip(*pairs):
        timep = stracksa[p].frame_id - stracksa[p].start_frame
        timeq = stracksb[q].frame_id - stracksb[q].start_frame
        if timep > timeq:
            dupb.append(q)
        else:
            dupa.append(p)
    resa = [t for i, t in enumerate(stracksa) if i not in dupa]
    resb = [t for i, t in enumerate(stracksb) if i not in dupb]
    return resa, resb


method ultralytics.trackers.byte_tracker.BYTETracker.reset

def reset(self)

Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def reset(self):
    """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
    self.tracked_stracks = []  # type: list[STrack]
    self.lost_stracks = []  # type: list[STrack]
    self.removed_stracks = []  # type: list[STrack]
    self.frame_id = 0
    self.kalman_filter = self.get_kalmanfilter()
    self.reset_id()


method ultralytics.trackers.byte_tracker.BYTETracker.reset_id

def reset_id()

Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions.

Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def reset_id():
    """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
    STrack.reset_id()


method ultralytics.trackers.byte_tracker.BYTETracker.sub_stracks

def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]

Filter out the stracks present in the second list from the first list.

Args

NameTypeDescriptionDefault
tlistalist[STrack]required
tlistblist[STrack]required
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
@staticmethod
def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
    """Filter out the stracks present in the second list from the first list."""
    track_ids_b = {t.track_id for t in tlistb}
    return [t for t in tlista if t.track_id not in track_ids_b]


method ultralytics.trackers.byte_tracker.BYTETracker.update

def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray

Update the tracker with new detections and return the current list of tracked objects.

Args

NameTypeDescriptionDefault
resultsrequired
imgnp.ndarray | NoneNone
featsnp.ndarray | NoneNone
Source code in ultralytics/trackers/byte_tracker.pyView on GitHub
def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray:
    """Update the tracker with new detections and return the current list of tracked objects."""
    self.frame_id += 1
    activated_stracks = []
    refind_stracks = []
    lost_stracks = []
    removed_stracks = []

    scores = results.conf
    remain_inds = scores >= self.args.track_high_thresh
    inds_low = scores > self.args.track_low_thresh
    inds_high = scores < self.args.track_high_thresh

    inds_second = inds_low & inds_high
    results_second = results[inds_second]
    results = results[remain_inds]
    feats_keep = feats_second = img
    if feats is not None and len(feats):
        feats_keep = feats[remain_inds]
        feats_second = feats[inds_second]

    detections = self.init_track(results, feats_keep)
    # Add newly detected tracklets to tracked_stracks
    unconfirmed = []
    tracked_stracks = []  # type: list[STrack]
    for track in self.tracked_stracks:
        if not track.is_activated:
            unconfirmed.append(track)
        else:
            tracked_stracks.append(track)
    # Step 2: First association, with high score detection boxes
    strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
    # Predict the current location with KF
    self.multi_predict(strack_pool)
    if hasattr(self, "gmc") and img is not None:
        # use try-except here to bypass errors from gmc module
        try:
            warp = self.gmc.apply(img, results.xyxy)
        except Exception:
            warp = np.eye(2, 3)
        STrack.multi_gmc(strack_pool, warp)
        STrack.multi_gmc(unconfirmed, warp)

    dists = self.get_dists(strack_pool, detections)
    matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)

    for itracked, idet in matches:
        track = strack_pool[itracked]
        det = detections[idet]
        if track.state == TrackState.Tracked:
            track.update(det, self.frame_id)
            activated_stracks.append(track)
        else:
            track.re_activate(det, self.frame_id, new_id=False)
            refind_stracks.append(track)
    # Step 3: Second association, with low score detection boxes association the untrack to the low score detections
    detections_second = self.init_track(results_second, feats_second)
    r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
    # TODO
    dists = matching.iou_distance(r_tracked_stracks, detections_second)
    matches, u_track, _u_detection_second = matching.linear_assignment(dists, thresh=0.5)
    for itracked, idet in matches:
        track = r_tracked_stracks[itracked]
        det = detections_second[idet]
        if track.state == TrackState.Tracked:
            track.update(det, self.frame_id)
            activated_stracks.append(track)
        else:
            track.re_activate(det, self.frame_id, new_id=False)
            refind_stracks.append(track)

    for it in u_track:
        track = r_tracked_stracks[it]
        if track.state != TrackState.Lost:
            track.mark_lost()
            lost_stracks.append(track)
    # Deal with unconfirmed tracks, usually tracks with only one beginning frame
    detections = [detections[i] for i in u_detection]
    dists = self.get_dists(unconfirmed, detections)
    matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
    for itracked, idet in matches:
        unconfirmed[itracked].update(detections[idet], self.frame_id)
        activated_stracks.append(unconfirmed[itracked])
    for it in u_unconfirmed:
        track = unconfirmed[it]
        track.mark_removed()
        removed_stracks.append(track)
    # Step 4: Init new stracks
    for inew in u_detection:
        track = detections[inew]
        if track.score < self.args.new_track_thresh:
            continue
        track.activate(self.kalman_filter, self.frame_id)
        activated_stracks.append(track)
    # Step 5: Update state
    for track in self.lost_stracks:
        if self.frame_id - track.end_frame > self.max_time_lost:
            track.mark_removed()
            removed_stracks.append(track)

    self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
    self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
    self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
    self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
    self.lost_stracks.extend(lost_stracks)
    self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
    self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
    self.removed_stracks.extend(removed_stracks)
    if len(self.removed_stracks) > 1000:
        self.removed_stracks = self.removed_stracks[-999:]  # clip remove stracks to 1000 maximum

    return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherjk4eBurhan-Q