-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathearlystopping_by_plateau.py
48 lines (36 loc) · 1.29 KB
/
earlystopping_by_plateau.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import numpy as np
torch_inf = torch.tensor(np.Inf)
class Earlystopping_by_plateau():
def __init__(self,
monitor='loss',
min_delta=0.0001,
patience=5,
mode='min'):
self.wait = 0
self.patience = patience
self.stopped_epoch = 0
self.min_delta = min_delta
self.monitor = monitor
self.mode = mode
mode_dict = {'min': torch.lt,'max': torch.gt}
if self.monitor == 'acc':
self.mode = 'max'
else:
self.mode = 'min'
self.monitor_op=mode_dict[self.mode]
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
def early_stop(self, current_epoch, current_val):
stop_training = False
if not isinstance(current_val, torch.Tensor):
current_val = torch.tensor(current_val)
if self.monitor_op(current_val - self.min_delta, self.best):
self.best = current_val
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = current_epoch
stop_training = True
return stop_training