-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathisa_head.py
143 lines (122 loc) · 4.81 KB
/
isa_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class SelfAttentionBlock(_SelfAttentionBlock):
"""Self-Attention Module.
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict | None): Config of activation layers.
"""
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
super().__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.output_project = self.build_project(
in_channels,
in_channels,
num_convs=1,
use_conv_module=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
"""Forward function."""
context = super().forward(x, x)
return self.output_project(context)
@MODELS.register_module()
class ISAHead(BaseDecodeHead):
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
This head is the implementation of `ISA
<https://arxiv.org/abs/1907.12273>`_.
Args:
isa_channels (int): The channels of ISA Module.
down_factor (tuple[int]): The local group size of ISA.
"""
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
super().__init__(**kwargs)
self.down_factor = down_factor
self.in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.local_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.out_conv = ConvModule(
self.channels * 2,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x_ = self._transform_inputs(inputs)
x = self.in_conv(x_)
residual = x
n, c, h, w = x.size()
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2)
x = F.pad(x, padding)
# global relation
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
# do permutation to gather global group
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
x = x.reshape(-1, c, glb_h, glb_w)
# apply attention within each global group
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
# local relation
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
# do permutation to gather local group
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
x = x.reshape(-1, c, loc_h, loc_w)
# apply attention within each local group
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
# permute each pixel back to its original position
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
if pad_h > 0 or pad_w > 0: # remove padding
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
x = self.out_conv(torch.cat([x, residual], dim=1))
out = self.cls_seg(x)
return out