Skip to content

Tensorflow test dependency needs upgrade to support orbax #5054

@samanklesaria

Description

@samanklesaria

Say you're on an M3 mac with python 3.11 trying export a simple model like this:

import tensorflow as tf
import jax
import jax.numpy as jnp
from orbax.export import JaxModule, ExportManager, ServingConfig

yy = jax.random.normal(jax.random.key(9), 256)

def my_pred(yy, x):
    return jnp.dot(xx, yy)

xx = jnp.ones(256)

jax_module = JaxModule(yy, my_pred)

def to_tfspec(a):
    return tf.TensorSpec(shape=a.shape, dtype=a.dtype)

sig = jax.tree.map(lambda a: to_tfspec(jax.typeof(a)), [xx])

export_mgr = ExportManager(jax_module, [
    ServingConfig('serving_default', input_signature=sig)
])

output_dir='/tmp/blahblah'
export_mgr.save(output_dir)

We get the following error: "XlaCallModule from your TensorFlow installation supports up to serialization version 9 but the serialized module needs version 10. You should upgrade TensorFlow, e.g., to tf_nightly."

I guess this means we need to depend on tf_nightly if we want docs to demonstrate orbax functionality.

Upon installing tf_nightly, however, we get another issue:

"Detected mismatched Protobuf Gencode/Runtime major versions when loading tensorflow/core/framework/attr_value.proto: gencode 6.31.1 runtime 5.29.5. "

I tried upgrading protobuf, but the issue doesn't go away. How do I get the protobuf runtime to use version 6.31.1 too?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions