-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhubconf.py
89 lines (63 loc) · 2.52 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
""" Configuration for torch hub usage """
from typing import Callable, Dict, Any
import torch
dependencies = ['torch', 'trecover']
def add_available_versions_to_docstring(handler: Callable[..., Callable]) -> Callable:
from trecover.config import var
handler.__doc__ += '\n'.join([f'\t* {version}' for version in var.CHECKPOINT_URLS.keys()])
return handler
@add_available_versions_to_docstring
def trecover(device: torch.device = torch.device('cpu'), version: str = 'latest'):
"""
Load the TRecover model via torch.hub.
Parameters
----------
device : torch.device, default=torch.device('cpu')
Device on which to allocate the model.
version : str, default='latest'
Model weights' version.
Returns
-------
model : TRecover
Model with specified weights' version.
Examples
--------
Show the docstring with available versions for the TRecover model:
>>> print(torch.hub.help(github='alex-snd/TRecover', model='trecover'))
Load the TRecover model:
>>> torch.hub.load('alex-snd/TRecover', model='trecover', device=torch.device('cpu'), version='latest')
Available Versions
------------------
"""
import json
from urllib.request import urlopen
from trecover.config import var
from trecover.model import TRecover
if version not in var.CHECKPOINT_URLS:
version = 'latest'
with urlopen(var.CHECKPOINT_URLS[version]['config']) as url:
config = json.loads(url.read().decode())
model = TRecover(token_size=config['token_size'], pe_max_len=config['pe_max_len'],
num_layers=config['num_layers'], d_model=config['d_model'], n_heads=config['n_heads'],
d_ff=config['d_ff'], dropout=config['dropout'])
model.load_state_dict(torch.hub.load_state_dict_from_url(url=var.CHECKPOINT_URLS[version]['model'],
progress=False,
map_location=device))
return model
def collab_args() -> Dict[str, Any]:
"""
Load collaborative training base arguments.
Returns
-------
Dict[str, Any] :
Collaborative training base arguments.
"""
return {
'initial_peers': [
'/ip4/95.216.202.215/tcp/34234/p2p/QmS2mYPX8Q78RDQzxMf17phVYPashi4C6ixVbA3jrj9yxt',
],
'experiment_prefix': 'trecover',
'target_batch_size': 4096,
'min_noise': 0,
'max_noise': 1, # TODO model params
}