Skip to content

Commit f4444d4

Browse files
committed
add a function for detaching the neural mem state
1 parent fcc5782 commit f4444d4

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "titans-pytorch"
3-
version = "0.4.3"
3+
version = "0.4.5"
44
description = "Titans"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_titans.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,23 @@ def test_assoc_scan(
405405
assert second_half.shape == inputs2.shape
406406

407407
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
408+
409+
def test_mem_state_detach():
410+
from titans_pytorch.neural_memory import mem_state_detach
411+
412+
mem = NeuralMemory(
413+
dim = 384,
414+
chunk_size = 2,
415+
qk_rmsnorm = True,
416+
dim_head = 64,
417+
heads = 4,
418+
)
419+
420+
seq = torch.randn(4, 64, 384)
421+
422+
state = None
423+
424+
for _ in range(2):
425+
parallel_retrieved, state = mem(seq, state = state)
426+
state = mem_state_detach(state)
427+
parallel_retrieved.sum().backward()

titans_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from titans_pytorch.neural_memory import (
22
NeuralMemory,
3+
NeuralMemState,
4+
mem_state_detach
35
)
46

57
from titans_pytorch.memory_models import (

titans_pytorch/neural_memory.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from collections import namedtuple
88

99
import torch
10-
from torch import nn, stack, cat, tensor, Tensor
10+
from torch import nn, stack, cat, is_tensor, tensor, Tensor
1111
import torch.nn.functional as F
1212
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
1313
from torch.func import functional_call, vmap, grad
14+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
1415

1516
from tensordict import TensorDict
1617

@@ -40,6 +41,8 @@
4041

4142
LinearNoBias = partial(Linear, bias = False)
4243

44+
# neural mem state related
45+
4346
NeuralMemState = namedtuple('NeuralMemState', [
4447
'seq_index',
4548
'weights',
@@ -48,6 +51,13 @@
4851
'updates',
4952
])
5053

54+
def mem_state_detach(
55+
state: NeuralMemState
56+
):
57+
assert isinstance(state, NeuralMemState)
58+
state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
59+
return NeuralMemState(*state)
60+
5161
# functions
5262

5363
def exists(v):
@@ -854,6 +864,7 @@ def forward(
854864
seq,
855865
store_seq = None,
856866
state: NeuralMemState | None = None,
867+
detach_mem_state = False,
857868
prev_weights = None,
858869
store_mask: Tensor | None = None,
859870
return_surprises = False
@@ -1003,6 +1014,11 @@ def accum_updates(past_updates, future_updates):
10031014
updates
10041015
)
10051016

1017+
# maybe detach
1018+
1019+
if detach_mem_state:
1020+
next_neural_mem_state = mem_state_detach(next_neural_mem_state)
1021+
10061022
# returning
10071023

10081024
if not return_surprises:

0 commit comments

Comments
 (0)