diff --git a/dvc.yaml b/dvc.yaml index b8034ee..9e983b1 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -44,4 +44,28 @@ stages: - data/datasets/polemo2/lightning - data/hps/lightning/${item}/polemo2/best_params.yaml outs: - - data/models/lightning/${item}/polemo2/ \ No newline at end of file + - data/models/lightning/${item}/polemo2/ + + train_retnet: + wdir: ./RetNet + cmd: >- + torchrun --nproc_per_node=3 + train.py + --model_size 300m + --output_dir ../data/models/retnet/checkpoints + --do_train + --do_eval + --prediction_loss_only + --remove_unused_columns False + --learning_rate 6e-4 + --weight_decay 0.01 + --max_steps 20000 + --logging_steps 100 + --eval_steps 1000 + --save_steps 1000 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + deps: + - train.py + outs: + - ../data/models/retnet/checkpoints