forked from pinellolab/DNA-Diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
50 lines (37 loc) · 1.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pathlib import Path
import hydra
from hydra.core.config_store import ConfigStore
from hydra_zen import MISSING, instantiate, make_config
from omegaconf import DictConfig
from dnadiffusion.configs import LightningTrainer, sample
Config = make_config(
hydra_defaults=[
"_self_",
{"data": "LoadingData"},
{"model": "Unet"},
],
data=MISSING,
model=MISSING,
trainer=LightningTrainer,
sample=sample,
# Constants
data_dir="dna_diffusion/data",
random_seed=42,
ckpt_path=None,
)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
def train(config):
data = instantiate(config.data)
sample = instantiate(config.sample, data_module=data)
model = instantiate(config.model)
trainer = instantiate(config.trainer)
# Adding custom callbacks
trainer.callbacks.append(sample)
trainer.fit(model, data)
return model
@hydra.main(config_path=None, config_name="config", version_base="1.3")
def main(cfg: DictConfig):
return train(cfg)
if __name__ == "__main__":
main()