Skip to content

Provide a supported way to convert str keys from orbax checkpoints back to int keys #5030

@pcmoritz

Description

@pcmoritz

It would be awesome if there was an officially supported function to convert keys that orbax sadly converted from int to str back to int. I saw #4317, but it is not sufficient for my use case, which needs a little more involved loading of checkpoints. See NovaSky-AI/SkyRL#461 for the use case.

The restore_int_paths function from https://github.com/google/flax/blob/main/flax/nnx/statelib.py works great, would it be possible to make it public (or to provide a public version of that function)?

Since the conversion from int to str is part of the public interface of flax.training.checkpoints, it would make sense to also have a public interface to convert back. Another even better option would be to do the conversion automatically in checkpoints.restore_checkpoint.

Thanks for your help :)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions