Skip to content

Reference for ultralytics/solutions/streamlit_inference.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/streamlit_inference.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.solutions.streamlit_inference.Inference

Inference(**kwargs: Any)

A class to perform object detection, image classification, image segmentation and pose estimation inference.

This class provides functionalities for loading models, configuring settings, uploading video files, and performing real-time inference using Streamlit and Ultralytics YOLO models.

Attributes:

Name Type Description
st module

Streamlit module for UI creation.

temp_dict dict

Temporary dictionary to store the model path and other configuration.

model_path str

Path to the loaded model.

model YOLO

The YOLO model instance.

source str

Selected video source (webcam or video file).

enable_trk str

Enable tracking option ("Yes" or "No").

conf float

Confidence threshold for detection.

iou float

IoU threshold for non-maximum suppression.

org_frame Any

Container for the original frame to be displayed.

ann_frame Any

Container for the annotated frame to be displayed.

vid_file_name str | int

Name of the uploaded video file or webcam index.

selected_ind List[int]

List of selected class indices for detection.

Methods:

Name Description
web_ui

Sets up the Streamlit web interface with custom HTML elements.

sidebar

Configures the Streamlit sidebar for model and inference settings.

source_upload

Handles video file uploads through the Streamlit interface.

configure

Configures the model and loads selected classes for inference.

inference

Performs real-time object detection inference.

Examples:

>>> inf = Inference(model="path/to/model.pt")  # Model is an optional argument
>>> inf.inference()

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments for model configuration.

{}
Source code in ultralytics/solutions/streamlit_inference.py
def __init__(self, **kwargs: Any):
    """
    Initialize the Inference class, checking Streamlit requirements and setting up the model path.

    Args:
        **kwargs (Any): Additional keyword arguments for model configuration.
    """
    check_requirements("streamlit>=1.29.0")  # scope imports for faster ultralytics package load speeds
    import streamlit as st

    self.st = st  # Reference to the Streamlit module
    self.source = None  # Video source selection (webcam or video file)
    self.enable_trk = False  # Flag to toggle object tracking
    self.conf = 0.25  # Confidence threshold for detection
    self.iou = 0.45  # Intersection-over-Union (IoU) threshold for non-maximum suppression
    self.org_frame = None  # Container for the original frame display
    self.ann_frame = None  # Container for the annotated frame display
    self.vid_file_name = None  # Video file name or webcam index
    self.selected_ind = []  # List of selected class indices for detection
    self.model = None  # YOLO model instance

    self.temp_dict = {"model": None, **kwargs}
    self.model_path = None  # Model file path
    if self.temp_dict["model"] is not None:
        self.model_path = self.temp_dict["model"]

    LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")

configure

configure()

Configure the model and load selected classes for inference.

Source code in ultralytics/solutions/streamlit_inference.py
def configure(self):
    """Configure the model and load selected classes for inference."""
    # Add dropdown menu for model selection
    available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
    if self.model_path:  # If user provided the custom model, insert model without suffix as *.pt is added later
        available_models.insert(0, self.model_path.split(".pt")[0])
    selected_model = self.st.sidebar.selectbox("Model", available_models)

    with self.st.spinner("Model is downloading..."):
        self.model = YOLO(f"{selected_model.lower()}.pt")  # Load the YOLO model
        class_names = list(self.model.names.values())  # Convert dictionary to list of class names
    self.st.success("Model loaded successfully!")

    # Multiselect box with class names and get indices of selected classes
    selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
    self.selected_ind = [class_names.index(option) for option in selected_classes]

    if not isinstance(self.selected_ind, list):  # Ensure selected_options is a list
        self.selected_ind = list(self.selected_ind)

inference

inference()

Perform real-time object detection inference on video or webcam feed.

Source code in ultralytics/solutions/streamlit_inference.py
def inference(self):
    """Perform real-time object detection inference on video or webcam feed."""
    self.web_ui()  # Initialize the web interface
    self.sidebar()  # Create the sidebar
    self.source_upload()  # Upload the video source
    self.configure()  # Configure the app

    if self.st.sidebar.button("Start"):
        stop_button = self.st.button("Stop")  # Button to stop the inference
        cap = cv2.VideoCapture(self.vid_file_name)  # Capture the video
        if not cap.isOpened():
            self.st.error("Could not open webcam or video source.")
            return

        while cap.isOpened():
            success, frame = cap.read()
            if not success:
                self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
                break

            # Process frame with model
            if self.enable_trk == "Yes":
                results = self.model.track(
                    frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
                )
            else:
                results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)

            annotated_frame = results[0].plot()  # Add annotations on frame

            if stop_button:
                cap.release()  # Release the capture
                self.st.stop()  # Stop streamlit app

            self.org_frame.image(frame, channels="BGR")  # Display original frame
            self.ann_frame.image(annotated_frame, channels="BGR")  # Display processed frame

        cap.release()  # Release the capture
    cv2.destroyAllWindows()  # Destroy all OpenCV windows

sidebar

sidebar()

Configure the Streamlit sidebar for model and inference settings.

Source code in ultralytics/solutions/streamlit_inference.py
def sidebar(self):
    """Configure the Streamlit sidebar for model and inference settings."""
    with self.st.sidebar:  # Add Ultralytics LOGO
        logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
        self.st.image(logo, width=250)

    self.st.sidebar.title("User Configuration")  # Add elements to vertical setting menu
    self.source = self.st.sidebar.selectbox(
        "Video",
        ("webcam", "video"),
    )  # Add source selection dropdown
    self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No"))  # Enable object tracking
    self.conf = float(
        self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
    )  # Slider for confidence
    self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01))  # Slider for NMS threshold

    col1, col2 = self.st.columns(2)  # Create two columns for displaying frames
    self.org_frame = col1.empty()  # Container for original frame
    self.ann_frame = col2.empty()  # Container for annotated frame

source_upload

source_upload()

Handle video file uploads through the Streamlit interface.

Source code in ultralytics/solutions/streamlit_inference.py
def source_upload(self):
    """Handle video file uploads through the Streamlit interface."""
    self.vid_file_name = ""
    if self.source == "video":
        vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
        if vid_file is not None:
            g = io.BytesIO(vid_file.read())  # BytesIO Object
            with open("ultralytics.mp4", "wb") as out:  # Open temporary file as bytes
                out.write(g.read())  # Read bytes into file
            self.vid_file_name = "ultralytics.mp4"
    elif self.source == "webcam":
        self.vid_file_name = 0  # Use webcam index 0

web_ui

web_ui()

Sets up the Streamlit web interface with custom HTML elements.

Source code in ultralytics/solutions/streamlit_inference.py
def web_ui(self):
    """Sets up the Streamlit web interface with custom HTML elements."""
    menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>"""  # Hide main menu style

    # Main title of streamlit application
    main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
    font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""

    # Subtitle of streamlit application
    sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif; 
    margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power 
    of Ultralytics YOLO! 🚀</h4></div>"""

    # Set html page configuration and append custom HTML
    self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
    self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
    self.st.markdown(main_title_cfg, unsafe_allow_html=True)
    self.st.markdown(sub_title_cfg, unsafe_allow_html=True)



📅 Created 8 months ago ✏️ Updated 3 months ago