-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path06_export_preprocessing_onnx.py
50 lines (41 loc) · 1.3 KB
/
06_export_preprocessing_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List
import onnx
import torch
import torch.nn as nn
from onnxsim import simplify
class Preprocess(nn.Module):
def __init__(self, input_shape: List[int]):
super(Preprocess, self).__init__()
self.input_shape = tuple(input_shape)
self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1)
self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = torch.nn.functional.interpolate(
input=x,
size=self.input_shape[2:],
)
x = x / 255.0
x = (x - self.mean) / self.std
return x
if __name__ == "__main__":
input_shape = [1, 3, 448, 448]
output_onnx_file = "preprocessing.onnx"
model = Preprocess(input_shape=input_shape)
torch.onnx.export(
model,
torch.randn(input_shape),
output_onnx_file,
opset_version=20,
input_names=["input_rgb"],
output_names=["output_preprocessing"],
dynamic_axes={
"input_rgb": {
0: "batch_size",
2: "height",
3: "width",
},
},
)
model_onnx = onnx.load(output_onnx_file)
model_simplified, _ = simplify(model_onnx)
onnx.save(model_simplified, output_onnx_file)