-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
105 lines (95 loc) · 3.2 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#coding:utf8
import os
from os.path import join, exists
import warnings
class DefaultConfig(object):
# dataset
patchsize = 256
dataset = "datasetCNNLC"
labelName = "GID15" # "GID24"
dataRoot = "./dataset/data"
imgdir = join(dataRoot, "img")
label5dir = join(dataRoot, "label-5")
label15dir = join(dataRoot, "label-15")
divideDir = f"./dataset/divide"
trainTileIds = join(divideDir, "train.txt")
valTileIds = join(divideDir, "val.txt")
testTileIds = join(divideDir, "test.txt")
lcClassNum = 5
label5List = ["builtup", "farmland", "forest", "meadow", "water"]
label15List = [
"Ind. land", "Urb. resid.", "Rur. resid.", "Traff. land", "P. field", "Irr. land", "Dry cropl.", "Garden",
"Arb. forest", "Shr. land", "Nat. mead.", "Art. mead.", "River", "Lake", "Pond"
]
label24List = [
"Indu", "Urba", "Rura", "Road", "Padd", "Irri", "Dryc", "Gard", "Arbo", "Shru", "Natu", "Arti", "Rive",
"Lake", "Pond", "Stad", "Squa", "Over", "Rail", "Airp", "Park", "Fish"
]
outputDir = './outputs'
# model
labelNameDict = {"GID15": 15, "GID24": 22}
indim = 3
outdim = labelNameDict[labelName]
model = 'FCN' # model Name
backbone = "resnet50"
isBackBoneFrozen = True
isLCEncoder = True
backboneFeatureLayer = "c3"
isMultiClassifier = True
GCNLayerNum = 1
nodeNumRate = 1
tag = model # output tag
# train
seed = 2023
train = False
createTif = False
loss = "CrossEntropyLoss"
isAccCal = True
testModel = None
batchSize = 8
useGpu = True
deviceId = None # None: use the last one by default
numWorkers = 2
saveFreq = 5
valStep = 5
maxEpoch = 50
lrMax = 0.001
lrMode = "poly" # const
weightDecay = 1e-4
def parse(self, kwargs):
'''
update parameters
'''
# Update parameters
for k, v in kwargs.items():
if not hasattr(self, k):
warnings.warn("Warning: opt has not attribut %s" % k)
setattr(self, k, v)
if k == "labelName":
self.outdim = self.labelNameDict[self.labelName]
self.label5dir = join(self.dataRoot, "label-5")
if self.labelName == "GID15":
self.label15dir = join(self.dataRoot, "label-15")
elif self.labelName == "GID24":
self.label15dir = join(self.dataRoot, "label-24")
else:
raise ValueError("labelName must be GID15 or GID24")
# The output path of paramters
paraSaveDir = join(self.outputDir, self.tag)
if not exists(paraSaveDir):
os.makedirs(paraSaveDir, exist_ok=True)
paraSavePath = join(paraSaveDir, "hyperParas.txt")
if self.train:
with open(paraSavePath, "w") as f:
f.write("")
print('user config:')
tplt = "{0:>20}\t{1:<10}"
with open(paraSavePath, "a") as f:
for k, v in self.__class__.__dict__.items():
if not k.startswith('__') and k != "parse":
value = str(getattr(self, k))
print(tplt.format(k, value))
if self.train:
f.write(tplt.format(k, value, chr(12288)) + "\n")
DefaultConfig.parse = parse
opt = DefaultConfig()