-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinit.py
30 lines (25 loc) · 998 Bytes
/
init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Model setup."""
from model import create_model
import os
import torch
import utils
def setup(opt, checkpoint):
"""Create new model or reload from checkpoint model"""
# Resume model if checkpoint is provided, create a new model otherwise.
if checkpoint is not None:
model_path = os.path.join(opt.resume, checkpoint['model_file'])
utils.check_file(model_path)
print("".ljust(4) + "=> Resuming model from %s" %model_path)
model = torch.load(model_path)
else:
print("".ljust(4) + "=> Creating new model")
model = create_model(opt)
# Load optim_file if checkpoint is provided, return None otherwise.
if checkpoint is not None:
optim_path = os.path.join(opt.resume, checkpoint['optim_file'])
utils.check_file(optim_path)
print("".ljust(4) + "=> Resuming optim_state from %s" %optim_path)
optim_state = torch.load(optim_path)
else:
optim_state = None
return model, optim_state