-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_cyclegan.py
83 lines (66 loc) · 2.83 KB
/
test_cyclegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import click
import os
import torch
import random
from datetime import datetime
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.autograd import Variable
from torchvision import transforms
import numpy as np
from utils import init_device_seed
from datasets import TypesDataset
from model_cyclegan import CycleGANGenerator, CycleGANDiscriminator
@click.command()
@click.option('--dataset_type', default='summer2winter_yosemite')
@click.option('--image_path', default='./data/summer2winter_yosemite/testA')
@click.option('--model_type', default='x2y')
@click.option('--cuda_visible', default='0')
def test(dataset_type, image_path, model_type, is_crop, cuda_visible):
device = init_device_seed(1234, cuda_visible)
os.makedirs('./result', exist_ok=True)
checkpoint = torch.load('./model/cyclegan_' + dataset_type, map_location=device)
generator = CycleGANGenerator().to(device)
generator.load_state_dict(checkpoint[model_type + '_state_dict'])
generator.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
to_pil = transforms.Compose([
transforms.Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
transforms.ToPILImage()
])
# generator_inv = CycleGANGenerator()
# generator_inv.load_state_dict(checkpoint[model_type[::-1] + '_state_dict'])
# generator_inv.to(device)
# generator_inv.eval()
if os.path.isdir(image_path):
files_list = []
file_names_list = os.listdir(image_path)
for file_name in file_names_list:
files_list.append(os.path.join(image_path, file_name))
output_dir = './result/{}'.format(datetime.now().strftime('%Y-%m-%d %H_%M_%S'))
os.makedirs(output_dir, exist_ok=True)
else:
files_list = [image_path]
for idx, file_path in enumerate(files_list):
file_name = '.'.join(os.path.basename(file_path).split('.')[:-1])
print('\r{}/{} {}'.format(idx, len(files_list), file_name), end=' ')
image = Image.open(file_path)
image = to_tensor(image)
image = torch.unsqueeze(image, 0).to(device)
output = generator(image)
# cycle = torch.clip(generator_inv(output).detach().cpu()[0], 0, 1)
# cycle = transforms.ToPILImage()(cycle)
# identity = torch.clip(generator_inv(image).detach().cpu()[0], 0, 1)
# identity = transforms.ToPILImage()(identity)
output = output.detach().cpu()[0]
output = to_pil(output)
output.save('{}/{}.jpg'.format(output_dir,file_name))
# cycle.save('./result/{}_cycle.jpg'.format(image_path_base))
# identity.save('./result/{}_identity.jpg'.format(image_path_base))
if __name__ == '__main__':
test()