-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgenerator.py
46 lines (36 loc) · 1.18 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import math
import random
import numpy as np
from tensorflow.keras.utils import to_categorical
def std_normalize(data):
# normalize as 64 bit, to avoid numpy warnings
data = data.astype(np.float64)
mean = np.mean(data)
std = np.std(data)
data = data.copy() - mean
if std != 0.:
data = data / std
return data.astype(np.float32)
def create_data_generator(features, labels, batch_size):
length = len(labels)
idx = [i for i in range(length)]
random.shuffle(idx)
i = 0
while True:
# 每个epoch后进行shuffle
if i + batch_size > length:
i = 0
random.shuffle(idx)
# 每次取batch_size个key
X, y = [], []
for j in range(i, i + batch_size):
# 对每一个,取feature
feature = features[idx[j]]
label = labels[idx[j]]
# # 对数据进行标准化 # 可能引起一些问题
# feature = std_normalize(feature)
# 存入集合
X.append(feature)
y.append(label)
i += batch_size
yield np.stack(X, axis=0), np.stack(y, axis=0)