Skip to content

Commit 54903cd

Browse files
committed
add create index test
1 parent 506d34b commit 54903cd

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

bergson/data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ def create_index(
366366
{
367367
"num_grads": num_grads,
368368
"dtype": struct_dtype,
369-
"unstructured_dtype": np.dtype(dtype).str,
370-
"grad_dimension": sum(grad_sizes.values()),
369+
"grad_sizes": grad_sizes,
370+
"base_dtype": np.dtype(dtype).str,
371371
},
372372
f,
373373
indent=2,
@@ -433,9 +433,9 @@ def load_gradients(root_dir: Path, with_structure: bool = True) -> np.memmap:
433433
dtype = info["dtype"]
434434
shape = (num_grads,)
435435
else:
436-
dtype = info["unstructured_dtype"]
437-
grad_dimension = info["grad_dimension"]
438-
shape = (num_grads, grad_dimension)
436+
dtype = info["base_dtype"]
437+
grad_sizes = info["grad_sizes"]
438+
shape = (num_grads, sum(grad_sizes.values()))
439439

440440
return np.memmap(
441441
root_dir / "gradients.bin",

tests/test_build.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66
import torch
7-
from transformers import AutoConfig, AutoModelForCausalLM
7+
from transformers import AutoModelForCausalLM
88

99
from bergson import (
1010
AttentionConfig,
@@ -37,30 +37,6 @@ def test_build_e2e(tmp_path: Path):
3737
assert result.returncode == 0
3838

3939

40-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
41-
def test_large_gradients_build(tmp_path: Path, dataset):
42-
config = AutoConfig.from_pretrained(
43-
"EleutherAI/pythia-1.4b", trust_remote_code=True
44-
)
45-
model = AutoModelForCausalLM.from_config(config)
46-
model.cuda()
47-
48-
collect_gradients(
49-
model=model,
50-
data=dataset,
51-
processor=GradientProcessor(),
52-
path=tmp_path,
53-
skip_preconditioners=True,
54-
)
55-
56-
# Load a large gradient index without structure.
57-
load_gradients(tmp_path, with_structure=False)
58-
59-
with pytest.raises(ValueError):
60-
# Max item size exceeded.
61-
load_gradients(tmp_path, with_structure=True)
62-
63-
6440
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6541
def test_build_consistency(tmp_path: Path, model, dataset):
6642
collect_gradients(

tests/test_data.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import math
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
from transformers import AutoConfig, AutoModelForCausalLM
8+
9+
from bergson.data import create_index, load_gradients
10+
from bergson.gradients import GradientCollector
11+
12+
13+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
14+
def test_large_gradients_build(tmp_path: Path, dataset):
15+
# Create index for uncompressed gradients from a large model.
16+
config = AutoConfig.from_pretrained(
17+
"EleutherAI/pythia-1.4b", trust_remote_code=True
18+
)
19+
model = AutoModelForCausalLM.from_config(config)
20+
collector = GradientCollector(model, lambda x: x)
21+
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
22+
23+
create_index(
24+
tmp_path,
25+
num_grads=len(dataset),
26+
grad_sizes=grad_sizes,
27+
dtype=np.float32,
28+
with_structure=False,
29+
)
30+
31+
# Load a large gradient index without structure.
32+
load_gradients(tmp_path, with_structure=False)
33+
34+
with pytest.raises(ValueError):
35+
# Max item size exceeded.
36+
load_gradients(tmp_path, with_structure=True)

0 commit comments

Comments
 (0)