跳至内容

参考资料 ultralytics/solutions/distance_calculation.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/solutions/distance_calculation .py 下找到。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.solutions.distance_calculation.DistanceCalculation

根据实时视频流中两个物体的轨迹计算它们之间距离的类。

源代码 ultralytics/solutions/distance_calculation.py
class DistanceCalculation:
    """A class to calculate distance between two objects in real-time video stream based on their tracks."""

    def __init__(self):
        """Initializes the distance calculation class with default values for Visual, Image, track and distance
        parameters.
        """

        # Visual & im0 information
        self.im0 = None
        self.annotator = None
        self.view_img = False
        self.line_color = (255, 255, 0)
        self.centroid_color = (255, 0, 255)

        # Predict/track information
        self.clss = None
        self.names = None
        self.boxes = None
        self.line_thickness = 2
        self.trk_ids = None

        # Distance calculation information
        self.centroids = []
        self.pixel_per_meter = 10

        # Mouse event
        self.left_mouse_count = 0
        self.selected_boxes = {}

        # Check if environment support imshow
        self.env_check = check_imshow(warn=True)

    def set_args(
        self,
        names,
        pixels_per_meter=10,
        view_img=False,
        line_thickness=2,
        line_color=(255, 255, 0),
        centroid_color=(255, 0, 255),
    ):
        """
        Configures the distance calculation and display parameters.

        Args:
            names (dict): object detection classes names
            pixels_per_meter (int): Number of pixels in meter
            view_img (bool): Flag indicating frame display
            line_thickness (int): Line thickness for bounding boxes.
            line_color (RGB): color of centroids line
            centroid_color (RGB): colors of bbox centroids
        """
        self.names = names
        self.pixel_per_meter = pixels_per_meter
        self.view_img = view_img
        self.line_thickness = line_thickness
        self.line_color = line_color
        self.centroid_color = centroid_color

    def mouse_event_for_distance(self, event, x, y, flags, param):
        """
        This function is designed to move region with mouse events in a real-time video stream.

        Args:
            event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
            x (int): The x-coordinate of the mouse pointer.
            y (int): The y-coordinate of the mouse pointer.
            flags (int): Any flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY,
                cv2.EVENT_FLAG_SHIFTKEY, etc.).
            param (dict): Additional parameters you may want to pass to the function.
        """
        global selected_boxes
        global left_mouse_count
        if event == cv2.EVENT_LBUTTONDOWN:
            self.left_mouse_count += 1
            if self.left_mouse_count <= 2:
                for box, track_id in zip(self.boxes, self.trk_ids):
                    if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
                        self.selected_boxes[track_id] = []
                        self.selected_boxes[track_id] = box

        if event == cv2.EVENT_RBUTTONDOWN:
            self.selected_boxes = {}
            self.left_mouse_count = 0

    def extract_tracks(self, tracks):
        """
        Extracts results from the provided data.

        Args:
            tracks (list): List of tracks obtained from the object tracking process.
        """
        self.boxes = tracks[0].boxes.xyxy.cpu()
        self.clss = tracks[0].boxes.cls.cpu().tolist()
        self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()

    def calculate_centroid(self, box):
        """
        Calculate the centroid of bounding box.

        Args:
            box (list): Bounding box data
        """
        return int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)

    def calculate_distance(self, centroid1, centroid2):
        """
        Calculate distance between two centroids.

        Args:
            centroid1 (point): First bounding box data
            centroid2 (point): Second bounding box data
        """
        pixel_distance = math.sqrt((centroid1[0] - centroid2[0]) ** 2 + (centroid1[1] - centroid2[1]) ** 2)
        return pixel_distance / self.pixel_per_meter, (pixel_distance / self.pixel_per_meter) * 1000

    def start_process(self, im0, tracks):
        """
        Calculate distance between two bounding boxes based on tracking data.

        Args:
            im0 (nd array): Image
            tracks (list): List of tracks obtained from the object tracking process.
        """
        self.im0 = im0
        if tracks[0].boxes.id is None:
            if self.view_img:
                self.display_frames()
            return
        self.extract_tracks(tracks)

        self.annotator = Annotator(self.im0, line_width=2)

        for box, cls, track_id in zip(self.boxes, self.clss, self.trk_ids):
            self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)])

            if len(self.selected_boxes) == 2:
                for trk_id, _ in self.selected_boxes.items():
                    if trk_id == track_id:
                        self.selected_boxes[track_id] = box

        if len(self.selected_boxes) == 2:
            for trk_id, box in self.selected_boxes.items():
                centroid = self.calculate_centroid(self.selected_boxes[trk_id])
                self.centroids.append(centroid)

            distance_m, distance_mm = self.calculate_distance(self.centroids[0], self.centroids[1])
            self.annotator.plot_distance_and_line(
                distance_m, distance_mm, self.centroids, self.line_color, self.centroid_color
            )

        self.centroids = []

        if self.view_img and self.env_check:
            self.display_frames()

        return im0

    def display_frames(self):
        """Display frame."""
        cv2.namedWindow("Ultralytics Distance Estimation")
        cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance)
        cv2.imshow("Ultralytics Distance Estimation", self.im0)

        if cv2.waitKey(1) & 0xFF == ord("q"):
            return

__init__()

使用视觉、图像、轨迹和距离参数的默认值初始化距离计算类。 参数的默认值。

源代码 ultralytics/solutions/distance_calculation.py
def __init__(self):
    """Initializes the distance calculation class with default values for Visual, Image, track and distance
    parameters.
    """

    # Visual & im0 information
    self.im0 = None
    self.annotator = None
    self.view_img = False
    self.line_color = (255, 255, 0)
    self.centroid_color = (255, 0, 255)

    # Predict/track information
    self.clss = None
    self.names = None
    self.boxes = None
    self.line_thickness = 2
    self.trk_ids = None

    # Distance calculation information
    self.centroids = []
    self.pixel_per_meter = 10

    # Mouse event
    self.left_mouse_count = 0
    self.selected_boxes = {}

    # Check if environment support imshow
    self.env_check = check_imshow(warn=True)

calculate_centroid(box)

计算包围盒的中心点。

参数

名称 类型 说明 默认值
box list

边界框数据

所需
源代码 ultralytics/solutions/distance_calculation.py
def calculate_centroid(self, box):
    """
    Calculate the centroid of bounding box.

    Args:
        box (list): Bounding box data
    """
    return int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)

calculate_distance(centroid1, centroid2)

计算两个中心点之间的距离。

参数

名称 类型 说明 默认值
centroid1 point

第一个边界框数据

所需
centroid2 point

第二边界框数据

所需
源代码 ultralytics/solutions/distance_calculation.py
def calculate_distance(self, centroid1, centroid2):
    """
    Calculate distance between two centroids.

    Args:
        centroid1 (point): First bounding box data
        centroid2 (point): Second bounding box data
    """
    pixel_distance = math.sqrt((centroid1[0] - centroid2[0]) ** 2 + (centroid1[1] - centroid2[1]) ** 2)
    return pixel_distance / self.pixel_per_meter, (pixel_distance / self.pixel_per_meter) * 1000

display_frames()

显示框。

源代码 ultralytics/solutions/distance_calculation.py
def display_frames(self):
    """Display frame."""
    cv2.namedWindow("Ultralytics Distance Estimation")
    cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance)
    cv2.imshow("Ultralytics Distance Estimation", self.im0)

    if cv2.waitKey(1) & 0xFF == ord("q"):
        return

extract_tracks(tracks)

从提供的数据中提取结果。

参数

名称 类型 说明 默认值
tracks list

物体追踪过程中获得的轨迹列表。

所需
源代码 ultralytics/solutions/distance_calculation.py
def extract_tracks(self, tracks):
    """
    Extracts results from the provided data.

    Args:
        tracks (list): List of tracks obtained from the object tracking process.
    """
    self.boxes = tracks[0].boxes.xyxy.cpu()
    self.clss = tracks[0].boxes.cls.cpu().tolist()
    self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()

mouse_event_for_distance(event, x, y, flags, param)

该函数用于在实时视频流中通过鼠标事件移动区域。

参数

名称 类型 说明 默认值
event int

鼠标事件的类型(如 cv2.EVENT_MOUSEMOVE、cv2.EVENT_LBUTTONDOWN 等)。

所需
x int

鼠标指针的 x 坐标。

所需
y int

鼠标指针的 Y 坐标。

所需
flags int

与事件相关的任何标记(例如,cv2.EVENT_FLAG_CTRLKEY、 cv2.EVENT_FLAG_SHIFTKEY 等)。

所需
param dict

您可能希望传递给函数的其他参数。

所需
源代码 ultralytics/solutions/distance_calculation.py
def mouse_event_for_distance(self, event, x, y, flags, param):
    """
    This function is designed to move region with mouse events in a real-time video stream.

    Args:
        event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
        x (int): The x-coordinate of the mouse pointer.
        y (int): The y-coordinate of the mouse pointer.
        flags (int): Any flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY,
            cv2.EVENT_FLAG_SHIFTKEY, etc.).
        param (dict): Additional parameters you may want to pass to the function.
    """
    global selected_boxes
    global left_mouse_count
    if event == cv2.EVENT_LBUTTONDOWN:
        self.left_mouse_count += 1
        if self.left_mouse_count <= 2:
            for box, track_id in zip(self.boxes, self.trk_ids):
                if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
                    self.selected_boxes[track_id] = []
                    self.selected_boxes[track_id] = box

    if event == cv2.EVENT_RBUTTONDOWN:
        self.selected_boxes = {}
        self.left_mouse_count = 0

set_args(names, pixels_per_meter=10, view_img=False, line_thickness=2, line_color=(255, 255, 0), centroid_color=(255, 0, 255))

配置距离计算和显示参数。

参数

名称 类型 说明 默认值
names dict

对象检测类名称

所需
pixels_per_meter int

米内像素数

10
view_img bool

显示帧的标志

False
line_thickness int

边界框的线条粗细

2
line_color RGB

中心线的颜色

(255, 255, 0)
centroid_color RGB

方框中心点的颜色

(255, 0, 255)
源代码 ultralytics/solutions/distance_calculation.py
def set_args(
    self,
    names,
    pixels_per_meter=10,
    view_img=False,
    line_thickness=2,
    line_color=(255, 255, 0),
    centroid_color=(255, 0, 255),
):
    """
    Configures the distance calculation and display parameters.

    Args:
        names (dict): object detection classes names
        pixels_per_meter (int): Number of pixels in meter
        view_img (bool): Flag indicating frame display
        line_thickness (int): Line thickness for bounding boxes.
        line_color (RGB): color of centroids line
        centroid_color (RGB): colors of bbox centroids
    """
    self.names = names
    self.pixel_per_meter = pixels_per_meter
    self.view_img = view_img
    self.line_thickness = line_thickness
    self.line_color = line_color
    self.centroid_color = centroid_color

start_process(im0, tracks)

根据跟踪数据计算两个边界框之间的距离。

参数

名称 类型 说明 默认值
im0 nd array

图片

所需
tracks list

物体追踪过程中获得的轨迹列表。

所需
源代码 ultralytics/solutions/distance_calculation.py
def start_process(self, im0, tracks):
    """
    Calculate distance between two bounding boxes based on tracking data.

    Args:
        im0 (nd array): Image
        tracks (list): List of tracks obtained from the object tracking process.
    """
    self.im0 = im0
    if tracks[0].boxes.id is None:
        if self.view_img:
            self.display_frames()
        return
    self.extract_tracks(tracks)

    self.annotator = Annotator(self.im0, line_width=2)

    for box, cls, track_id in zip(self.boxes, self.clss, self.trk_ids):
        self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)])

        if len(self.selected_boxes) == 2:
            for trk_id, _ in self.selected_boxes.items():
                if trk_id == track_id:
                    self.selected_boxes[track_id] = box

    if len(self.selected_boxes) == 2:
        for trk_id, box in self.selected_boxes.items():
            centroid = self.calculate_centroid(self.selected_boxes[trk_id])
            self.centroids.append(centroid)

        distance_m, distance_mm = self.calculate_distance(self.centroids[0], self.centroids[1])
        self.annotator.plot_distance_and_line(
            distance_m, distance_mm, self.centroids, self.line_color, self.centroid_color
        )

    self.centroids = []

    if self.view_img and self.env_check:
        self.display_frames()

    return im0





创建于 2024-01-05,更新于 2024-01-10
作者:AyushExel(1),chr043416@gmail.com(1)