Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] : add flow1d correlation and correlation lookup #213

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions mmflow/models/utils/correlation1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

import torch
from mmcv.runner import BaseModule
from torch import Tensor


class Correlation1D(BaseModule):
"""Correlation1D Module.

The neck of Flow1D, which calculates correlation tensor of input features
with the method of 3D cost volume.
"""

def __init__(self):
super().__init__()

def forward(
self,
feat1: Tensor,
feat2_x: Tensor,
feat2_y: Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
feat2_y: Tensor,
feat2_y: Tensor

) -> Sequence[Tensor]:
"""Forward function for Correlation1D.

Args:
feat1 (Tensor): The feature from first input image.
feat2_x (Tensor): The 1D cross attention feature2 on x direction.
feat2_y (Tensor): The 1D cross attention feature2 on y direction.

Returns:
Sequence[Tensor]: Correlation list, include x correlation
and y correlation.
"""
corr_x = self.corr_x(feat1, feat2_x)
corr_y = self.corr_y(feat1, feat2_y)
corr = [corr_x, corr_y]
return corr

@staticmethod
def corr_x(feature1: Tensor, feature2: Tensor) -> Tensor:
"""corr_x function for Correlation1D.

Args:
feature1 (Tensor): Input feature1.
feature2 (Tensor): Input feature2.

Returns:
Tensor: x correlation.
"""
b, c, h, w = feature1.shape # [B, C, H, W]
scale_factor = c**0.5

# x direction, corr shape is [B, H, W, W]
feature1 = feature1.permute(0, 2, 3, 1)
feature2 = feature2.permute(0, 2, 1, 3)
corr = torch.matmul(feature1, feature2)

# reshape to [B*H*W, 1, 1, W]
corr = corr.unsqueeze(3).unsqueeze(3)
corr = corr / scale_factor
corr = corr.flatten(0, 2)
MeowZheng marked this conversation as resolved.
Show resolved Hide resolved

return corr

@staticmethod
def corr_y(feature1: Tensor, feature2: Tensor) -> Tensor:
"""corr_y function for Correlation1D.

Args:
feature1 (Tensor): Input feature1.
feature2 (Tensor): Input feature2.

Returns:
Tensor: y correlation.
"""
b, c, h, w = feature1.shape # [B, C, H, W]
scale_factor = c**0.5

# y direction, corr shape is [B, W, H, H]
feature1 = feature1.permute(0, 3, 2, 1)
feature2 = feature2.permute(0, 3, 1, 2)
corr = torch.matmul(feature1, feature2)

# reshape to [B*H*W, 1, H, 1]
corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1)
corr = corr / scale_factor
corr = corr.flatten(0, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1)
corr = corr / scale_factor
corr = corr.flatten(0, 2)
corr = corr.view(-1, 1, h, 1) / scale_factor


return corr
85 changes: 85 additions & 0 deletions mmflow/ops/corr_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,88 @@ def forward(self, corr_pyramid: Sequence[Tensor], flow: Tensor) -> Tensor:

out = torch.cat(out_corr_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()


@OPERATORS.register_module()
class CorrLookupFlow1D(nn.Module):
"""Correlation lookup operator for Flow1D.

This operator is used in `Flow1D<https://arxiv.org/pdf/2104.13918.pdf>`_

Args:
radius (int): the radius of the local neighborhood of the pixels.
Default to 4.
mode (str): interpolation mode to calculate output values 'bilinear'
| 'nearest' | 'bicubic'. Default: 'bilinear' Note: mode='bicubic'
supports only 4-D input.
padding_mode (str): padding mode for outside grid values 'zeros' |
'border' | 'reflection'. Default: 'zeros'
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s corner
pixels. If set to False, they are instead considered as referring
to the corner points of the input’s corner pixels, making the
sampling more resolution agnostic. Default to True.
"""

def __init__(self,
radius: int = 4,
mode: str = 'bilinear',
padding_mode: str = 'zeros',
align_corners: bool = True) -> None:
super().__init__()
self.r = radius
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners

def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor:
"""Forward function of Correlation lookup for Flow1D.

Args:
corr (Sequence[Tensor]): Correlation on x and y direction.
flow (Tensor): Current estimated optical flow.

Returns:
Tensor: lookup cost volume on the correlation of x and y directions
concatenate together.
"""
corr_x = corr[0]
corr_y = corr[1]
B, _, H, W = flow.shape

# reshape flow to [B, H, W, 2]
flow = flow.permute(0, 2, 3, 1)
coords_x = flow[:, :, :, 0]
coords_y = flow[:, :, :, 1]
coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1)
coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1)
centroid_x = coords_x.view(B * H * W, 1, 1, 2)
centroid_y = coords_y.view(B * H * W, 1, 1, 2)

dx = torch.linspace(
-self.r, self.r, 2 * self.r + 1, device=flow.device)
dy = torch.linspace(
-self.r, self.r, 2 * self.r + 1, device=flow.device)

delta_x = torch.stack((dx, torch.zeros_like(dx)), dim=-1)
delta_y = torch.stack((torch.zeros_like(dy), dy), dim=-1)
# [1, 2r+1, 1, 2]
delta_y = delta_y.view(1, 2 * self.r + 1, 1, 2)

coords_x = centroid_x + delta_x
coords_y = centroid_y + delta_y

corr_x = bilinear_sample(corr_x, coords_x, self.mode,
self.padding_mode, self.align_corners)
corr_y = bilinear_sample(corr_y, coords_y, self.mode,
self.padding_mode, self.align_corners)

# shape is [B, 2r+1, H, W]
corr_x = corr_x.view(B, H, W, -1)
corr_x = corr_x.permute(0, 3, 1, 2).contiguous()
corr_y = corr_y.view(B, H, W, -1)
corr_y = corr_y.permute(0, 3, 1, 2).contiguous()

correlation = torch.cat((corr_x, corr_y), dim=1)

return correlation
211 changes: 211 additions & 0 deletions tests/test_models/test_utils/test_correlation1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor

from mmflow.models.utils.correlation1d import Correlation1D

_feat1 = Tensor(
[[[[1.0154, 0.4896, 1.8628, 0.0762, 0.2545, -0.1868, 0.5853, 1.6154],
[-0.4458, -1.3631, -0.6748, 0.2643, 0.8796, 1.2195, -0.9295, 0.3636],
[0.2345, 0.1408, -0.2794, -2.2829, -1.8497, -0.4348, -0.1259, 1.2991],
[0.9833, 0.5806, 0.0429, -1.5982, 1.1363, 0.0071, -1.5662, 0.0415]],
[[-2.5624, 0.4736, 0.3118, -0.1595, 0.4542, -1.2495, -0.3464, -1.1194],
[0.1017, 1.1922, -1.2911, 0.6752, 1.4180, -0.3162, -0.3809, 1.4444],
[-0.8802, 1.5789, -0.7804, -0.2817, 0.3465, -0.6741, 0.1570, 0.1059],
[-0.8849, 0.3025, -0.3609, 0.7738, 0.8476, -0.2813, 1.5131, -1.4178]],
[[0.2065, -0.8124, -0.6505, 1.6508, 1.7852, 1.2732, 0.4985, -0.5486],
[2.7083, 1.0688, 0.4090, -0.1851, 1.0733, 1.1038, -1.4032, 0.2552],
[1.5166, -0.6669, 1.3872, -0.4971, 1.9420, -2.2243, -2.3078, -0.4577],
[-1.7597, 0.7735, 1.1435, -0.5766, 1.0973, -0.1990, -1.1990, 0.1093]],
[[0.2446, 1.8493, 1.7110, 1.1204, -1.7352, -1.3811, -0.2492, 0.8741],
[0.3271, 0.2713, -1.3248, -0.2370, 0.4934, -0.8729, -0.3618, 0.5313],
[0.8359, -0.2329, 0.4883, 0.1030, 0.2581, 0.3148, -0.9930, 0.2271],
[-1.1038, 0.0708, -0.4958, -1.1129, -0.9431, -0.0880, 1.0499,
-0.6881]]]])
_feat2 = Tensor(
[[[[1.3175, 1.4551, 1.6624, -0.5219, 0.3938, -1.4649, 0.9400, -0.4180],
[0.4486, 0.0388, -0.6881, -1.4353, 1.8669, 0.6907, 0.0128, 0.2979],
[1.7176, 0.3644, -1.2154, -1.9436, 0.9357, 2.0734, -0.3146, 0.1123],
[-0.7050, 1.4828, 0.8406, 0.3374, 0.7549, 0.4404, -0.1620, 0.3539]],
[[1.1737, -0.9930, -0.6959, -1.7765, -0.4785, -0.5701, -0.6154, 0.8447],
[2.2322, 1.2820, -0.9384, -0.2065, 0.1662, 0.9703, 0.1947, -0.7589],
[0.9334, -0.5888, 0.2904, -1.1869, -1.3860, -1.1149, -0.4794, -0.4440],
[1.0862, -1.1460, 0.9998, -1.3857, 1.0615, -0.1334, 1.4889, -0.2771]],
[[0.4017, 0.4662, 0.6031, 2.2982, -1.3094, -0.7295, -0.2682, 0.3263],
[-0.2803, 1.5200, -0.5896, 0.5558, -0.6111, -0.5191, -0.0100, 0.4099],
[0.3736, -1.0845, -0.9815, 0.9264, 0.5722, -2.2061, 0.9850, -0.2834],
[0.2425, 1.4829, -0.8054, 1.1259, -1.0513, 1.3195, -1.7388, 0.3673]],
[[0.0612, 0.3328, 0.1373, -1.9487, 0.8354, -0.7799, -0.4399, 1.7067],
[1.1250, -0.8651, -0.3540, 0.7884, 1.2341, -1.0060, 1.8890, 0.9911],
[0.9935, 0.3770, 1.4380, 0.0396, 0.2286, 2.2238, 0.1141, 0.0866],
[-0.1054, -0.4454, 0.1032, -1.1747, 0.5838, 1.2229, -0.2493, 1.0715]]]])
b, c, h, w = _feat1.size()


def test_correlation():
gt_corr_x = Tensor([[[[
-7.8589e-01, 2.0998e+00, 1.8146e+00, 2.0100e+00, 7.7996e-01,
-1.8402e-01, 1.1842e+00, -1.0520e+00
]]],
[[[
4.9387e-01, 2.3942e-01, 1.2414e-01, -3.2838e+00,
1.2874e+00, -9.1842e-01, -2.1343e-01, 1.5433e+00
]]],
[[[
1.3318e+00, 1.3336e+00, 1.3612e+00, -3.1777e+00,
1.4328e+00, -1.8832e+00, 4.9047e-01, 1.0963e+00
]]],
[[[
3.2244e-01, 7.0587e-01, 6.9355e-01, 9.2706e-01,
-5.5962e-01, -1.0494e+00, -3.8291e-01, 1.1421e+00
]]],
[[[
7.3966e-01, 8.7044e-02, 4.7271e-01, 3.2722e+00,
-1.9521e+00, -2.9039e-01, 1.2212e-01, -1.0508e+00
]]],
[[[
-6.4286e-01, 5.5144e-01, 5.6862e-01, 3.9673e+00,
-1.1483e+00, 5.6715e-01, 4.2971e-01, -1.4595e+00
]]],
[[[
2.7478e-01, 6.7256e-01, 7.4025e-01, 9.7059e-01,
-2.3234e-01, -4.1461e-01, 3.6964e-01, -3.9995e-01
]]],
[[[
3.2379e-01, 1.7486e+00, 1.6268e+00, -9.0931e-01,
1.3102e+00, -1.0049e+00, 9.8499e-01, -1.5399e-01
]]],
[[[
-1.8206e-01, 1.9734e+00, -7.5064e-01, 1.1910e+00,
-1.0334e+00, -9.7209e-01, 3.0245e-01, 6.1217e-01
]]],
[[[
1.0277e+00, 1.4327e+00, -4.5351e-01, 1.2591e+00,
-1.3325e+00, -3.0622e-01, 3.5824e-01, -3.0192e-01
]]],
[[[
-2.3949e+00, 4.3196e-02, 9.5187e-01, 2.0900e-01,
-1.6796e+00, -2.9920e-01, -1.3833e+00, -1.8328e-01
]]],
[[[
7.0550e-01, 3.9977e-01, -3.1122e-01, -4.0425e-01,
2.1314e-01, 5.8610e-01, -1.5550e-01, -3.7222e-01
]]],
[[[
1.9070e+00, 1.5283e+00, -1.3717e+00, -2.8489e-01,
9.1540e-01, 4.6496e-01, 6.0432e-01, 5.7434e-02
]]],
[[[
-7.2508e-01, 1.0374e+00, -4.4210e-01, -8.7988e-01,
2.3618e-01, 4.2033e-01, -8.5295e-01, 9.5285e-02
]]],
[[[
-6.4046e-01, -1.1721e+00, 9.7621e-01, 1.7381e-01,
-6.9380e-01, 4.0390e-02, -3.7773e-01, -4.6079e-01
]]],
[[[
1.9567e+00, 8.9705e-01, -9.7208e-01, -1.2971e-01,
7.0929e-01, 4.9284e-01, 6.4348e-01, -1.7833e-01
]]],
[[[
4.8913e-01, -3.6295e-01, -4.1357e-01, 1.0135e+00,
1.2491e+00, -9.6748e-03, 9.6871e-01, 2.9864e-02
]]],
[[[
6.1752e-01, -1.2145e-01, 3.0352e-01, -1.3873e+00,
-1.2457e+00, -2.5753e-01, -7.4235e-01, -2.5819e-01
]]],
[[[
-1.0247e-01, -4.8132e-01, -2.7320e-01, 1.3869e+00,
8.6279e-01, -8.4183e-01, 9.4207e-01, -1.7862e-02
]]],
[[[
-2.1337e+00, -4.4044e-02, 1.6644e+00, 2.1575e+00,
-1.0033e+00, -1.5468e+00, 1.8768e-01, 9.2515e-03
]]],
[[[
-9.3583e-01, -1.4434e+00, 4.0691e-01, 2.4966e+00,
-5.2040e-01, -3.9659e+00, 1.1791e+00, -4.4479e-01
]]],
[[[
-9.4713e-01, 1.3847e+00, 1.4843e+00, -2.0148e-01,
-3.3666e-01, 2.7286e+00, -8.4753e-01, 4.5405e-01
]]],
[[[
-9.5922e-01, 9.9506e-01, 5.1789e-01, -1.0595e+00,
-9.4146e-01, 1.2235e+00, -1.2111e+00, 2.4210e-01
]]],
[[[
1.1924e+00, 4.9652e-01, -3.8619e-01, -1.5328e+00,
4.2940e-01, 2.0451e+00, -4.4219e-01, 1.2412e-01
]]],
[[[
-9.8240e-01, 1.7715e-01, 6.2259e-01, 4.3668e-01,
5.0427e-01, -1.5603e+00, 9.2906e-01, -6.1793e-01
]]],
[[[
4.9682e-02, 8.1487e-01, 8.7411e-02, 2.8222e-01,
-6.2244e-03, 6.6128e-01, -5.0314e-01, 2.4081e-01
]]],
[[[
-4.6349e-02, 1.1969e+00, -6.4845e-01, 1.1922e+00,
-9.2116e-01, 4.8479e-01, -1.2045e+00, 1.9728e-03
]]],
[[[
9.7235e-01, -1.8080e+00, -1.1013e-01, -4.7668e-01,
-2.1431e-01, -1.4644e+00, 1.3455e+00, -1.0921e+00
]]],
[[[
2.4253e-01, 1.3804e+00, 4.1076e-01, 7.7609e-01,
2.6673e-02, 3.4096e-01, -2.9748e-01, -2.2011e-01
]]],
[[[
-1.7477e-01, 3.8498e-02, -6.2041e-02, 1.3576e-01,
-6.7703e-02, -1.6477e-01, -2.6009e-02, -4.3462e-02
]]],
[[[
1.1731e+00, -3.1510e+00, 6.3514e-01, -2.6042e+00,
1.1486e+00, -5.9488e-01, 2.1648e+00, -1.4449e-01
]]],
[[[
-7.3512e-01, 1.0774e+00, -7.7084e-01, 1.4550e+00,
-9.9514e-01, -2.4492e-01, -1.0681e+00, -1.4480e-01
]]]])
gt_corr_y = Tensor([[[[-0.7859], [-2.5235], [-0.1638], [-1.7374]]],
[[[0.2394], [-1.1043], [0.7389], [-0.9226]]],
[[[1.3612], [-0.8983], [0.4627], [1.2890]]],
[[[0.9271], [0.8622], [0.8074], [0.3946]]],
[[[-1.9521], [-1.3409], [0.1167], [-1.1078]]],
[[[0.5672], [-0.3065], [-2.4372], [0.0377]]],
[[[0.3696], [-0.2678], [0.2223], [-0.7076]]],
[[[-0.1540], [0.9861], [0.4548], [0.8085]]],
[[[0.3200], [-0.1821], [0.3330], [0.5235]]],
[[[-1.2894], [1.4327], [-1.1278], [-0.9617]]],
[[[-0.0793], [0.9519], [-0.9306], [-1.1621]]],
[[[-0.6505], [-0.4043], [-0.7480], [-0.3882]]],
[[[-0.6627], [0.9154], [-0.2077], [0.6645]]],
[[[-0.8653], [0.4203], [-0.7476], [0.4841]]],
[[[-0.0519], [-0.3777], [-0.4742], [1.0568]]],
[[[1.0291], [-0.1783], [-0.3134], [0.1957]]],
[[[-0.0319], [-0.6722], [0.4891], [-0.4209]]],
[[[-0.8757], [0.6087], [-0.1214], [-1.2429]]],
[[[0.4911], [-0.0331], [-0.2732], [-1.0410]]],
[[[0.1744], [1.5699], [2.1575], [-0.5303]]],
[[[-1.6107], [-2.1319], [-0.5204], [-1.4597]]],
[[[1.1992], [-0.0582], [2.7286], [-1.3258]]],
[[[0.4204], [-0.9119], [-1.2111], [2.2573]]],
[[[-0.1077], [0.1721], [0.1241], [0.2528]]],
[[[-0.2588], [-1.1413], [-0.4455], [-0.9824]]],
[[[0.4643], [0.7624], [-0.3894], [0.8149]]],
[[[0.4720], [-0.0948], [-0.9961], [-0.6485]]],
[[[0.1515], [0.4681], [0.8048], [-0.4767]]],
[[[-1.0914], [0.2139], [0.1504], [0.0267]]],
[[[0.1819], [-0.0381], [0.2858], [-0.1648]]],
[[[-1.2718], [1.1349], [-0.6469], [2.1648]]],
[[[-1.1768], [0.2256], [0.2718], [-0.1448]]]])
corr = Correlation1D()
correlation = corr(_feat1, _feat2, _feat2)
assert correlation[0].size() == (b * h * w, 1, 1, w)
assert correlation[1].size() == (b * h * w, 1, h, 1)
assert torch.allclose(correlation[0], gt_corr_x, atol=1e-4)
assert torch.allclose(correlation[1], gt_corr_y, atol=1e-4)
Loading