Skip to content

Reference for ultralytics/data/explorer/explorer.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/explorer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.data.explorer.explorer.ExplorerDataset

ExplorerDataset(*args, data: dict = None, **kwargs)

Bases: YOLODataset

Extends YOLODataset for advanced data exploration and manipulation in model training workflows.

Source code in ultralytics/data/explorer/explorer.py
def __init__(self, *args, data: dict = None, **kwargs) -> None:
    """Initializes the ExplorerDataset with the provided data arguments, extending the YOLODataset class."""
    super().__init__(*args, data=data, **kwargs)

build_transforms

build_transforms(hyp: IterableSimpleNamespace = None)

Creates transforms for dataset images without resizing.

Source code in ultralytics/data/explorer/explorer.py
def build_transforms(self, hyp: IterableSimpleNamespace = None):
    """Creates transforms for dataset images without resizing."""
    return Format(
        bbox_format="xyxy",
        normalize=False,
        return_mask=self.use_segments,
        return_keypoint=self.use_keypoints,
        batch_idx=True,
        mask_ratio=hyp.mask_ratio,
        mask_overlap=hyp.overlap_mask,
    )

load_image

load_image(
    i: int,
) -> Union[
    Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]
]

Loads 1 image from dataset index 'i' without any resize ops.

Source code in ultralytics/data/explorer/explorer.py
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
    """Loads 1 image from dataset index 'i' without any resize ops."""
    im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
    if im is None:  # not cached in RAM
        if fn.exists():  # load npy
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
            if im is None:
                raise FileNotFoundError(f"Image Not Found {f}")
        h0, w0 = im.shape[:2]  # orig hw
        return im, (h0, w0), im.shape[:2]

    return self.ims[i], self.im_hw0[i], self.im_hw[i]





ultralytics.data.explorer.explorer.Explorer

Explorer(
    data: Union[str, Path] = "coco128.yaml",
    model: str = "yolov8n.pt",
    uri: str = USER_CONFIG_DIR / "explorer",
)

Utility class for image embedding, table creation, and similarity querying using LanceDB and YOLO models.

Source code in ultralytics/data/explorer/explorer.py
def __init__(
    self,
    data: Union[str, Path] = "coco128.yaml",
    model: str = "yolov8n.pt",
    uri: str = USER_CONFIG_DIR / "explorer",
) -> None:
    """Initializes the Explorer class with dataset path, model, and URI for database connection."""
    # Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
    checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
    import lancedb

    self.connection = lancedb.connect(uri)
    self.table_name = f"{Path(data).name.lower()}_{model.lower()}"
    self.sim_idx_base_name = (
        f"{self.table_name}_sim_idx".lower()
    )  # Use this name and append thres and top_k to reuse the table
    self.model = YOLO(model)
    self.data = data  # None
    self.choice_set = None

    self.table = None
    self.progress = 0

ask_ai

ask_ai(query)

Ask AI a question.

Parameters:

Name Type Description Default
query str

Question to ask.

required

Returns:

Type Description
DataFrame

A dataframe containing filtered results to the SQL query.

Example
exp = Explorer()
exp.create_embeddings_table()
answer = exp.ask_ai("Show images with 1 person and 2 dogs")
Source code in ultralytics/data/explorer/explorer.py
def ask_ai(self, query):
    """
    Ask AI a question.

    Args:
        query (str): Question to ask.

    Returns:
        (pandas.DataFrame): A dataframe containing filtered results to the SQL query.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        answer = exp.ask_ai("Show images with 1 person and 2 dogs")
        ```
    """
    result = prompt_sql_query(query)
    try:
        return self.sql_query(result)
    except Exception as e:
        LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
        LOGGER.error(e)
        return None

create_embeddings_table

create_embeddings_table(force: bool = False, split: str = 'train') -> None

Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it already exists. Pass force=True to overwrite the existing table.

Parameters:

Name Type Description Default
force bool

Whether to overwrite the existing table or not. Defaults to False.

False
split str

Split of the dataset to use. Defaults to 'train'.

'train'
Example
exp = Explorer()
exp.create_embeddings_table()
Source code in ultralytics/data/explorer/explorer.py
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
    """
    Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
    already exists. Pass force=True to overwrite the existing table.

    Args:
        force (bool): Whether to overwrite the existing table or not. Defaults to False.
        split (str): Split of the dataset to use. Defaults to 'train'.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        ```
    """
    if self.table is not None and not force:
        LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
        return
    if self.table_name in self.connection.table_names() and not force:
        LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
        self.table = self.connection.open_table(self.table_name)
        self.progress = 1
        return
    if self.data is None:
        raise ValueError("Data must be provided to create embeddings table")

    data_info = check_det_dataset(self.data)
    if split not in data_info:
        raise ValueError(
            f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
        )

    choice_set = data_info[split]
    choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
    self.choice_set = choice_set
    dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)

    # Create the table schema
    batch = dataset[0]
    vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
    table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
    table.add(
        self._yield_batches(
            dataset,
            data_info,
            self.model,
            exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
        )
    )

    self.table = table

get_similar

get_similar(
    img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
    idx: Union[int, List[int]] = None,
    limit: int = 25,
    return_type: str = "pandas",
) -> Any

Query the table for similar images. Accepts a single image or a list of images.

Parameters:

Name Type Description Default
img str or list

Path to the image or a list of paths to the images.

None
idx int or list

Index of the image in the table or a list of indexes.

None
limit int

Number of results to return. Defaults to 25.

25
return_type str

Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

'pandas'

Returns:

Type Description
DataFrame

A dataframe containing the results.

Example
exp = Explorer()
exp.create_embeddings_table()
similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg")
Source code in ultralytics/data/explorer/explorer.py
def get_similar(
    self,
    img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
    idx: Union[int, List[int]] = None,
    limit: int = 25,
    return_type: str = "pandas",
) -> Any:  # pandas.DataFrame or pyarrow.Table
    """
    Query the table for similar images. Accepts a single image or a list of images.

    Args:
        img (str or list): Path to the image or a list of paths to the images.
        idx (int or list): Index of the image in the table or a list of indexes.
        limit (int): Number of results to return. Defaults to 25.
        return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

    Returns:
        (pandas.DataFrame): A dataframe containing the results.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg")
        ```
    """
    assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"
    img = self._check_imgs_or_idxs(img, idx)
    similar = self.query(img, limit=limit)

    if return_type == "arrow":
        return similar
    elif return_type == "pandas":
        return similar.to_pandas()

plot_similar

plot_similar(
    img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
    idx: Union[int, List[int]] = None,
    limit: int = 25,
    labels: bool = True,
) -> Image.Image

Plot the similar images. Accepts images or indexes.

Parameters:

Name Type Description Default
img str or list

Path to the image or a list of paths to the images.

None
idx int or list

Index of the image in the table or a list of indexes.

None
labels bool

Whether to plot the labels or not.

True
limit int

Number of results to return. Defaults to 25.

25

Returns:

Type Description
Image

Image containing the plot.

Example
exp = Explorer()
exp.create_embeddings_table()
similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg")
Source code in ultralytics/data/explorer/explorer.py
def plot_similar(
    self,
    img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
    idx: Union[int, List[int]] = None,
    limit: int = 25,
    labels: bool = True,
) -> Image.Image:
    """
    Plot the similar images. Accepts images or indexes.

    Args:
        img (str or list): Path to the image or a list of paths to the images.
        idx (int or list): Index of the image in the table or a list of indexes.
        labels (bool): Whether to plot the labels or not.
        limit (int): Number of results to return. Defaults to 25.

    Returns:
        (PIL.Image): Image containing the plot.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg")
        ```
    """
    similar = self.get_similar(img, idx, limit, return_type="arrow")
    if len(similar) == 0:
        LOGGER.info("No results found.")
        return None
    img = plot_query_result(similar, plot_labels=labels)
    return Image.fromarray(img)

plot_similarity_index

plot_similarity_index(
    max_dist: float = 0.2, top_k: float = None, force: bool = False
) -> Image

Plot the similarity index of all the images in the table. Here, the index will contain the data points that are max_dist or closer to the image in the embedding space at a given index.

Parameters:

Name Type Description Default
max_dist float

maximum L2 distance between the embeddings to consider. Defaults to 0.2.

0.2
top_k float

Percentage of closest data points to consider when counting. Used to apply limit when running vector search. Defaults to 0.01.

None
force bool

Whether to overwrite the existing similarity index or not. Defaults to True.

False

Returns:

Type Description
Image

Image containing the plot.

Example
exp = Explorer()
exp.create_embeddings_table()

similarity_idx_plot = exp.plot_similarity_index()
similarity_idx_plot.show()  # view image preview
similarity_idx_plot.save("path/to/save/similarity_index_plot.png")  # save contents to file
Source code in ultralytics/data/explorer/explorer.py
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
    """
    Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
    max_dist or closer to the image in the embedding space at a given index.

    Args:
        max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
        top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
            running vector search. Defaults to 0.01.
        force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.

    Returns:
        (PIL.Image): Image containing the plot.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()

        similarity_idx_plot = exp.plot_similarity_index()
        similarity_idx_plot.show()  # view image preview
        similarity_idx_plot.save("path/to/save/similarity_index_plot.png")  # save contents to file
        ```
    """
    sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
    sim_count = sim_idx["count"].tolist()
    sim_count = np.array(sim_count)

    indices = np.arange(len(sim_count))

    # Create the bar plot
    plt.bar(indices, sim_count)

    # Customize the plot (optional)
    plt.xlabel("data idx")
    plt.ylabel("Count")
    plt.title("Similarity Count")
    buffer = BytesIO()
    plt.savefig(buffer, format="png")
    buffer.seek(0)

    # Use Pillow to open the image from the buffer
    return Image.fromarray(np.array(Image.open(buffer)))

plot_sql_query

plot_sql_query(query: str, labels: bool = True) -> Image.Image

Plot the results of a SQL-Like query on the table.

Parameters:

Name Type Description Default
query str

SQL query to run.

required
labels bool

Whether to plot the labels or not.

True

Returns:

Type Description
Image

Image containing the plot.

Example
exp = Explorer()
exp.create_embeddings_table()
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.plot_sql_query(query)
Source code in ultralytics/data/explorer/explorer.py
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
    """
    Plot the results of a SQL-Like query on the table.

    Args:
        query (str): SQL query to run.
        labels (bool): Whether to plot the labels or not.

    Returns:
        (PIL.Image): Image containing the plot.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
        result = exp.plot_sql_query(query)
        ```
    """
    result = self.sql_query(query, return_type="arrow")
    if len(result) == 0:
        LOGGER.info("No results found.")
        return None
    img = plot_query_result(result, plot_labels=labels)
    return Image.fromarray(img)

query

query(
    imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
    limit: int = 25,
) -> Any

Query the table for similar images. Accepts a single image or a list of images.

Parameters:

Name Type Description Default
imgs str or list

Path to the image or a list of paths to the images.

None
limit int

Number of results to return.

25

Returns:

Type Description
Table

An arrow table containing the results. Supports converting to: - pandas dataframe: result.to_pandas() - dict of lists: result.to_pydict()

Example
exp = Explorer()
exp.create_embeddings_table()
similar = exp.query(img="https://ultralytics.com/images/zidane.jpg")
Source code in ultralytics/data/explorer/explorer.py
def query(
    self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
) -> Any:  # pyarrow.Table
    """
    Query the table for similar images. Accepts a single image or a list of images.

    Args:
        imgs (str or list): Path to the image or a list of paths to the images.
        limit (int): Number of results to return.

    Returns:
        (pyarrow.Table): An arrow table containing the results. Supports converting to:
            - pandas dataframe: `result.to_pandas()`
            - dict of lists: `result.to_pydict()`

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        similar = exp.query(img="https://ultralytics.com/images/zidane.jpg")
        ```
    """
    if self.table is None:
        raise ValueError("Table is not created. Please create the table first.")
    if isinstance(imgs, str):
        imgs = [imgs]
    assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
    embeds = self.model.embed(imgs)
    # Get avg if multiple images are passed (len > 1)
    embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
    return self.table.search(embeds).limit(limit).to_arrow()

similarity_index

similarity_index(
    max_dist: float = 0.2, top_k: float = None, force: bool = False
) -> Any

Calculate the similarity index of all the images in the table. Here, the index will contain the data points that are max_dist or closer to the image in the embedding space at a given index.

Parameters:

Name Type Description Default
max_dist float

maximum L2 distance between the embeddings to consider. Defaults to 0.2.

0.2
top_k float

Percentage of the closest data points to consider when counting. Used to apply limit. vector search. Defaults: None.

None
force bool

Whether to overwrite the existing similarity index or not. Defaults to True.

False

Returns:

Type Description
DataFrame

A dataframe containing the similarity index. Each row corresponds to an image, and columns include indices of similar images and their respective distances.

Example
exp = Explorer()
exp.create_embeddings_table()
sim_idx = exp.similarity_index()
Source code in ultralytics/data/explorer/explorer.py
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any:  # pd.DataFrame
    """
    Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
    are max_dist or closer to the image in the embedding space at a given index.

    Args:
        max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
        top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
                       vector search. Defaults: None.
        force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.

    Returns:
        (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
            and columns include indices of similar images and their respective distances.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        sim_idx = exp.similarity_index()
        ```
    """
    if self.table is None:
        raise ValueError("Table is not created. Please create the table first.")
    sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
    if sim_idx_table_name in self.connection.table_names() and not force:
        LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
        return self.connection.open_table(sim_idx_table_name).to_pandas()

    if top_k and not (1.0 >= top_k >= 0.0):
        raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
    if max_dist < 0.0:
        raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")

    top_k = int(top_k * len(self.table)) if top_k else len(self.table)
    top_k = max(top_k, 1)
    features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
    im_files = features["im_file"]
    embeddings = features["vector"]

    sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")

    def _yield_sim_idx():
        """Generates a dataframe with similarity indices and distances for images."""
        for i in tqdm(range(len(embeddings))):
            sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
            yield [
                {
                    "idx": i,
                    "im_file": im_files[i],
                    "count": len(sim_idx),
                    "sim_im_files": sim_idx["im_file"].tolist(),
                }
            ]

    sim_table.add(_yield_sim_idx())
    self.sim_index = sim_table
    return sim_table.to_pandas()

sql_query

sql_query(query: str, return_type: str = 'pandas') -> Union[Any, None]

Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.

Parameters:

Name Type Description Default
query str

SQL query to run.

required
return_type str

Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

'pandas'

Returns:

Type Description
Table

An arrow table containing the results.

Example
exp = Explorer()
exp.create_embeddings_table()
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.sql_query(query)
Source code in ultralytics/data/explorer/explorer.py
def sql_query(
    self, query: str, return_type: str = "pandas"
) -> Union[Any, None]:  # pandas.DataFrame or pyarrow.Table
    """
    Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.

    Args:
        query (str): SQL query to run.
        return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

    Returns:
        (pyarrow.Table): An arrow table containing the results.

    Example:
        ```python
        exp = Explorer()
        exp.create_embeddings_table()
        query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
        result = exp.sql_query(query)
        ```
    """
    assert return_type in {
        "pandas",
        "arrow",
    }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
    import duckdb

    if self.table is None:
        raise ValueError("Table is not created. Please create the table first.")

    # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
    table = self.table.to_arrow()  # noqa NOTE: Don't comment this. This line is used by DuckDB
    if not query.startswith("SELECT") and not query.startswith("WHERE"):
        raise ValueError(
            f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
            f"clause. found {query}"
        )
    if query.startswith("WHERE"):
        query = f"SELECT * FROM 'table' {query}"
    LOGGER.info(f"Running query: {query}")

    rs = duckdb.sql(query)
    if return_type == "arrow":
        return rs.arrow()
    elif return_type == "pandas":
        return rs.df()




📅 Created 9 months ago ✏️ Updated 1 month ago

Comments