From 464a610636d760887ff272720fc196140bad031c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 19:47:59 -0700 Subject: [PATCH 1/4] Update Signed-off-by: Justin Chu --- src/onnx_safetensors/_safetensors_io.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/onnx_safetensors/_safetensors_io.py b/src/onnx_safetensors/_safetensors_io.py index 4044b7f..3a378c7 100644 --- a/src/onnx_safetensors/_safetensors_io.py +++ b/src/onnx_safetensors/_safetensors_io.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, TypeVar import onnx -import onnx.helper import safetensors from onnxscript import ir @@ -111,7 +110,7 @@ def _is_8bit_float(dtype: ir.DataType) -> bool: def replace_tensors( - model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike + model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike, cast: bool = False ) -> None: """Replace all tensors in an ONNX model with external data from a safetensors file. @@ -119,12 +118,13 @@ def replace_tensors( model: ONNX model to replace tensors in. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. + cast: Whether to cast the tensors to the dtype in the model if they differ. .. versionadded:: 1.0 Added the function. """ - tensors = _read_safetensors(location, base_dir=base_dir) - _apply_tensors(model, tensors, apply_safetensors=True) + tensors = read_safetensors(location, base_dir=base_dir) + _apply_tensors(model, tensors, apply_safetensors=True, cast=cast) def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel: @@ -327,7 +327,7 @@ def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]] return json.loads(header.decode("utf-8")), header_size -def _read_safetensors( +def read_safetensors( location: str | os.PathLike, base_dir: str | os.PathLike ) -> dict[str, ir.ExternalTensor]: """Read a safetensors file. @@ -344,6 +344,8 @@ def _read_safetensors( header, header_size = _read_safetensors_header(file) tensors = {} for name, metadata in header.items(): + if name == "__metadata__": + continue offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE length = metadata["data_offsets"][1] - metadata["data_offsets"][0] tensors[name] = ir.ExternalTensor( @@ -394,7 +396,7 @@ def _check_tensors_match( ) elif model_tensor.dtype != safe_tensor.dtype: raise ValueError( - f"The tensor from safetensors has dtype: {safe_tensor.dtype}, " + f"The tensor '{model_tensor.name}' from safetensors has dtype: {safe_tensor.dtype}, " f"which does not match the dtype of the tensor in the model: {model_tensor.dtype}." ) From 51bb88372fd4ab4bd0325fee19e53a3cca014eb1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 20:03:47 -0700 Subject: [PATCH 2/4] Support hf safetensors Signed-off-by: Justin Chu --- src/onnx_safetensors/__init__.py | 4 ++ src/onnx_safetensors/_safetensors_io.py | 53 ++++++++++++++++++++----- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/onnx_safetensors/__init__.py b/src/onnx_safetensors/__init__.py index 58710d9..b4dd825 100644 --- a/src/onnx_safetensors/__init__.py +++ b/src/onnx_safetensors/__init__.py @@ -1,18 +1,22 @@ """Use safetensors with ONNX.""" __all__ = [ + "apply_tensors", "load", "load_file", "load_file_as_external_data", + "read_safetensors", "replace_tensors", "save", "save_file", ] from onnx_safetensors._safetensors_io import ( + apply_tensors, load, load_file, load_file_as_external_data, + read_safetensors, replace_tensors, save, save_file, diff --git a/src/onnx_safetensors/_safetensors_io.py b/src/onnx_safetensors/_safetensors_io.py index d80b961..f1f5a95 100644 --- a/src/onnx_safetensors/_safetensors_io.py +++ b/src/onnx_safetensors/_safetensors_io.py @@ -61,15 +61,31 @@ ir.DataType.UINT32: "uint32", ir.DataType.UINT64: "uint64", } +_CASTABLE_DTYPES = frozenset({ + ir.DataType.BFLOAT16, + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.INT8, + ir.DataType.INT16, + ir.DataType.INT32, + ir.DataType.INT64, + ir.DataType.UINT8, + ir.DataType.UINT16, + ir.DataType.UINT32, + ir.DataType.UINT64, +}) TModel = TypeVar("TModel", onnx.ModelProto, ir.Model) -def _apply_tensors( +def apply_tensors( model: ir.Model, tensors: Mapping[str, ir.TensorProtocol], - apply_safetensors: bool = False, + *, + apply_safetensors: bool = True, + cast: bool = False, ): """Apply tensors to an ONNX model. @@ -77,6 +93,7 @@ def _apply_tensors( model: ONNX model to apply tensors to. tensors: Tensors to apply to the ONNX model. apply_safetensors: Whether it is applying safetensors to the ONNX model. + cast: Whether to cast the tensors to the dtype in the model if they differ. """ graph = model.graph for name, tensor in tensors.items(): @@ -85,8 +102,20 @@ def _apply_tensors( model_tensor = graph.initializers[name].const_value if model_tensor is not None and apply_safetensors: assert isinstance(tensor, ir.ExternalTensor) - _check_tensors_match(model_tensor, tensor) - updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor) + _check_tensors_match(model_tensor, tensor, cast=cast) + if model_tensor.dtype != tensor.dtype: + if model_tensor.dtype not in _CASTABLE_DTYPES: + raise ValueError( + f"Cannot cast tensor '{name}' from dtype {tensor.dtype} to {model_tensor.dtype}." + ) + updated_tensor = ir.LazyTensor( + lambda: ir.Tensor(tensor.numpy().astype(model_tensor.dtype.numpy())), + dtype=model_tensor.dtype, + shape=model_tensor.shape, + name=model_tensor.name, + ) + else: + updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor) else: updated_tensor = tensor graph.initializers[name].const_value = updated_tensor @@ -124,7 +153,7 @@ def replace_tensors( Added the function. """ tensors = read_safetensors(location, base_dir=base_dir) - _apply_tensors(model, tensors, apply_safetensors=True, cast=cast) + apply_tensors(model, tensors, apply_safetensors=True, cast=cast) def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel: @@ -176,7 +205,7 @@ def load(model: TModel, /, data: bytes) -> TModel: ) for (name, metadata) in tensors } - _apply_tensors(model_ir, tensors_dict) + apply_tensors(model_ir, tensors_dict, apply_safetensors=False) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) @@ -330,7 +359,7 @@ def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]] def read_safetensors( location: str | os.PathLike, base_dir: str | os.PathLike ) -> dict[str, ir.ExternalTensor]: - """Read a safetensors file. + """Read a safetensors file and return a mapping of tensor names to ExternalTensors. Args: location: The safetensors file to read. @@ -361,13 +390,14 @@ def read_safetensors( def _check_tensors_match( - model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor + model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor, cast: bool = False ): """Check if two tensors match. Args: model_tensor: Tensor from the model. safe_tensor: Tensor from the safetensors file. + cast: Whether to allow casting of the tensor from safetensors to match the model tensor. Raises: ValueError: If the tensors do not match. @@ -394,7 +424,7 @@ def _check_tensors_match( f"The tensor from safetensors has dtype: {safe_tensor.dtype}, but it must be UINT8 to " f"represent the dtype of the tensor in the model: {model_tensor.dtype}." ) - elif model_tensor.dtype != safe_tensor.dtype: + elif model_tensor.dtype != safe_tensor.dtype and not cast: raise ValueError( f"The tensor '{model_tensor.name}' from safetensors has dtype: {safe_tensor.dtype}, " f"which does not match the dtype of the tensor in the model: {model_tensor.dtype}." @@ -408,8 +438,8 @@ def _check_tensors_match( def _migrate_tensor_shape_dtype( - model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor -) -> ir.ExternalTensor: + model_tensor: ir.TensorProtocol, safe_tensor: ir.TensorProtocol +) -> ir.TensorProtocol: """Migrate the shape and dtype of a tensor. Args: @@ -425,6 +455,7 @@ def _migrate_tensor_shape_dtype( ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2FNUZ, } or _is_4bit(model_tensor.dtype): + assert isinstance(safe_tensor, ir.ExternalTensor) return ir.ExternalTensor( location=safe_tensor.location, offset=safe_tensor.offset, From bdff1290c8fea3681bf385265ed10d13347c36d5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 20:05:30 -0700 Subject: [PATCH 3/4] Bind correctly Signed-off-by: Justin Chu --- src/onnx_safetensors/_safetensors_io.py | 40 +++++++++++++++---------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/onnx_safetensors/_safetensors_io.py b/src/onnx_safetensors/_safetensors_io.py index f1f5a95..ed4f41e 100644 --- a/src/onnx_safetensors/_safetensors_io.py +++ b/src/onnx_safetensors/_safetensors_io.py @@ -61,20 +61,22 @@ ir.DataType.UINT32: "uint32", ir.DataType.UINT64: "uint64", } -_CASTABLE_DTYPES = frozenset({ - ir.DataType.BFLOAT16, - ir.DataType.FLOAT16, - ir.DataType.FLOAT, - ir.DataType.DOUBLE, - ir.DataType.INT8, - ir.DataType.INT16, - ir.DataType.INT32, - ir.DataType.INT64, - ir.DataType.UINT8, - ir.DataType.UINT16, - ir.DataType.UINT32, - ir.DataType.UINT64, -}) +_CASTABLE_DTYPES = frozenset( + { + ir.DataType.BFLOAT16, + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.INT8, + ir.DataType.INT16, + ir.DataType.INT32, + ir.DataType.INT64, + ir.DataType.UINT8, + ir.DataType.UINT16, + ir.DataType.UINT32, + ir.DataType.UINT64, + } +) TModel = TypeVar("TModel", onnx.ModelProto, ir.Model) @@ -109,7 +111,9 @@ def apply_tensors( f"Cannot cast tensor '{name}' from dtype {tensor.dtype} to {model_tensor.dtype}." ) updated_tensor = ir.LazyTensor( - lambda: ir.Tensor(tensor.numpy().astype(model_tensor.dtype.numpy())), + lambda tensor=tensor, model_tensor=model_tensor: ir.Tensor( + tensor.numpy().astype(model_tensor.dtype.numpy()) + ), dtype=model_tensor.dtype, shape=model_tensor.shape, name=model_tensor.name, @@ -139,7 +143,11 @@ def _is_8bit_float(dtype: ir.DataType) -> bool: def replace_tensors( - model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike, cast: bool = False + model: ir.Model, + /, + location: str | os.PathLike, + base_dir: str | os.PathLike, + cast: bool = False, ) -> None: """Replace all tensors in an ONNX model with external data from a safetensors file. From 984205d8eceeb7ed732b703d0c60df74ec162b93 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 20:16:56 -0700 Subject: [PATCH 4/4] wip Signed-off-by: Justin Chu --- src/onnx_safetensors/_safetensors_io.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/onnx_safetensors/_safetensors_io.py b/src/onnx_safetensors/_safetensors_io.py index ed4f41e..baf0786 100644 --- a/src/onnx_safetensors/_safetensors_io.py +++ b/src/onnx_safetensors/_safetensors_io.py @@ -105,7 +105,11 @@ def apply_tensors( if model_tensor is not None and apply_safetensors: assert isinstance(tensor, ir.ExternalTensor) _check_tensors_match(model_tensor, tensor, cast=cast) - if model_tensor.dtype != tensor.dtype: + if ( + model_tensor.dtype != tensor.dtype + and not _is_4bit(model_tensor.dtype) + and not _is_8bit_float(model_tensor.dtype) + ): if model_tensor.dtype not in _CASTABLE_DTYPES: raise ValueError( f"Cannot cast tensor '{name}' from dtype {tensor.dtype} to {model_tensor.dtype}."