Reference for ultralytics/solutions/streamlit_inference.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/streamlit_inference.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
class ultralytics.solutions.streamlit_inference.Inference
Inference(self, **kwargs: Any) -> None
A class to perform object detection, image classification, image segmentation and pose estimation inference.
This class provides functionalities for loading models, configuring settings, uploading video files, and performing real-time inference using Streamlit and Ultralytics YOLO models.
Args
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs | Any | Additional keyword arguments for model configuration. | required |
Attributes
| Name | Type | Description |
|---|---|---|
st | module | Streamlit module for UI creation. |
temp_dict | dict | Temporary dictionary to store the model path and other configuration. |
model_path | str | Path to the loaded model. |
model | YOLO | The YOLO model instance. |
source | str | Selected video source (webcam or video file). |
enable_trk | bool | Enable tracking option. |
conf | float | Confidence threshold for detection. |
iou | float | IoU threshold for non-maximum suppression. |
org_frame | Any | Container for the original frame to be displayed. |
ann_frame | Any | Container for the annotated frame to be displayed. |
vid_file_name | str | int | Name of the uploaded video file or webcam index. |
selected_ind | list[int] | List of selected class indices for detection. |
Methods
| Name | Description |
|---|---|
configure | Configure the model and load selected classes for inference. |
image_inference | Perform inference on uploaded images. |
inference | Perform real-time object detection inference on video or webcam feed. |
sidebar | Configure the Streamlit sidebar for model and inference settings. |
source_upload | Handle video file uploads through the Streamlit interface. |
web_ui | Set up the Streamlit web interface with custom HTML elements. |
Examples
Create an Inference instance with a custom model
>>> inf = Inference(model="path/to/model.pt")
>>> inf.inference()
Create an Inference instance with default settings
>>> inf = Inference()
>>> inf.inference()
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubclass Inference:
"""A class to perform object detection, image classification, image segmentation and pose estimation inference.
This class provides functionalities for loading models, configuring settings, uploading video files, and performing
real-time inference using Streamlit and Ultralytics YOLO models.
Attributes:
st (module): Streamlit module for UI creation.
temp_dict (dict): Temporary dictionary to store the model path and other configuration.
model_path (str): Path to the loaded model.
model (YOLO): The YOLO model instance.
source (str): Selected video source (webcam or video file).
enable_trk (bool): Enable tracking option.
conf (float): Confidence threshold for detection.
iou (float): IoU threshold for non-maximum suppression.
org_frame (Any): Container for the original frame to be displayed.
ann_frame (Any): Container for the annotated frame to be displayed.
vid_file_name (str | int): Name of the uploaded video file or webcam index.
selected_ind (list[int]): List of selected class indices for detection.
Methods:
web_ui: Set up the Streamlit web interface with custom HTML elements.
sidebar: Configure the Streamlit sidebar for model and inference settings.
source_upload: Handle video file uploads through the Streamlit interface.
configure: Configure the model and load selected classes for inference.
inference: Perform real-time object detection inference.
Examples:
Create an Inference instance with a custom model
>>> inf = Inference(model="path/to/model.pt")
>>> inf.inference()
Create an Inference instance with default settings
>>> inf = Inference()
>>> inf.inference()
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Inference class, checking Streamlit requirements and setting up the model path.
Args:
**kwargs (Any): Additional keyword arguments for model configuration.
"""
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
import streamlit as st
self.st = st # Reference to the Streamlit module
self.source = None # Video source selection (webcam or video file)
self.img_file_names = [] # List of image file names
self.enable_trk = False # Flag to toggle object tracking
self.conf = 0.25 # Confidence threshold for detection
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
self.org_frame = None # Container for the original frame display
self.ann_frame = None # Container for the annotated frame display
self.vid_file_name = None # Video file name or webcam index
self.selected_ind: list[int] = [] # List of selected class indices for detection
self.model = None # YOLO model instance
self.temp_dict = {"model": None, **kwargs}
self.model_path = None # Model file path
if self.temp_dict["model"] is not None:
self.model_path = self.temp_dict["model"]
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
method ultralytics.solutions.streamlit_inference.Inference.configure
def configure(self) -> None
Configure the model and load selected classes for inference.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef configure(self) -> None:
"""Configure the model and load selected classes for inference."""
# Add dropdown menu for model selection
M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"]
available_models = sorted(
[
x.replace("yolo", "YOLO")
for x in GITHUB_ASSETS_STEMS
if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x
],
key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
)
if self.model_path: # Insert user provided custom model in available_models
available_models.insert(0, self.model_path)
selected_model = self.st.sidebar.selectbox("Model", available_models)
with self.st.spinner("Model is downloading..."):
if selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine")) or any(
fmt in selected_model for fmt in ("openvino_model", "rknn_model")
):
model_path = selected_model
else:
model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
self.model = YOLO(model_path) # Load the YOLO model
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
self.st.success("Model loaded successfully!")
# Multiselect box with class names and get indices of selected classes
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
self.selected_ind = [class_names.index(option) for option in selected_classes]
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
self.selected_ind = list(self.selected_ind)
method ultralytics.solutions.streamlit_inference.Inference.image_inference
def image_inference(self) -> None
Perform inference on uploaded images.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef image_inference(self) -> None:
"""Perform inference on uploaded images."""
for img_info in self.img_file_names:
img_path = img_info["path"]
image = cv2.imread(img_path) # Load and display the original image
if image is not None:
self.st.markdown(f"#### Processed: {img_info['name']}")
col1, col2 = self.st.columns(2)
with col1:
self.st.image(image, channels="BGR", caption="Original Image")
results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
annotated_image = results[0].plot()
with col2:
self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
try: # Clean up temporary file
os.unlink(img_path)
except FileNotFoundError:
pass # File doesn't exist, ignore
else:
self.st.error("Could not load the uploaded image.")
method ultralytics.solutions.streamlit_inference.Inference.inference
def inference(self) -> None
Perform real-time object detection inference on video or webcam feed.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef inference(self) -> None:
"""Perform real-time object detection inference on video or webcam feed."""
self.web_ui() # Initialize the web interface
self.sidebar() # Create the sidebar
self.source_upload() # Upload the video source
self.configure() # Configure the app
if self.st.sidebar.button("Start"):
if self.source == "image":
if self.img_file_names:
self.image_inference()
else:
self.st.info("Please upload an image file to perform inference.")
return
stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
if not cap.isOpened():
self.st.error("Could not open webcam or video source.")
return
while cap.isOpened():
success, frame = cap.read()
if not success:
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
break
# Process frame with model
if self.enable_trk:
results = self.model.track(
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
)
else:
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
annotated_frame = results[0].plot() # Add annotations on frame
if stop_button:
cap.release() # Release the capture
self.st.stop() # Stop streamlit app
self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
cap.release() # Release the capture
cv2.destroyAllWindows() # Destroy all OpenCV windows
method ultralytics.solutions.streamlit_inference.Inference.sidebar
def sidebar(self) -> None
Configure the Streamlit sidebar for model and inference settings.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef sidebar(self) -> None:
"""Configure the Streamlit sidebar for model and inference settings."""
with self.st.sidebar: # Add Ultralytics LOGO
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
self.st.image(logo, width=250)
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
self.source = self.st.sidebar.selectbox(
"Source",
("webcam", "video", "image"),
) # Add source selection dropdown
if self.source in ["webcam", "video"]:
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
self.conf = float(
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
) # Slider for confidence
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
if self.source != "image": # Only create columns for video/webcam
col1, col2 = self.st.columns(2) # Create two columns for displaying frames
self.org_frame = col1.empty() # Container for original frame
self.ann_frame = col2.empty() # Container for annotated frame
method ultralytics.solutions.streamlit_inference.Inference.source_upload
def source_upload(self) -> None
Handle video file uploads through the Streamlit interface.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef source_upload(self) -> None:
"""Handle video file uploads through the Streamlit interface."""
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
self.vid_file_name = ""
if self.source == "video":
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
if vid_file is not None:
g = io.BytesIO(vid_file.read()) # BytesIO Object
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
out.write(g.read()) # Read bytes into file
self.vid_file_name = "ultralytics.mp4"
elif self.source == "webcam":
self.vid_file_name = 0 # Use webcam index 0
elif self.source == "image":
import tempfile # scope import
if imgfiles := self.st.sidebar.file_uploader(
"Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
):
for imgfile in imgfiles: # Save each uploaded image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
tf.write(imgfile.read())
self.img_file_names.append({"path": tf.name, "name": imgfile.name})
method ultralytics.solutions.streamlit_inference.Inference.web_ui
def web_ui(self) -> None
Set up the Streamlit web interface with custom HTML elements.
Source code in ultralytics/solutions/streamlit_inference.py
View on GitHubdef web_ui(self) -> None:
"""Set up the Streamlit web interface with custom HTML elements."""
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
# Main title of streamlit application
main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
# Subtitle of streamlit application
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
with the power of Ultralytics YOLO! 🚀</h5></div>"""
# Set html page configuration and append custom HTML
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)