Skip to content

Commit dac4cf0

Browse files
authored
Merge pull request #64 from EleutherAI/harden-otf
Refactor on-the-fly query
2 parents defe3eb + 367ffe1 commit dac4cf0

File tree

8 files changed

+360
-374
lines changed

8 files changed

+360
-374
lines changed

bergson/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class Build:
1717

1818
def execute(self):
1919
"""Build the gradient dataset."""
20-
if not self.cfg.save_index and not self.cfg.save_processor:
20+
if not self.cfg.save_index and self.cfg.skip_preconditioners:
2121
raise ValueError(
22-
"At least one of save_index or save_processor must be True"
22+
"Either save_index must be True or skip_preconditioners must be False"
2323
)
2424

2525
build_gradient_dataset(self.cfg)

bergson/build.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def worker(
147147
projection_type=cfg.projection_type,
148148
include_bias=cfg.include_bias,
149149
)
150-
if rank == 0 and cfg.save_processor:
150+
if rank == 0:
151151
processor.save(cfg.partial_run_path)
152152

153153
if cfg.split_attention_modules:
@@ -171,7 +171,6 @@ def worker(
171171
target_modules=target_modules,
172172
attention_cfgs=attention_cfgs,
173173
save_index=cfg.save_index,
174-
save_processor=cfg.save_processor,
175174
drop_columns=cfg.drop_columns,
176175
token_batch_size=cfg.token_batch_size,
177176
module_wise=cfg.module_wise,
@@ -199,7 +198,6 @@ def flush():
199198
attention_cfgs=attention_cfgs,
200199
save_index=cfg.save_index,
201200
# Save a processor state checkpoint after each shard
202-
save_processor=cfg.save_processor,
203201
drop_columns=cfg.drop_columns,
204202
token_batch_size=cfg.token_batch_size,
205203
module_wise=cfg.module_wise,
@@ -213,7 +211,7 @@ def flush():
213211
flush()
214212
flush()
215213

216-
if cfg.save_processor:
214+
if rank == 0:
217215
processor.save(cfg.partial_run_path)
218216

219217

bergson/collection.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .data import create_index, pad_and_tensor
1414
from .gradients import AttentionConfig, GradientCollector, GradientProcessor
1515
from .peft import set_peft_enabled
16-
from .score_writer import ScoreWriter
16+
from .scorer import Scorer
1717

1818

1919
def collect_gradients(
@@ -29,9 +29,8 @@ def collect_gradients(
2929
target_modules: set[str] | None = None,
3030
attention_cfgs: dict[str, AttentionConfig] | None = None,
3131
save_index: bool = True,
32-
save_processor: bool = True,
3332
drop_columns: bool = False,
34-
score_writer: ScoreWriter | None = None,
33+
scorer: Scorer | None = None,
3534
token_batch_size: int | None = None,
3635
module_wise: bool = False,
3736
):
@@ -65,8 +64,8 @@ def callback(name: str, g: torch.Tensor, indices: list[int]):
6564
else:
6665
mod_grads[name] = g.to(dtype=dtype)
6766

68-
if score_writer and module_wise:
69-
score_writer(indices, mod_grads, name=name)
67+
if scorer and module_wise:
68+
scorer(indices, mod_grads, name=name)
7069

7170
# Compute the outer product of the flattened gradient
7271
if not skip_preconditioners:
@@ -161,11 +160,11 @@ def callback(name: str, g: torch.Tensor, indices: list[int]):
161160
for module_name in mod_grads.keys():
162161
grad_buffer[module_name][indices] = mod_grads[module_name].numpy()
163162

164-
if score_writer is not None:
163+
if scorer is not None:
165164
if module_wise:
166-
score_writer.finalize_module_wise(indices)
165+
scorer.finalize_module_wise(indices)
167166
else:
168-
score_writer(indices, mod_grads)
167+
scorer(indices, mod_grads)
169168

170169
mod_grads.clear()
171170
per_doc_losses[indices] = losses.detach().type_as(per_doc_losses)
@@ -187,8 +186,7 @@ def callback(name: str, g: torch.Tensor, indices: list[int]):
187186
)
188187
data.save_to_disk(path / "data.hf")
189188

190-
if save_processor:
191-
processor.save(path)
189+
processor.save(path)
192190

193191
# Make sure the gradients are written to disk
194192
if grad_buffer is not None:

bergson/data.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ class IndexConfig:
118118
save_index: bool = True
119119
"""Whether to write the gradient index to disk."""
120120

121-
save_processor: bool = True
122-
"""Whether to write the gradient processor to disk."""
123-
124121
data: DataConfig = field(default_factory=DataConfig)
125122
"""Specification of the data on which to build the index."""
126123

0 commit comments

Comments
 (0)