跳至内容

参考资料 ultralytics/utils/triton.py

备注

该文件可从https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/utils/ triton.py 获取。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.utils.triton.TritonRemoteModel

与远程Triton Inference Server 模型交互的客户端。

属性

名称 类型 说明
endpoint str

Triton 服务器上的型号名称。

url str

Triton 服务器的 URL。

triton_client

Triton 客户端(HTTP 或 gRPC)。

InferInput

Triton 客户端的输入类。

InferRequestedOutput

Triton 客户端的输出请求类。

input_formats List[str]

模型输入的数据类型。

np_input_formats List[type]

模型输入的 numpy 数据类型。

input_names List[str]

模型输入的名称。

output_names List[str]

模型输出的名称。

源代码 ultralytics/utils/triton.py
class TritonRemoteModel:
    """
    Client for interacting with a remote Triton Inference Server model.

    Attributes:
        endpoint (str): The name of the model on the Triton server.
        url (str): The URL of the Triton server.
        triton_client: The Triton client (either HTTP or gRPC).
        InferInput: The input class for the Triton client.
        InferRequestedOutput: The output request class for the Triton client.
        input_formats (List[str]): The data types of the model inputs.
        np_input_formats (List[type]): The numpy data types of the model inputs.
        input_names (List[str]): The names of the model inputs.
        output_names (List[str]): The names of the model outputs.
    """

    def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
        """
        Initialize the TritonRemoteModel.

        Arguments may be provided individually or parsed from a collective 'url' argument of the form
            <scheme>://<netloc>/<endpoint>/<task_name>

        Args:
            url (str): The URL of the Triton server.
            endpoint (str): The name of the model on the Triton server.
            scheme (str): The communication scheme ('http' or 'grpc').
        """
        if not endpoint and not scheme:  # Parse all args from URL string
            splits = urlsplit(url)
            endpoint = splits.path.strip("/").split("/")[0]
            scheme = splits.scheme
            url = splits.netloc

        self.endpoint = endpoint
        self.url = url

        # Choose the Triton client based on the communication scheme
        if scheme == "http":
            import tritonclient.http as client  # noqa

            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint)
        else:
            import tritonclient.grpc as client  # noqa

            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]

        # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
        config["output"] = sorted(config["output"], key=lambda x: x.get("name"))

        # Define model attributes
        type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
        self.InferRequestedOutput = client.InferRequestedOutput
        self.InferInput = client.InferInput
        self.input_formats = [x["data_type"] for x in config["input"]]
        self.np_input_formats = [type_map[x] for x in self.input_formats]
        self.input_names = [x["name"] for x in config["input"]]
        self.output_names = [x["name"] for x in config["output"]]

    def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
        """
        Call the model with the given inputs.

        Args:
            *inputs (List[np.ndarray]): Input data to the model.

        Returns:
            (List[np.ndarray]): Model outputs.
        """
        infer_inputs = []
        input_format = inputs[0].dtype
        for i, x in enumerate(inputs):
            if x.dtype != self.np_input_formats[i]:
                x = x.astype(self.np_input_formats[i])
            infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
            infer_input.set_data_from_numpy(x)
            infer_inputs.append(infer_input)

        infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
        outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)

        return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

__call__(*inputs)

使用给定的输入调用模型。

参数

名称 类型 说明 默认值
*inputs List[ndarray]

模型的输入数据。

()

返回:

类型 说明
List[ndarray]

模型输出。

源代码 ultralytics/utils/triton.py
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
    """
    Call the model with the given inputs.

    Args:
        *inputs (List[np.ndarray]): Input data to the model.

    Returns:
        (List[np.ndarray]): Model outputs.
    """
    infer_inputs = []
    input_format = inputs[0].dtype
    for i, x in enumerate(inputs):
        if x.dtype != self.np_input_formats[i]:
            x = x.astype(self.np_input_formats[i])
        infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
        infer_input.set_data_from_numpy(x)
        infer_inputs.append(infer_input)

    infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
    outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)

    return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

__init__(url, endpoint='', scheme='')

初始化 TritonRemoteModel。

参数可以单独提供,也可以从形式为 "url "的集合参数中解析出来。 :////

参数

名称 类型 说明 默认值
url str

Triton 服务器的 URL。

所需
endpoint str

Triton 服务器上的型号名称。

''
scheme str

通信方案("http "或 "grpc")。

''
源代码 ultralytics/utils/triton.py
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
    """
    Initialize the TritonRemoteModel.

    Arguments may be provided individually or parsed from a collective 'url' argument of the form
        <scheme>://<netloc>/<endpoint>/<task_name>

    Args:
        url (str): The URL of the Triton server.
        endpoint (str): The name of the model on the Triton server.
        scheme (str): The communication scheme ('http' or 'grpc').
    """
    if not endpoint and not scheme:  # Parse all args from URL string
        splits = urlsplit(url)
        endpoint = splits.path.strip("/").split("/")[0]
        scheme = splits.scheme
        url = splits.netloc

    self.endpoint = endpoint
    self.url = url

    # Choose the Triton client based on the communication scheme
    if scheme == "http":
        import tritonclient.http as client  # noqa

        self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
        config = self.triton_client.get_model_config(endpoint)
    else:
        import tritonclient.grpc as client  # noqa

        self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
        config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]

    # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
    config["output"] = sorted(config["output"], key=lambda x: x.get("name"))

    # Define model attributes
    type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
    self.InferRequestedOutput = client.InferRequestedOutput
    self.InferInput = client.InferInput
    self.input_formats = [x["data_type"] for x in config["input"]]
    self.np_input_formats = [type_map[x] for x in self.input_formats]
    self.input_names = [x["name"] for x in config["input"]]
    self.output_names = [x["name"] for x in config["output"]]





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (6), Burhan-Q (1)