This repository has been archived by the owner on Apr 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathevaluate.py
66 lines (57 loc) · 2.19 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import pandas as pd
import argparse
import logging
import time
import cv2
from utils.image import align_one_face, resize_square_image
from tqdm import tqdm
from model.inceptionv3 import AgenderNetInceptionV3
from model.mobilenetv2 import AgenderNetMobileNetV2
from model.ssrnet import AgenderSSRNet
parser = argparse.ArgumentParser()
parser.add_argument('--db_name',
required=True,
help='name of dataset .csv file in data/db/ folder')
parser.add_argument('--model',
required=True,
choices=['mobilenetv2', 'inceptionv3', 'ssrnet'],
help="model name to be used")
def main():
args = parser.parse_args()
DB = args.db_name
MODEL = args.model
data = pd.read_csv('data/db/{}.csv'.format(DB))
model = None
logger.info('Load model and weight')
if MODEL == 'mobilenetv2':
model = AgenderNetMobileNetV2()
model.load_weights('model/weight/mobilenetv2/model.10-3.8290-0.8965-6.9498.h5')
elif MODEL == 'inceptionv3':
model = AgenderNetInceptionV3()
model.load_weights('model/weight/inceptionv3/model.16-3.7887-0.9004-6.6744.h5')
else:
model = AgenderSSRNet(64, [3, 3, 3], 1.0, 1.0)
model.load_weights('model/weight/ssrnet/model.37-7.3318-0.8643-7.1952.h5')
logger.info('Read image')
images = [cv2.imread('{}_aligned/{}'.format(DB, path))
for path in tqdm(data.full_path.values)]
images = [cv2.resize(image, (model.input_size, model.input_size))
for image in images]
images = np.array(images)
images = model.prep_image(images)
logger.info('Predict data')
start = time.time()
prediction = model.predict(images)
pred_gender, pred_age = model.decode_prediction(prediction)
elapsed = time.time() - start
logger.info('Time elapsed {:.2f} sec'.format(elapsed))
result = pd.DataFrame()
result['full_path'] = data['full_path']
result['age'] = pred_age
result['gender'] = pred_gender
result.to_csv('result/{}.csv'.format(DB), index=False)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
main()