-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathmain.lua
49 lines (35 loc) · 1.37 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
require 'torch'
require 'cutorch'
require 'paths'
require 'nn'
require 'nngraph'
local DataLoader = require 'dataloader'
local checkpoints = require 'checkpoints'
local models = require 'models/init'
local Trainer = require 'train'
local opts = require 'opts'
local opt = opts.parse(arg)
torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(1)
cutorch.setDevice(1)
torch.setheaptracking(true)
torch.manualSeed(opt.manualSeed)
cutorch.manualSeed(opt.manualSeed)
--Load previous checkpoints, if it exists
local checkpoint, optimState = checkpoints.latest(opt)
local optimState = checkpoint and torch.load(checkpoint.optimFile) or nil
--Create model
local model, criterion = models.setup(opt, checkpoint, true)
print('=> Model size: ', model:getParameters():size(1))
--Data loading
local trainLoader = DataLoader.create(opt,'train')
local trainer = Trainer(model, criterion, opt, optimState, netLogger)
local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber
for epoch = startEpoch, opt.nEpochs do
-- Train for a single epoch
local trainLoss, trainAcc = trainer:train(epoch, trainLoader)
print(string.format(' *Results loss: %6.6f acc: %6.6f ',trainLoss, trainAcc))
if opt.snapshot ~= 0 and epoch % opt.snapshot == 0 then
checkpoints.save(epoch, model:clearState(), trainer.optimState, bestModel)
end
end