It would be awesome if there was an officially supported function to convert keys that orbax sadly converted from int to str back to int. I saw #4317, but it is not sufficient for my use case, which needs a little more involved loading of checkpoints. See NovaSky-AI/SkyRL#461 for the use case.
The restore_int_paths function from https://github.com/google/flax/blob/main/flax/nnx/statelib.py works great, would it be possible to make it public (or to provide a public version of that function)?
Since the conversion from int to str is part of the public interface of flax.training.checkpoints, it would make sense to also have a public interface to convert back. Another even better option would be to do the conversion automatically in checkpoints.restore_checkpoint.
Thanks for your help :)