-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpred_and_plot_image.py
68 lines (56 loc) · 2.67 KB
/
pred_and_plot_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#Libraries
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from typing import List, Tuple
# Setup target device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Function
def pred_and_plot_image(model: torch.nn.Module,
image_path: str,
class_names: List[str],
image_size: Tuple[int, int] = (224, 224),
transform: transforms = None,
device: torch.device=device):
"""
Predicts the class label and plots an image with the predicted label and probability.
Args:
model (torch.nn.Module): PyTorch model for image classification.
image_path (str): Path to the input image file.
class_names (List[str]): List of class names for mapping predicted labels.
image_size (Tuple[int, int]): Size to which the input image is resized (default is (224, 224)).
transform (transforms.Compose): Image transformation pipeline (default is None).
device (torch.device): Device to perform the inference on (default is GPU if available, else CPU).
"""
# 2. Open image
img = Image.open(image_path)
# 3. Create transformation for image (if one doesn't exist)
if transform is not None:
image_transform = transform
else:
image_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
### Predict on image ###
# 4. Make sure the model is on the target device
model.to(device)
# 5. Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# 6. Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
transformed_image = image_transform(img).unsqueeze(dim=0)
# 7. Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(transformed_image.to(device))
# 8. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 9. Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 10. Plot image with predicted label and probability
plt.figure()
plt.imshow(img)
plt.title(f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}")
plt.axis(False);