diff --git a/test.py b/test.py index c99cd7e..6bed1f5 100644 --- a/test.py +++ b/test.py @@ -5,9 +5,10 @@ import numpy as np import torch import architecture as arch +from pathlib import Path model_path = sys.argv[1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth -device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu +device = torch.device('cpu') # if you want to run on CPU, change 'cuda' -> cpu # device = torch.device('cpu') test_img_folder = 'LR/*' @@ -19,9 +20,12 @@ for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) +model_used = Path('{:s}'.format(model_path)).stem # get model name to simple string for filename use print('Model path {:s}. \nTesting...'.format(model_path)) + +amend = 2 #File already exists, lets amend a 2 to the end of the file name idx = 0 for path in glob.glob(test_img_folder): idx += 1 @@ -37,4 +41,13 @@ output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) output = (output * 255.0).round() - cv2.imwrite('results/{:s}_rlt.png'.format(base), output) + + while os.path.exists('results/{:s}_%s_%s.png'.format(base) % (model_used, amend)): + amend += 1 + + if os.path.exists('results/{:s}_%s.png'.format(base) % model_used): + cv2.imwrite('results/{:s}_%s_%s.png'.format(base) % (model_used, amend), output) + + else: + cv2.imwrite('results/{:s}_%s.png'.format(base) % model_used, output) +