-
Notifications
You must be signed in to change notification settings - Fork 755
Open
Description
Say you're on an M3 mac with python 3.11 trying export a simple model like this:
import tensorflow as tf
import jax
import jax.numpy as jnp
from orbax.export import JaxModule, ExportManager, ServingConfig
yy = jax.random.normal(jax.random.key(9), 256)
def my_pred(yy, x):
return jnp.dot(xx, yy)
xx = jnp.ones(256)
jax_module = JaxModule(yy, my_pred)
def to_tfspec(a):
return tf.TensorSpec(shape=a.shape, dtype=a.dtype)
sig = jax.tree.map(lambda a: to_tfspec(jax.typeof(a)), [xx])
export_mgr = ExportManager(jax_module, [
ServingConfig('serving_default', input_signature=sig)
])
output_dir='/tmp/blahblah'
export_mgr.save(output_dir)
We get the following error: "XlaCallModule from your TensorFlow installation supports up to serialization version 9 but the serialized module needs version 10. You should upgrade TensorFlow, e.g., to tf_nightly."
I guess this means we need to depend on tf_nightly if we want docs to demonstrate orbax functionality.
Upon installing tf_nightly, however, we get another issue:
"Detected mismatched Protobuf Gencode/Runtime major versions when loading tensorflow/core/framework/attr_value.proto: gencode 6.31.1 runtime 5.29.5. "
I tried upgrading protobuf, but the issue doesn't go away. How do I get the protobuf runtime to use version 6.31.1 too?
Metadata
Metadata
Assignees
Labels
No labels