from onnx.onnx_pb import (
    AttributeProto,
    SparseTensorProto,
    TensorProto,
    TensorShapeProto,
    TypeProto,
)

class InferenceError(Exception): ...

class GraphInferencer:
    def do_inferencing(
        self, input_types: list[TypeProto], input_data: list[TensorProto | None]
    ) -> list[TypeProto]: ...

class InferenceContext:
    def get_attribute(self, name: str) -> AttributeProto: ...
    def get_num_inputs(self) -> int: ...
    def get_input_type(self, idx: int) -> TypeProto: ...
    def has_input(self, idx: int) -> bool: ...
    def get_input_data(self, idx: int) -> TensorProto: ...
    def get_num_outputs(self) -> int: ...
    def get_output_type(self, idx: int) -> TypeProto: ...
    def set_output_type(self, idx: int, type_proto: TypeProto) -> bool: ...
    def has_output(self, idx: int) -> bool: ...
    def get_graph_attribute_inferencer(self, attr_name: str) -> GraphInferencer: ...
    def get_input_sparse_data(self, idx: int) -> SparseTensorProto: ...
    def get_symbolic_input(self, idx: int) -> TensorShapeProto: ...
    def get_display_name(self) -> str: ...

def infer_shapes(
    b: bytes, check_type: bool, strict_mode: bool, data_prop: bool
) -> bytes: ...
def infer_shapes_path(
    model_path: str,
    output_path: str,
    check_type: bool,
    strict_mode: bool,
    data_prop: bool,
) -> None: ...
def infer_function_output_types(
    b: bytes,
    input_types: list[bytes],
    attributes: list[bytes],
) -> list[bytes]: ...
