Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/onnx_safetensors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
"""Use safetensors with ONNX."""

__all__ = [
"apply_tensors",
"load",
"load_file",
"load_file_as_external_data",
"read_safetensors",
"replace_tensors",
"save",
"save_file",
]

from onnx_safetensors._safetensors_io import (
apply_tensors,
load,
load_file,
load_file_as_external_data,
read_safetensors,
replace_tensors,
save,
save_file,
Expand Down
77 changes: 61 additions & 16 deletions src/onnx_safetensors/_safetensors_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import TYPE_CHECKING, Any, TypeVar

import onnx
import onnx.helper
import onnx_ir as ir
import safetensors

Expand Down Expand Up @@ -62,22 +61,41 @@
ir.DataType.UINT32: "uint32",
ir.DataType.UINT64: "uint64",
}
_CASTABLE_DTYPES = frozenset(
{
ir.DataType.BFLOAT16,
ir.DataType.FLOAT16,
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
ir.DataType.INT8,
ir.DataType.INT16,
ir.DataType.INT32,
ir.DataType.INT64,
ir.DataType.UINT8,
ir.DataType.UINT16,
ir.DataType.UINT32,
ir.DataType.UINT64,
}
)


TModel = TypeVar("TModel", onnx.ModelProto, ir.Model)


def _apply_tensors(
def apply_tensors(
model: ir.Model,
tensors: Mapping[str, ir.TensorProtocol],
apply_safetensors: bool = False,
*,
apply_safetensors: bool = True,
cast: bool = False,
):
"""Apply tensors to an ONNX model.

Args:
model: ONNX model to apply tensors to.
tensors: Tensors to apply to the ONNX model.
apply_safetensors: Whether it is applying safetensors to the ONNX model.
cast: Whether to cast the tensors to the dtype in the model if they differ.
"""
graph = model.graph
for name, tensor in tensors.items():
Expand All @@ -86,8 +104,26 @@ def _apply_tensors(
model_tensor = graph.initializers[name].const_value
if model_tensor is not None and apply_safetensors:
assert isinstance(tensor, ir.ExternalTensor)
_check_tensors_match(model_tensor, tensor)
updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
_check_tensors_match(model_tensor, tensor, cast=cast)
if (
model_tensor.dtype != tensor.dtype
and not _is_4bit(model_tensor.dtype)
and not _is_8bit_float(model_tensor.dtype)
):
if model_tensor.dtype not in _CASTABLE_DTYPES:
raise ValueError(
f"Cannot cast tensor '{name}' from dtype {tensor.dtype} to {model_tensor.dtype}."
)
updated_tensor = ir.LazyTensor(
lambda tensor=tensor, model_tensor=model_tensor: ir.Tensor(
tensor.numpy().astype(model_tensor.dtype.numpy())
),
dtype=model_tensor.dtype,
shape=model_tensor.shape,
name=model_tensor.name,
)
else:
updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
else:
updated_tensor = tensor
graph.initializers[name].const_value = updated_tensor
Expand All @@ -111,20 +147,25 @@ def _is_8bit_float(dtype: ir.DataType) -> bool:


def replace_tensors(
model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike
model: ir.Model,
/,
location: str | os.PathLike,
base_dir: str | os.PathLike,
cast: bool = False,
) -> None:
"""Replace all tensors in an ONNX model with external data from a safetensors file.

Args:
model: ONNX model to replace tensors in.
location: Path to the safetensors file relative to the ONNX model file.
base_dir: Directory where the ONNX model file is stored.
cast: Whether to cast the tensors to the dtype in the model if they differ.

.. versionadded:: 1.0
Added the function.
"""
tensors = _read_safetensors(location, base_dir=base_dir)
_apply_tensors(model, tensors, apply_safetensors=True)
tensors = read_safetensors(location, base_dir=base_dir)
apply_tensors(model, tensors, apply_safetensors=True, cast=cast)


def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel:
Expand Down Expand Up @@ -176,7 +217,7 @@ def load(model: TModel, /, data: bytes) -> TModel:
)
for (name, metadata) in tensors
}
_apply_tensors(model_ir, tensors_dict)
apply_tensors(model_ir, tensors_dict, apply_safetensors=False)

if isinstance(model, onnx.ModelProto):
return ir.serde.serialize_model(model_ir)
Expand Down Expand Up @@ -327,10 +368,10 @@ def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]]
return json.loads(header.decode("utf-8")), header_size


def _read_safetensors(
def read_safetensors(
location: str | os.PathLike, base_dir: str | os.PathLike
) -> dict[str, ir.ExternalTensor]:
"""Read a safetensors file.
"""Read a safetensors file and return a mapping of tensor names to ExternalTensors.

Args:
location: The safetensors file to read.
Expand All @@ -344,6 +385,8 @@ def _read_safetensors(
header, header_size = _read_safetensors_header(file)
tensors = {}
for name, metadata in header.items():
if name == "__metadata__":
continue
offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE
length = metadata["data_offsets"][1] - metadata["data_offsets"][0]
tensors[name] = ir.ExternalTensor(
Expand All @@ -359,13 +402,14 @@ def _read_safetensors(


def _check_tensors_match(
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor, cast: bool = False
):
"""Check if two tensors match.

Args:
model_tensor: Tensor from the model.
safe_tensor: Tensor from the safetensors file.
cast: Whether to allow casting of the tensor from safetensors to match the model tensor.

Raises:
ValueError: If the tensors do not match.
Expand All @@ -392,9 +436,9 @@ def _check_tensors_match(
f"The tensor from safetensors has dtype: {safe_tensor.dtype}, but it must be UINT8 to "
f"represent the dtype of the tensor in the model: {model_tensor.dtype}."
)
elif model_tensor.dtype != safe_tensor.dtype:
elif model_tensor.dtype != safe_tensor.dtype and not cast:
raise ValueError(
f"The tensor from safetensors has dtype: {safe_tensor.dtype}, "
f"The tensor '{model_tensor.name}' from safetensors has dtype: {safe_tensor.dtype}, "
f"which does not match the dtype of the tensor in the model: {model_tensor.dtype}."
)

Expand All @@ -406,8 +450,8 @@ def _check_tensors_match(


def _migrate_tensor_shape_dtype(
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
) -> ir.ExternalTensor:
model_tensor: ir.TensorProtocol, safe_tensor: ir.TensorProtocol
) -> ir.TensorProtocol:
"""Migrate the shape and dtype of a tensor.

Args:
Expand All @@ -423,6 +467,7 @@ def _migrate_tensor_shape_dtype(
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2FNUZ,
} or _is_4bit(model_tensor.dtype):
assert isinstance(safe_tensor, ir.ExternalTensor)
return ir.ExternalTensor(
location=safe_tensor.location,
offset=safe_tensor.offset,
Expand Down