Skip to content

nnx.PathContains should support regex #5075

@thijs-vanweezel

Description

@thijs-vanweezel

The current implementation of nnx.PathContains only looks for exact matches in the path. This makes it difficult to select, e.g., a set of layers which have been given a similar name. To illustrate, my use case would be to use nnx.split to separate convolutional layers and linear layers, which I respectively named "conv1", "conv2", etc. and "fc1", "fc2", etc. To my knowledge, there are also no other filters that support this.

My proposal is a very simple modification method, namely:

@dataclasses.dataclass(frozen=True)
class PathContains:
  key: Key

  def __call__(self, path: PathParts, x: tp.Any):
    pattern = re.compile(self.key)
    return any(pattern.match(segment) for segment in path)

  def __repr__(self):
    return f'PathContains({self.key!r})'

Let me know if it is worth opening a pull request for this, and whether any efficiency constraints should be taken into account.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions