From 358594b8cdb7d04c5f0090e15cd2011cce799caf Mon Sep 17 00:00:00 2001 From: Fc-idris <23311590278@qq.com> Date: Mon, 8 Aug 2022 20:47:17 +0800 Subject: [PATCH 1/5] feat: add flow1d correlation and correlation lookup --- mmflow/models/utils/correlation1d.py | 91 ++++++++ mmflow/ops/corr_lookup.py | 85 +++++++ .../test_utils/test_correlation1d.py | 211 ++++++++++++++++++ tests/test_op/test_corr_lookup.py | 128 +++++++++++ 4 files changed, 515 insertions(+) create mode 100644 mmflow/models/utils/correlation1d.py create mode 100644 tests/test_models/test_utils/test_correlation1d.py diff --git a/mmflow/models/utils/correlation1d.py b/mmflow/models/utils/correlation1d.py new file mode 100644 index 00000000..5e444b31 --- /dev/null +++ b/mmflow/models/utils/correlation1d.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from mmcv.runner import BaseModule +from torch import Tensor + + +class Correlation1D(BaseModule): + """Correlation1D Module. + + The neck of Flow1D, which calculates correlation tensor of input features + with the method of 3D cost volume. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + feat1: Tensor, + feat2_x: Tensor, + feat2_y: Tensor, + ) -> Sequence[Tensor]: + """Forward function for Correlation1D. + + Args: + feat1 (Tensor): The feature from first input image. + feat2_x (Tensor): The 1D cross attention feature2 on x direction. + feat2_y (Tensor): The 1D cross attention feature2 on y direction. + + Returns: + Sequence[Tensor]: Correlation list, include x correlation + and y correlation. + """ + corr_x = self.corr_x(feat1, feat2_x) + corr_y = self.corr_y(feat1, feat2_y) + corr = [corr_x, corr_y] + return corr + + @staticmethod + def corr_x(feature1: Tensor, feature2: Tensor) -> Tensor: + """corr_x function for Correlation1D. + + Args: + feature1 (Tensor): Input feature1. + feature2 (Tensor): Input feature2. + + Returns: + Tensor: x correlation. + """ + b, c, h, w = feature1.shape # [B, C, H, W] + scale_factor = c**0.5 + + # x direction, corr shape is [B, H, W, W] + feature1 = feature1.permute(0, 2, 3, 1) + feature2 = feature2.permute(0, 2, 1, 3) + corr = torch.matmul(feature1, feature2) + + # reshape to [B*H*W, 1, 1, W] + corr = corr.unsqueeze(3).unsqueeze(3) + corr = corr / scale_factor + corr = corr.flatten(0, 2) + + return corr + + @staticmethod + def corr_y(feature1: Tensor, feature2: Tensor) -> Tensor: + """corr_y function for Correlation1D. + + Args: + feature1 (Tensor): Input feature1. + feature2 (Tensor): Input feature2. + + Returns: + Tensor: y correlation. + """ + b, c, h, w = feature1.shape # [B, C, H, W] + scale_factor = c**0.5 + + # y direction, corr shape is [B, W, H, H] + feature1 = feature1.permute(0, 3, 2, 1) + feature2 = feature2.permute(0, 3, 1, 2) + corr = torch.matmul(feature1, feature2) + + # reshape to [B*H*W, 1, H, 1] + corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1) + corr = corr / scale_factor + corr = corr.flatten(0, 2) + + return corr diff --git a/mmflow/ops/corr_lookup.py b/mmflow/ops/corr_lookup.py index db8d375b..c1e68ca4 100644 --- a/mmflow/ops/corr_lookup.py +++ b/mmflow/ops/corr_lookup.py @@ -135,3 +135,88 @@ def forward(self, corr_pyramid: Sequence[Tensor], flow: Tensor) -> Tensor: out = torch.cat(out_corr_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() + + +@OPERATORS.register_module() +class CorrLookupFlow1D(nn.Module): + """Correlation lookup operator for Flow1D. + + This operator is used in `Flow1D`_ + + Args: + radius (int): the radius of the local neighborhood of the pixels. + Default to 4. + mode (str): interpolation mode to calculate output values 'bilinear' + | 'nearest' | 'bicubic'. Default: 'bilinear' Note: mode='bicubic' + supports only 4-D input. + padding_mode (str): padding mode for outside grid values 'zeros' | + 'border' | 'reflection'. Default: 'zeros' + align_corners (bool): If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s corner + pixels. If set to False, they are instead considered as referring + to the corner points of the input’s corner pixels, making the + sampling more resolution agnostic. Default to True. + """ + + def __init__(self, + radius: int = 4, + mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = True) -> None: + super().__init__() + self.r = radius + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + + def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor: + """Forward function of Correlation lookup for Flow1D. + + Args: + corr (Sequence[Tensor]): Correlation on x and y direction. + flow (Tensor): Current estimated optical flow. + + Returns: + Tensor: lookup cost volume on the correlation of x and y directions + concatenate together. + """ + corr_x = corr[0] + corr_y = corr[1] + B, _, H, W = flow.shape + + # reshape flow to [B, H, W, 2] + flow = flow.permute(0, 2, 3, 1) + coords_x = flow[:, :, :, 0] + coords_y = flow[:, :, :, 1] + coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1) + coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1) + centroid_x = coords_x.view(B * H * W, 1, 1, 2) + centroid_y = coords_y.view(B * H * W, 1, 1, 2) + + dx = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + dy = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + + delta_x = torch.stack((dx, torch.zeros_like(dx)), dim=-1) + delta_y = torch.stack((torch.zeros_like(dy), dy), dim=-1) + # [1, 2r+1, 1, 2] + delta_y = delta_y.view(1, 2 * self.r + 1, 1, 2) + + coords_x = centroid_x + delta_x + coords_y = centroid_y + delta_y + + corr_x = bilinear_sample(corr_x, coords_x, self.mode, + self.padding_mode, self.align_corners) + corr_y = bilinear_sample(corr_y, coords_y, self.mode, + self.padding_mode, self.align_corners) + + # shape is [B, 2r+1, H, W] + corr_x = corr_x.view(B, H, W, -1) + corr_x = corr_x.permute(0, 3, 1, 2).contiguous() + corr_y = corr_y.view(B, H, W, -1) + corr_y = corr_y.permute(0, 3, 1, 2).contiguous() + + correlation = torch.cat((corr_x, corr_y), dim=1) + + return correlation diff --git a/tests/test_models/test_utils/test_correlation1d.py b/tests/test_models/test_utils/test_correlation1d.py new file mode 100644 index 00000000..fc6a669f --- /dev/null +++ b/tests/test_models/test_utils/test_correlation1d.py @@ -0,0 +1,211 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from mmflow.models.utils.correlation1d import Correlation1D + +_feat1 = Tensor( + [[[[1.0154, 0.4896, 1.8628, 0.0762, 0.2545, -0.1868, 0.5853, 1.6154], + [-0.4458, -1.3631, -0.6748, 0.2643, 0.8796, 1.2195, -0.9295, 0.3636], + [0.2345, 0.1408, -0.2794, -2.2829, -1.8497, -0.4348, -0.1259, 1.2991], + [0.9833, 0.5806, 0.0429, -1.5982, 1.1363, 0.0071, -1.5662, 0.0415]], + [[-2.5624, 0.4736, 0.3118, -0.1595, 0.4542, -1.2495, -0.3464, -1.1194], + [0.1017, 1.1922, -1.2911, 0.6752, 1.4180, -0.3162, -0.3809, 1.4444], + [-0.8802, 1.5789, -0.7804, -0.2817, 0.3465, -0.6741, 0.1570, 0.1059], + [-0.8849, 0.3025, -0.3609, 0.7738, 0.8476, -0.2813, 1.5131, -1.4178]], + [[0.2065, -0.8124, -0.6505, 1.6508, 1.7852, 1.2732, 0.4985, -0.5486], + [2.7083, 1.0688, 0.4090, -0.1851, 1.0733, 1.1038, -1.4032, 0.2552], + [1.5166, -0.6669, 1.3872, -0.4971, 1.9420, -2.2243, -2.3078, -0.4577], + [-1.7597, 0.7735, 1.1435, -0.5766, 1.0973, -0.1990, -1.1990, 0.1093]], + [[0.2446, 1.8493, 1.7110, 1.1204, -1.7352, -1.3811, -0.2492, 0.8741], + [0.3271, 0.2713, -1.3248, -0.2370, 0.4934, -0.8729, -0.3618, 0.5313], + [0.8359, -0.2329, 0.4883, 0.1030, 0.2581, 0.3148, -0.9930, 0.2271], + [-1.1038, 0.0708, -0.4958, -1.1129, -0.9431, -0.0880, 1.0499, + -0.6881]]]]) +_feat2 = Tensor( + [[[[1.3175, 1.4551, 1.6624, -0.5219, 0.3938, -1.4649, 0.9400, -0.4180], + [0.4486, 0.0388, -0.6881, -1.4353, 1.8669, 0.6907, 0.0128, 0.2979], + [1.7176, 0.3644, -1.2154, -1.9436, 0.9357, 2.0734, -0.3146, 0.1123], + [-0.7050, 1.4828, 0.8406, 0.3374, 0.7549, 0.4404, -0.1620, 0.3539]], + [[1.1737, -0.9930, -0.6959, -1.7765, -0.4785, -0.5701, -0.6154, 0.8447], + [2.2322, 1.2820, -0.9384, -0.2065, 0.1662, 0.9703, 0.1947, -0.7589], + [0.9334, -0.5888, 0.2904, -1.1869, -1.3860, -1.1149, -0.4794, -0.4440], + [1.0862, -1.1460, 0.9998, -1.3857, 1.0615, -0.1334, 1.4889, -0.2771]], + [[0.4017, 0.4662, 0.6031, 2.2982, -1.3094, -0.7295, -0.2682, 0.3263], + [-0.2803, 1.5200, -0.5896, 0.5558, -0.6111, -0.5191, -0.0100, 0.4099], + [0.3736, -1.0845, -0.9815, 0.9264, 0.5722, -2.2061, 0.9850, -0.2834], + [0.2425, 1.4829, -0.8054, 1.1259, -1.0513, 1.3195, -1.7388, 0.3673]], + [[0.0612, 0.3328, 0.1373, -1.9487, 0.8354, -0.7799, -0.4399, 1.7067], + [1.1250, -0.8651, -0.3540, 0.7884, 1.2341, -1.0060, 1.8890, 0.9911], + [0.9935, 0.3770, 1.4380, 0.0396, 0.2286, 2.2238, 0.1141, 0.0866], + [-0.1054, -0.4454, 0.1032, -1.1747, 0.5838, 1.2229, -0.2493, 1.0715]]]]) +b, c, h, w = _feat1.size() + + +def test_correlation(): + gt_corr_x = Tensor([[[[ + -7.8589e-01, 2.0998e+00, 1.8146e+00, 2.0100e+00, 7.7996e-01, + -1.8402e-01, 1.1842e+00, -1.0520e+00 + ]]], + [[[ + 4.9387e-01, 2.3942e-01, 1.2414e-01, -3.2838e+00, + 1.2874e+00, -9.1842e-01, -2.1343e-01, 1.5433e+00 + ]]], + [[[ + 1.3318e+00, 1.3336e+00, 1.3612e+00, -3.1777e+00, + 1.4328e+00, -1.8832e+00, 4.9047e-01, 1.0963e+00 + ]]], + [[[ + 3.2244e-01, 7.0587e-01, 6.9355e-01, 9.2706e-01, + -5.5962e-01, -1.0494e+00, -3.8291e-01, 1.1421e+00 + ]]], + [[[ + 7.3966e-01, 8.7044e-02, 4.7271e-01, 3.2722e+00, + -1.9521e+00, -2.9039e-01, 1.2212e-01, -1.0508e+00 + ]]], + [[[ + -6.4286e-01, 5.5144e-01, 5.6862e-01, 3.9673e+00, + -1.1483e+00, 5.6715e-01, 4.2971e-01, -1.4595e+00 + ]]], + [[[ + 2.7478e-01, 6.7256e-01, 7.4025e-01, 9.7059e-01, + -2.3234e-01, -4.1461e-01, 3.6964e-01, -3.9995e-01 + ]]], + [[[ + 3.2379e-01, 1.7486e+00, 1.6268e+00, -9.0931e-01, + 1.3102e+00, -1.0049e+00, 9.8499e-01, -1.5399e-01 + ]]], + [[[ + -1.8206e-01, 1.9734e+00, -7.5064e-01, 1.1910e+00, + -1.0334e+00, -9.7209e-01, 3.0245e-01, 6.1217e-01 + ]]], + [[[ + 1.0277e+00, 1.4327e+00, -4.5351e-01, 1.2591e+00, + -1.3325e+00, -3.0622e-01, 3.5824e-01, -3.0192e-01 + ]]], + [[[ + -2.3949e+00, 4.3196e-02, 9.5187e-01, 2.0900e-01, + -1.6796e+00, -2.9920e-01, -1.3833e+00, -1.8328e-01 + ]]], + [[[ + 7.0550e-01, 3.9977e-01, -3.1122e-01, -4.0425e-01, + 2.1314e-01, 5.8610e-01, -1.5550e-01, -3.7222e-01 + ]]], + [[[ + 1.9070e+00, 1.5283e+00, -1.3717e+00, -2.8489e-01, + 9.1540e-01, 4.6496e-01, 6.0432e-01, 5.7434e-02 + ]]], + [[[ + -7.2508e-01, 1.0374e+00, -4.4210e-01, -8.7988e-01, + 2.3618e-01, 4.2033e-01, -8.5295e-01, 9.5285e-02 + ]]], + [[[ + -6.4046e-01, -1.1721e+00, 9.7621e-01, 1.7381e-01, + -6.9380e-01, 4.0390e-02, -3.7773e-01, -4.6079e-01 + ]]], + [[[ + 1.9567e+00, 8.9705e-01, -9.7208e-01, -1.2971e-01, + 7.0929e-01, 4.9284e-01, 6.4348e-01, -1.7833e-01 + ]]], + [[[ + 4.8913e-01, -3.6295e-01, -4.1357e-01, 1.0135e+00, + 1.2491e+00, -9.6748e-03, 9.6871e-01, 2.9864e-02 + ]]], + [[[ + 6.1752e-01, -1.2145e-01, 3.0352e-01, -1.3873e+00, + -1.2457e+00, -2.5753e-01, -7.4235e-01, -2.5819e-01 + ]]], + [[[ + -1.0247e-01, -4.8132e-01, -2.7320e-01, 1.3869e+00, + 8.6279e-01, -8.4183e-01, 9.4207e-01, -1.7862e-02 + ]]], + [[[ + -2.1337e+00, -4.4044e-02, 1.6644e+00, 2.1575e+00, + -1.0033e+00, -1.5468e+00, 1.8768e-01, 9.2515e-03 + ]]], + [[[ + -9.3583e-01, -1.4434e+00, 4.0691e-01, 2.4966e+00, + -5.2040e-01, -3.9659e+00, 1.1791e+00, -4.4479e-01 + ]]], + [[[ + -9.4713e-01, 1.3847e+00, 1.4843e+00, -2.0148e-01, + -3.3666e-01, 2.7286e+00, -8.4753e-01, 4.5405e-01 + ]]], + [[[ + -9.5922e-01, 9.9506e-01, 5.1789e-01, -1.0595e+00, + -9.4146e-01, 1.2235e+00, -1.2111e+00, 2.4210e-01 + ]]], + [[[ + 1.1924e+00, 4.9652e-01, -3.8619e-01, -1.5328e+00, + 4.2940e-01, 2.0451e+00, -4.4219e-01, 1.2412e-01 + ]]], + [[[ + -9.8240e-01, 1.7715e-01, 6.2259e-01, 4.3668e-01, + 5.0427e-01, -1.5603e+00, 9.2906e-01, -6.1793e-01 + ]]], + [[[ + 4.9682e-02, 8.1487e-01, 8.7411e-02, 2.8222e-01, + -6.2244e-03, 6.6128e-01, -5.0314e-01, 2.4081e-01 + ]]], + [[[ + -4.6349e-02, 1.1969e+00, -6.4845e-01, 1.1922e+00, + -9.2116e-01, 4.8479e-01, -1.2045e+00, 1.9728e-03 + ]]], + [[[ + 9.7235e-01, -1.8080e+00, -1.1013e-01, -4.7668e-01, + -2.1431e-01, -1.4644e+00, 1.3455e+00, -1.0921e+00 + ]]], + [[[ + 2.4253e-01, 1.3804e+00, 4.1076e-01, 7.7609e-01, + 2.6673e-02, 3.4096e-01, -2.9748e-01, -2.2011e-01 + ]]], + [[[ + -1.7477e-01, 3.8498e-02, -6.2041e-02, 1.3576e-01, + -6.7703e-02, -1.6477e-01, -2.6009e-02, -4.3462e-02 + ]]], + [[[ + 1.1731e+00, -3.1510e+00, 6.3514e-01, -2.6042e+00, + 1.1486e+00, -5.9488e-01, 2.1648e+00, -1.4449e-01 + ]]], + [[[ + -7.3512e-01, 1.0774e+00, -7.7084e-01, 1.4550e+00, + -9.9514e-01, -2.4492e-01, -1.0681e+00, -1.4480e-01 + ]]]]) + gt_corr_y = Tensor([[[[-0.7859], [-2.5235], [-0.1638], [-1.7374]]], + [[[0.2394], [-1.1043], [0.7389], [-0.9226]]], + [[[1.3612], [-0.8983], [0.4627], [1.2890]]], + [[[0.9271], [0.8622], [0.8074], [0.3946]]], + [[[-1.9521], [-1.3409], [0.1167], [-1.1078]]], + [[[0.5672], [-0.3065], [-2.4372], [0.0377]]], + [[[0.3696], [-0.2678], [0.2223], [-0.7076]]], + [[[-0.1540], [0.9861], [0.4548], [0.8085]]], + [[[0.3200], [-0.1821], [0.3330], [0.5235]]], + [[[-1.2894], [1.4327], [-1.1278], [-0.9617]]], + [[[-0.0793], [0.9519], [-0.9306], [-1.1621]]], + [[[-0.6505], [-0.4043], [-0.7480], [-0.3882]]], + [[[-0.6627], [0.9154], [-0.2077], [0.6645]]], + [[[-0.8653], [0.4203], [-0.7476], [0.4841]]], + [[[-0.0519], [-0.3777], [-0.4742], [1.0568]]], + [[[1.0291], [-0.1783], [-0.3134], [0.1957]]], + [[[-0.0319], [-0.6722], [0.4891], [-0.4209]]], + [[[-0.8757], [0.6087], [-0.1214], [-1.2429]]], + [[[0.4911], [-0.0331], [-0.2732], [-1.0410]]], + [[[0.1744], [1.5699], [2.1575], [-0.5303]]], + [[[-1.6107], [-2.1319], [-0.5204], [-1.4597]]], + [[[1.1992], [-0.0582], [2.7286], [-1.3258]]], + [[[0.4204], [-0.9119], [-1.2111], [2.2573]]], + [[[-0.1077], [0.1721], [0.1241], [0.2528]]], + [[[-0.2588], [-1.1413], [-0.4455], [-0.9824]]], + [[[0.4643], [0.7624], [-0.3894], [0.8149]]], + [[[0.4720], [-0.0948], [-0.9961], [-0.6485]]], + [[[0.1515], [0.4681], [0.8048], [-0.4767]]], + [[[-1.0914], [0.2139], [0.1504], [0.0267]]], + [[[0.1819], [-0.0381], [0.2858], [-0.1648]]], + [[[-1.2718], [1.1349], [-0.6469], [2.1648]]], + [[[-1.1768], [0.2256], [0.2718], [-0.1448]]]]) + corr = Correlation1D() + correlation = corr(_feat1, _feat2, _feat2) + assert correlation[0].size() == (b * h * w, 1, 1, w) + assert correlation[1].size() == (b * h * w, 1, h, 1) + assert torch.allclose(correlation[0], gt_corr_x, atol=1e-4) + assert torch.allclose(correlation[1], gt_corr_y, atol=1e-4) diff --git a/tests/test_op/test_corr_lookup.py b/tests/test_op/test_corr_lookup.py index f31f4a2b..162ebdb7 100644 --- a/tests/test_op/test_corr_lookup.py +++ b/tests/test_op/test_corr_lookup.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch +from torch import Tensor from mmflow.models.decoders.raft_decoder import CorrelationPyramid +from mmflow.models.utils.correlation1d import Correlation1D from mmflow.ops.builder import build_operators from mmflow.ops.corr_lookup import bilinear_sample, coords_grid @@ -56,3 +59,128 @@ def test_corr_lookup(): corr_lpt = corr_lookup_op(corr_pyramid, torch.randn(1, 2, H, W)) assert corr_lpt.shape == torch.Size((1, 81 * 4, H, W)) + + +@pytest.mark.parametrize('mode', ['bilinear', 'nearest', 'bicubic']) +@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection']) +@pytest.mark.parametrize('align_corners', [True, False]) +def test_corr_lookup_flow1d(mode, padding_mode, align_corners): + corr_block = Correlation1D() + feat1 = torch.arange(0, 64) + feat1 = feat1.view(1, 2, 4, 8) + feat2 = feat1 + 1 + flow = torch.ones_like(feat1) + b, c, h, w = feat1.size() + radius = 1 + + # gronud truth + gt_corr_x = Tensor([[[[ + 746.7048, 770.7464, 794.7880, 818.8297, 842.8713, 866.9129, 890.9546, + 914.9962 + ], + [ + 1210.5668, 1245.9221, 1281.2775, 1316.6328, + 1351.9882, 1387.3435, 1422.6989, 1458.0542 + ], + [ + 1855.4482, 1902.1173, 1948.7864, 1995.4553, + 2042.1244, 2088.7935, 2135.4624, 2182.1316 + ], + [ + 2681.3489, 2739.3318, 2797.3145, 2855.2971, + 2913.2800, 2971.2627, 3029.2456, 3087.2283 + ]], + [[ + 769.3322, 794.7880, 820.2439, 845.6997, 871.1556, + 896.6114, 922.0673, 947.5231 + ], + [ + 1244.5079, 1281.2775, 1318.0471, 1354.8167, + 1391.5862, 1428.3557, 1465.1252, 1501.8948 + ], + [ + 1900.7030, 1948.7864, 1996.8696, 2044.9529, + 2093.0359, 2141.1194, 2189.2026, 2237.2859 + ], + [ + 2737.9175, 2797.3145, 2856.7114, 2916.1084, + 2975.5054, 3034.9023, 3094.2993, 3153.6963 + ]], + [[ + 791.9596, 818.8297, 845.6997, 872.5698, 899.4398, + 926.3099, 953.1799, 980.0500 + ], + [ + 1278.4491, 1316.6328, 1354.8167, 1393.0004, + 1431.1842, 1469.3679, 1507.5516, 1545.7355 + ], + [ + 1945.9579, 1995.4553, 2044.9529, 2094.4504, + 2143.9478, 2193.4453, 2242.9426, 2292.4402 + ], + [ + 2794.4861, 2855.2971, 2916.1084, 2976.9197, + 3037.7307, 3098.5420, 3159.3533, 3220.1643 + ]]]]) + gt_corr_y = Tensor([[[[ + 746.7048, 794.7880, 845.6997, 899.4398, 956.0084, 1015.4053, 1077.6307, + 1142.6846 + ], + [ + 939.0378, 998.4348, 1060.6602, 1125.7140, + 1193.5963, 1264.3070, 1337.8461, 1414.2136 + ], + [ + 1131.3708, 1202.0815, 1275.6206, 1351.9882, + 1431.1842, 1513.2085, 1598.0614, 1685.7426 + ], + [ + 1323.7039, 1405.7283, 1490.5812, 1578.2623, + 1668.7720, 1762.1101, 1858.2766, 1957.2716 + ]], + [[ + 927.7241, 987.1211, 1049.3464, 1114.4003, + 1182.2826, 1252.9933, 1326.5323, 1402.8999 + ], + [ + 1210.5668, 1281.2775, 1354.8167, 1431.1842, + 1510.3801, 1592.4045, 1677.2573, 1764.9386 + ], + [ + 1493.4095, 1575.4340, 1660.2867, 1747.9680, + 1838.4777, 1931.8158, 2027.9823, 2126.9773 + ], + [ + 1776.2523, 1869.5903, 1965.7568, 2064.7520, + 2166.5752, 2271.2271, 2378.7073, 2489.0159 + ]], + [[ + 1108.7434, 1179.4541, 1252.9933, 1329.3607, + 1408.5568, 1490.5812, 1575.4340, 1663.1152 + ], + [ + 1482.0958, 1564.1202, 1648.9730, 1736.6543, + 1827.1639, 1920.5021, 2016.6686, 2115.6636 + ], + [ + 1855.4482, 1948.7864, 2044.9529, 2143.9478, + 2245.7712, 2350.4231, 2457.9033, 2568.2119 + ], + [ + 2228.8005, 2333.4524, 2440.9326, 2551.2412, + 2664.3784, 2780.3440, 2899.1379, 3020.7603 + ]]]]) + gt_corr = torch.cat((gt_corr_x, gt_corr_y), dim=1) + correlation = corr_block(feat1, feat2, feat2) + + corr_lookup_cfg = dict( + type='CorrLookupFlow1D', + radius=radius, + mode=mode, + padding_mode=padding_mode, + align_corners=True) + corr_lookup_op = build_operators(corr_lookup_cfg) + + corr_xy = corr_lookup_op(correlation, flow) + assert corr_xy.size() == (b, 2 * (2 * radius + 1), h, w) + assert torch.allclose(gt_corr, corr_xy, atol=1e-4) From 4e04e9efaed91fb134bfc31c507bd055476cdc77 Mon Sep 17 00:00:00 2001 From: Fc-idris <23311590278@qq.com> Date: Tue, 16 Aug 2022 20:22:58 +0800 Subject: [PATCH 2/5] feat: add flow1d correlation and correlation lookup --- mmflow/models/utils/correlation1d.py | 84 ++----- mmflow/ops/corr_lookup.py | 6 +- .../test_utils/test_correlation1d.py | 232 +++--------------- tests/test_op/test_corr_lookup.py | 124 ++-------- 4 files changed, 77 insertions(+), 369 deletions(-) diff --git a/mmflow/models/utils/correlation1d.py b/mmflow/models/utils/correlation1d.py index 5e444b31..d25e1e49 100644 --- a/mmflow/models/utils/correlation1d.py +++ b/mmflow/models/utils/correlation1d.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Sequence import torch from mmcv.runner import BaseModule @@ -16,76 +15,29 @@ class Correlation1D(BaseModule): def __init__(self): super().__init__() - def forward( - self, - feat1: Tensor, - feat2_x: Tensor, - feat2_y: Tensor, - ) -> Sequence[Tensor]: + def forward(self, + feat1: Tensor, + feat2: Tensor, + y_direction: bool = False) -> Tensor: """Forward function for Correlation1D. Args: feat1 (Tensor): The feature from first input image. - feat2_x (Tensor): The 1D cross attention feature2 on x direction. - feat2_y (Tensor): The 1D cross attention feature2 on y direction. - - Returns: - Sequence[Tensor]: Correlation list, include x correlation - and y correlation. - """ - corr_x = self.corr_x(feat1, feat2_x) - corr_y = self.corr_y(feat1, feat2_y) - corr = [corr_x, corr_y] - return corr - - @staticmethod - def corr_x(feature1: Tensor, feature2: Tensor) -> Tensor: - """corr_x function for Correlation1D. - - Args: - feature1 (Tensor): Input feature1. - feature2 (Tensor): Input feature2. - - Returns: - Tensor: x correlation. - """ - b, c, h, w = feature1.shape # [B, C, H, W] - scale_factor = c**0.5 - - # x direction, corr shape is [B, H, W, W] - feature1 = feature1.permute(0, 2, 3, 1) - feature2 = feature2.permute(0, 2, 1, 3) - corr = torch.matmul(feature1, feature2) - - # reshape to [B*H*W, 1, 1, W] - corr = corr.unsqueeze(3).unsqueeze(3) - corr = corr / scale_factor - corr = corr.flatten(0, 2) - - return corr - - @staticmethod - def corr_y(feature1: Tensor, feature2: Tensor) -> Tensor: - """corr_y function for Correlation1D. - - Args: - feature1 (Tensor): Input feature1. - feature2 (Tensor): Input feature2. - + feat2 (Tensor): The 1D cross attention feat2 on x or y direction. + y_direction (bool): whether y direction or not. Returns: - Tensor: y correlation. + Tensor: Correlation of x correlation or y correlation. """ - b, c, h, w = feature1.shape # [B, C, H, W] + b, c, h, w = feat1.shape scale_factor = c**0.5 - - # y direction, corr shape is [B, W, H, H] - feature1 = feature1.permute(0, 3, 2, 1) - feature2 = feature2.permute(0, 3, 1, 2) - corr = torch.matmul(feature1, feature2) - - # reshape to [B*H*W, 1, H, 1] - corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1) - corr = corr / scale_factor - corr = corr.flatten(0, 2) - + if y_direction: + # y direction, corr shape is [B, W, H, H] + feat1 = feat1.permute(0, 3, 2, 1) + feat2 = feat2.permute(0, 3, 1, 2) + corr = torch.matmul(feat1, feat2) / scale_factor + else: + # x direction, corr shape is [B, H, W, W] + feat1 = feat1.permute(0, 2, 3, 1) + feat2 = feat2.permute(0, 2, 1, 3) + corr = torch.matmul(feat1, feat2) / scale_factor return corr diff --git a/mmflow/ops/corr_lookup.py b/mmflow/ops/corr_lookup.py index c1e68ca4..67aaf0c1 100644 --- a/mmflow/ops/corr_lookup.py +++ b/mmflow/ops/corr_lookup.py @@ -180,9 +180,11 @@ def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor: Tensor: lookup cost volume on the correlation of x and y directions concatenate together. """ - corr_x = corr[0] - corr_y = corr[1] B, _, H, W = flow.shape + # reshape corr_x to [B*H*W, 1, 1, W] + corr_x = corr[0].view(-1, 1, 1, W) + # reshape corr_y to [B*H*W, 1, H, 1] + corr_y = corr[1].permute(0, 2, 1, 3).contiguous().view(-1, 1, H, 1) # reshape flow to [B, H, W, 2] flow = flow.permute(0, 2, 3, 1) diff --git a/tests/test_models/test_utils/test_correlation1d.py b/tests/test_models/test_utils/test_correlation1d.py index fc6a669f..fa9f4ffd 100644 --- a/tests/test_models/test_utils/test_correlation1d.py +++ b/tests/test_models/test_utils/test_correlation1d.py @@ -4,208 +4,40 @@ from mmflow.models.utils.correlation1d import Correlation1D -_feat1 = Tensor( - [[[[1.0154, 0.4896, 1.8628, 0.0762, 0.2545, -0.1868, 0.5853, 1.6154], - [-0.4458, -1.3631, -0.6748, 0.2643, 0.8796, 1.2195, -0.9295, 0.3636], - [0.2345, 0.1408, -0.2794, -2.2829, -1.8497, -0.4348, -0.1259, 1.2991], - [0.9833, 0.5806, 0.0429, -1.5982, 1.1363, 0.0071, -1.5662, 0.0415]], - [[-2.5624, 0.4736, 0.3118, -0.1595, 0.4542, -1.2495, -0.3464, -1.1194], - [0.1017, 1.1922, -1.2911, 0.6752, 1.4180, -0.3162, -0.3809, 1.4444], - [-0.8802, 1.5789, -0.7804, -0.2817, 0.3465, -0.6741, 0.1570, 0.1059], - [-0.8849, 0.3025, -0.3609, 0.7738, 0.8476, -0.2813, 1.5131, -1.4178]], - [[0.2065, -0.8124, -0.6505, 1.6508, 1.7852, 1.2732, 0.4985, -0.5486], - [2.7083, 1.0688, 0.4090, -0.1851, 1.0733, 1.1038, -1.4032, 0.2552], - [1.5166, -0.6669, 1.3872, -0.4971, 1.9420, -2.2243, -2.3078, -0.4577], - [-1.7597, 0.7735, 1.1435, -0.5766, 1.0973, -0.1990, -1.1990, 0.1093]], - [[0.2446, 1.8493, 1.7110, 1.1204, -1.7352, -1.3811, -0.2492, 0.8741], - [0.3271, 0.2713, -1.3248, -0.2370, 0.4934, -0.8729, -0.3618, 0.5313], - [0.8359, -0.2329, 0.4883, 0.1030, 0.2581, 0.3148, -0.9930, 0.2271], - [-1.1038, 0.0708, -0.4958, -1.1129, -0.9431, -0.0880, 1.0499, - -0.6881]]]]) -_feat2 = Tensor( - [[[[1.3175, 1.4551, 1.6624, -0.5219, 0.3938, -1.4649, 0.9400, -0.4180], - [0.4486, 0.0388, -0.6881, -1.4353, 1.8669, 0.6907, 0.0128, 0.2979], - [1.7176, 0.3644, -1.2154, -1.9436, 0.9357, 2.0734, -0.3146, 0.1123], - [-0.7050, 1.4828, 0.8406, 0.3374, 0.7549, 0.4404, -0.1620, 0.3539]], - [[1.1737, -0.9930, -0.6959, -1.7765, -0.4785, -0.5701, -0.6154, 0.8447], - [2.2322, 1.2820, -0.9384, -0.2065, 0.1662, 0.9703, 0.1947, -0.7589], - [0.9334, -0.5888, 0.2904, -1.1869, -1.3860, -1.1149, -0.4794, -0.4440], - [1.0862, -1.1460, 0.9998, -1.3857, 1.0615, -0.1334, 1.4889, -0.2771]], - [[0.4017, 0.4662, 0.6031, 2.2982, -1.3094, -0.7295, -0.2682, 0.3263], - [-0.2803, 1.5200, -0.5896, 0.5558, -0.6111, -0.5191, -0.0100, 0.4099], - [0.3736, -1.0845, -0.9815, 0.9264, 0.5722, -2.2061, 0.9850, -0.2834], - [0.2425, 1.4829, -0.8054, 1.1259, -1.0513, 1.3195, -1.7388, 0.3673]], - [[0.0612, 0.3328, 0.1373, -1.9487, 0.8354, -0.7799, -0.4399, 1.7067], - [1.1250, -0.8651, -0.3540, 0.7884, 1.2341, -1.0060, 1.8890, 0.9911], - [0.9935, 0.3770, 1.4380, 0.0396, 0.2286, 2.2238, 0.1141, 0.0866], - [-0.1054, -0.4454, 0.1032, -1.1747, 0.5838, 1.2229, -0.2493, 1.0715]]]]) +_feat1 = torch.arange(0, 24).view(1, 2, 3, 4) +_feat2 = _feat1 + 1 b, c, h, w = _feat1.size() def test_correlation(): - gt_corr_x = Tensor([[[[ - -7.8589e-01, 2.0998e+00, 1.8146e+00, 2.0100e+00, 7.7996e-01, - -1.8402e-01, 1.1842e+00, -1.0520e+00 - ]]], - [[[ - 4.9387e-01, 2.3942e-01, 1.2414e-01, -3.2838e+00, - 1.2874e+00, -9.1842e-01, -2.1343e-01, 1.5433e+00 - ]]], - [[[ - 1.3318e+00, 1.3336e+00, 1.3612e+00, -3.1777e+00, - 1.4328e+00, -1.8832e+00, 4.9047e-01, 1.0963e+00 - ]]], - [[[ - 3.2244e-01, 7.0587e-01, 6.9355e-01, 9.2706e-01, - -5.5962e-01, -1.0494e+00, -3.8291e-01, 1.1421e+00 - ]]], - [[[ - 7.3966e-01, 8.7044e-02, 4.7271e-01, 3.2722e+00, - -1.9521e+00, -2.9039e-01, 1.2212e-01, -1.0508e+00 - ]]], - [[[ - -6.4286e-01, 5.5144e-01, 5.6862e-01, 3.9673e+00, - -1.1483e+00, 5.6715e-01, 4.2971e-01, -1.4595e+00 - ]]], - [[[ - 2.7478e-01, 6.7256e-01, 7.4025e-01, 9.7059e-01, - -2.3234e-01, -4.1461e-01, 3.6964e-01, -3.9995e-01 - ]]], - [[[ - 3.2379e-01, 1.7486e+00, 1.6268e+00, -9.0931e-01, - 1.3102e+00, -1.0049e+00, 9.8499e-01, -1.5399e-01 - ]]], - [[[ - -1.8206e-01, 1.9734e+00, -7.5064e-01, 1.1910e+00, - -1.0334e+00, -9.7209e-01, 3.0245e-01, 6.1217e-01 - ]]], - [[[ - 1.0277e+00, 1.4327e+00, -4.5351e-01, 1.2591e+00, - -1.3325e+00, -3.0622e-01, 3.5824e-01, -3.0192e-01 - ]]], - [[[ - -2.3949e+00, 4.3196e-02, 9.5187e-01, 2.0900e-01, - -1.6796e+00, -2.9920e-01, -1.3833e+00, -1.8328e-01 - ]]], - [[[ - 7.0550e-01, 3.9977e-01, -3.1122e-01, -4.0425e-01, - 2.1314e-01, 5.8610e-01, -1.5550e-01, -3.7222e-01 - ]]], - [[[ - 1.9070e+00, 1.5283e+00, -1.3717e+00, -2.8489e-01, - 9.1540e-01, 4.6496e-01, 6.0432e-01, 5.7434e-02 - ]]], - [[[ - -7.2508e-01, 1.0374e+00, -4.4210e-01, -8.7988e-01, - 2.3618e-01, 4.2033e-01, -8.5295e-01, 9.5285e-02 - ]]], - [[[ - -6.4046e-01, -1.1721e+00, 9.7621e-01, 1.7381e-01, - -6.9380e-01, 4.0390e-02, -3.7773e-01, -4.6079e-01 - ]]], - [[[ - 1.9567e+00, 8.9705e-01, -9.7208e-01, -1.2971e-01, - 7.0929e-01, 4.9284e-01, 6.4348e-01, -1.7833e-01 - ]]], - [[[ - 4.8913e-01, -3.6295e-01, -4.1357e-01, 1.0135e+00, - 1.2491e+00, -9.6748e-03, 9.6871e-01, 2.9864e-02 - ]]], - [[[ - 6.1752e-01, -1.2145e-01, 3.0352e-01, -1.3873e+00, - -1.2457e+00, -2.5753e-01, -7.4235e-01, -2.5819e-01 - ]]], - [[[ - -1.0247e-01, -4.8132e-01, -2.7320e-01, 1.3869e+00, - 8.6279e-01, -8.4183e-01, 9.4207e-01, -1.7862e-02 - ]]], - [[[ - -2.1337e+00, -4.4044e-02, 1.6644e+00, 2.1575e+00, - -1.0033e+00, -1.5468e+00, 1.8768e-01, 9.2515e-03 - ]]], - [[[ - -9.3583e-01, -1.4434e+00, 4.0691e-01, 2.4966e+00, - -5.2040e-01, -3.9659e+00, 1.1791e+00, -4.4479e-01 - ]]], - [[[ - -9.4713e-01, 1.3847e+00, 1.4843e+00, -2.0148e-01, - -3.3666e-01, 2.7286e+00, -8.4753e-01, 4.5405e-01 - ]]], - [[[ - -9.5922e-01, 9.9506e-01, 5.1789e-01, -1.0595e+00, - -9.4146e-01, 1.2235e+00, -1.2111e+00, 2.4210e-01 - ]]], - [[[ - 1.1924e+00, 4.9652e-01, -3.8619e-01, -1.5328e+00, - 4.2940e-01, 2.0451e+00, -4.4219e-01, 1.2412e-01 - ]]], - [[[ - -9.8240e-01, 1.7715e-01, 6.2259e-01, 4.3668e-01, - 5.0427e-01, -1.5603e+00, 9.2906e-01, -6.1793e-01 - ]]], - [[[ - 4.9682e-02, 8.1487e-01, 8.7411e-02, 2.8222e-01, - -6.2244e-03, 6.6128e-01, -5.0314e-01, 2.4081e-01 - ]]], - [[[ - -4.6349e-02, 1.1969e+00, -6.4845e-01, 1.1922e+00, - -9.2116e-01, 4.8479e-01, -1.2045e+00, 1.9728e-03 - ]]], - [[[ - 9.7235e-01, -1.8080e+00, -1.1013e-01, -4.7668e-01, - -2.1431e-01, -1.4644e+00, 1.3455e+00, -1.0921e+00 - ]]], - [[[ - 2.4253e-01, 1.3804e+00, 4.1076e-01, 7.7609e-01, - 2.6673e-02, 3.4096e-01, -2.9748e-01, -2.2011e-01 - ]]], - [[[ - -1.7477e-01, 3.8498e-02, -6.2041e-02, 1.3576e-01, - -6.7703e-02, -1.6477e-01, -2.6009e-02, -4.3462e-02 - ]]], - [[[ - 1.1731e+00, -3.1510e+00, 6.3514e-01, -2.6042e+00, - 1.1486e+00, -5.9488e-01, 2.1648e+00, -1.4449e-01 - ]]], - [[[ - -7.3512e-01, 1.0774e+00, -7.7084e-01, 1.4550e+00, - -9.9514e-01, -2.4492e-01, -1.0681e+00, -1.4480e-01 - ]]]]) - gt_corr_y = Tensor([[[[-0.7859], [-2.5235], [-0.1638], [-1.7374]]], - [[[0.2394], [-1.1043], [0.7389], [-0.9226]]], - [[[1.3612], [-0.8983], [0.4627], [1.2890]]], - [[[0.9271], [0.8622], [0.8074], [0.3946]]], - [[[-1.9521], [-1.3409], [0.1167], [-1.1078]]], - [[[0.5672], [-0.3065], [-2.4372], [0.0377]]], - [[[0.3696], [-0.2678], [0.2223], [-0.7076]]], - [[[-0.1540], [0.9861], [0.4548], [0.8085]]], - [[[0.3200], [-0.1821], [0.3330], [0.5235]]], - [[[-1.2894], [1.4327], [-1.1278], [-0.9617]]], - [[[-0.0793], [0.9519], [-0.9306], [-1.1621]]], - [[[-0.6505], [-0.4043], [-0.7480], [-0.3882]]], - [[[-0.6627], [0.9154], [-0.2077], [0.6645]]], - [[[-0.8653], [0.4203], [-0.7476], [0.4841]]], - [[[-0.0519], [-0.3777], [-0.4742], [1.0568]]], - [[[1.0291], [-0.1783], [-0.3134], [0.1957]]], - [[[-0.0319], [-0.6722], [0.4891], [-0.4209]]], - [[[-0.8757], [0.6087], [-0.1214], [-1.2429]]], - [[[0.4911], [-0.0331], [-0.2732], [-1.0410]]], - [[[0.1744], [1.5699], [2.1575], [-0.5303]]], - [[[-1.6107], [-2.1319], [-0.5204], [-1.4597]]], - [[[1.1992], [-0.0582], [2.7286], [-1.3258]]], - [[[0.4204], [-0.9119], [-1.2111], [2.2573]]], - [[[-0.1077], [0.1721], [0.1241], [0.2528]]], - [[[-0.2588], [-1.1413], [-0.4455], [-0.9824]]], - [[[0.4643], [0.7624], [-0.3894], [0.8149]]], - [[[0.4720], [-0.0948], [-0.9961], [-0.6485]]], - [[[0.1515], [0.4681], [0.8048], [-0.4767]]], - [[[-1.0914], [0.2139], [0.1504], [0.0267]]], - [[[0.1819], [-0.0381], [0.2858], [-0.1648]]], - [[[-1.2718], [1.1349], [-0.6469], [2.1648]]], - [[[-1.1768], [0.2256], [0.2718], [-0.1448]]]]) + gt_corr_x = Tensor([[[[110.3087, 118.7939, 127.2792, 135.7645], + [120.2082, 130.1077, 140.0071, 149.9066], + [130.1077, 141.4214, 152.7351, 164.0488], + [140.0071, 152.7351, 165.4630, 178.1909]], + [[206.4752, 220.6173, 234.7595, 248.9016], + [222.0315, 237.5879, 253.1442, 268.7006], + [237.5879, 254.5584, 271.5290, 288.4996], + [253.1442, 271.5290, 289.9138, 308.2986]], + [[347.8965, 367.6955, 387.4945, 407.2935], + [369.1097, 390.3229, 411.5362, 432.7494], + [390.3229, 412.9504, 435.5778, 458.2052], + [411.5362, 435.5778, 459.6194, 483.6610]]]]) + gt_corr_y = Tensor([[[[110.3087, 144.2498, 178.1909], + [149.9066, 206.4752, 263.0437], + [189.5046, 268.7006, 347.8965]], + [[130.1077, 169.7056, 209.3036], + [175.3625, 237.5879, 299.8133], + [220.6173, 305.4701, 390.3229]], + [[152.7351, 197.9899, 243.2447], + [203.6468, 271.5290, 339.4113], + [254.5584, 345.0681, 435.5778]], + [[178.1909, 229.1026, 280.0143], + [234.7595, 308.2986, 381.8377], + [291.3280, 387.4945, 483.6610]]]]) corr = Correlation1D() - correlation = corr(_feat1, _feat2, _feat2) - assert correlation[0].size() == (b * h * w, 1, 1, w) - assert correlation[1].size() == (b * h * w, 1, h, 1) - assert torch.allclose(correlation[0], gt_corr_x, atol=1e-4) - assert torch.allclose(correlation[1], gt_corr_y, atol=1e-4) + corr_x = corr(_feat1, _feat2, False) + corr_y = corr(_feat1, _feat2, True) + assert corr_x.size() == (b, h, w, w) + assert corr_y.size() == (b, w, h, h) + assert torch.allclose(corr_x, gt_corr_x, atol=1e-4) + assert torch.allclose(corr_y, gt_corr_y, atol=1e-4) diff --git a/tests/test_op/test_corr_lookup.py b/tests/test_op/test_corr_lookup.py index 162ebdb7..e5cb99a9 100644 --- a/tests/test_op/test_corr_lookup.py +++ b/tests/test_op/test_corr_lookup.py @@ -20,7 +20,6 @@ def test_coords_grid(): assert grid.shape == torch.Size((2, 2, H, W)) for i in range(H): for j in range(W): - assert torch.all(grid[0, :, i, j] == torch.Tensor((j, i))) @@ -66,113 +65,36 @@ def test_corr_lookup(): @pytest.mark.parametrize('align_corners', [True, False]) def test_corr_lookup_flow1d(mode, padding_mode, align_corners): corr_block = Correlation1D() - feat1 = torch.arange(0, 64) - feat1 = feat1.view(1, 2, 4, 8) + feat1 = torch.arange(0, 24) + feat1 = feat1.view(1, 2, 3, 4) feat2 = feat1 + 1 flow = torch.ones_like(feat1) b, c, h, w = feat1.size() radius = 1 # gronud truth - gt_corr_x = Tensor([[[[ - 746.7048, 770.7464, 794.7880, 818.8297, 842.8713, 866.9129, 890.9546, - 914.9962 - ], - [ - 1210.5668, 1245.9221, 1281.2775, 1316.6328, - 1351.9882, 1387.3435, 1422.6989, 1458.0542 - ], - [ - 1855.4482, 1902.1173, 1948.7864, 1995.4553, - 2042.1244, 2088.7935, 2135.4624, 2182.1316 - ], - [ - 2681.3489, 2739.3318, 2797.3145, 2855.2971, - 2913.2800, 2971.2627, 3029.2456, 3087.2283 - ]], - [[ - 769.3322, 794.7880, 820.2439, 845.6997, 871.1556, - 896.6114, 922.0673, 947.5231 - ], - [ - 1244.5079, 1281.2775, 1318.0471, 1354.8167, - 1391.5862, 1428.3557, 1465.1252, 1501.8948 - ], - [ - 1900.7030, 1948.7864, 1996.8696, 2044.9529, - 2093.0359, 2141.1194, 2189.2026, 2237.2859 - ], - [ - 2737.9175, 2797.3145, 2856.7114, 2916.1084, - 2975.5054, 3034.9023, 3094.2993, 3153.6963 - ]], - [[ - 791.9596, 818.8297, 845.6997, 872.5698, 899.4398, - 926.3099, 953.1799, 980.0500 - ], - [ - 1278.4491, 1316.6328, 1354.8167, 1393.0004, - 1431.1842, 1469.3679, 1507.5516, 1545.7355 - ], - [ - 1945.9579, 1995.4553, 2044.9529, 2094.4504, - 2143.9478, 2193.4453, 2242.9426, 2292.4402 - ], - [ - 2794.4861, 2855.2971, 2916.1084, 2976.9197, - 3037.7307, 3098.5420, 3159.3533, 3220.1643 - ]]]]) - gt_corr_y = Tensor([[[[ - 746.7048, 794.7880, 845.6997, 899.4398, 956.0084, 1015.4053, 1077.6307, - 1142.6846 - ], - [ - 939.0378, 998.4348, 1060.6602, 1125.7140, - 1193.5963, 1264.3070, 1337.8461, 1414.2136 - ], - [ - 1131.3708, 1202.0815, 1275.6206, 1351.9882, - 1431.1842, 1513.2085, 1598.0614, 1685.7426 - ], - [ - 1323.7039, 1405.7283, 1490.5812, 1578.2623, - 1668.7720, 1762.1101, 1858.2766, 1957.2716 - ]], - [[ - 927.7241, 987.1211, 1049.3464, 1114.4003, - 1182.2826, 1252.9933, 1326.5323, 1402.8999 - ], - [ - 1210.5668, 1281.2775, 1354.8167, 1431.1842, - 1510.3801, 1592.4045, 1677.2573, 1764.9386 - ], - [ - 1493.4095, 1575.4340, 1660.2867, 1747.9680, - 1838.4777, 1931.8158, 2027.9823, 2126.9773 - ], - [ - 1776.2523, 1869.5903, 1965.7568, 2064.7520, - 2166.5752, 2271.2271, 2378.7073, 2489.0159 - ]], - [[ - 1108.7434, 1179.4541, 1252.9933, 1329.3607, - 1408.5568, 1490.5812, 1575.4340, 1663.1152 - ], - [ - 1482.0958, 1564.1202, 1648.9730, 1736.6543, - 1827.1639, 1920.5021, 2016.6686, 2115.6636 - ], - [ - 1855.4482, 1948.7864, 2044.9529, 2143.9478, - 2245.7712, 2350.4231, 2457.9033, 2568.2119 - ], - [ - 2228.8005, 2333.4524, 2440.9326, 2551.2412, - 2664.3784, 2780.3440, 2899.1379, 3020.7603 - ]]]]) + gt_corr_x = Tensor([[[[110.3087, 120.2082, 130.1077, 140.0071], + [206.4752, 222.0315, 237.5879, 253.1442], + [347.8965, 369.1097, 390.3229, 411.5362]], + [[118.7939, 130.1077, 141.4214, 152.7351], + [220.6173, 237.5879, 254.5584, 271.5290], + [367.6955, 390.3229, 412.9504, 435.5778]], + [[127.2792, 140.0071, 152.7351, 165.4630], + [234.7595, 253.1442, 271.5290, 289.9138], + [387.4945, 411.5362, 435.5778, 459.6194]]]]) + gt_corr_y = Tensor([[[[110.3087, 130.1077, 152.7351, 178.1909], + [149.9066, 175.3625, 203.6468, 234.7595], + [189.5046, 220.6173, 254.5584, 291.3280]], + [[144.2498, 169.7056, 197.9899, 229.1026], + [206.4752, 237.5879, 271.5290, 308.2986], + [268.7006, 305.4701, 345.0681, 387.4945]], + [[178.1909, 209.3036, 243.2447, 280.0143], + [263.0437, 299.8133, 339.4113, 381.8377], + [347.8965, 390.3229, 435.5778, 483.6610]]]]) gt_corr = torch.cat((gt_corr_x, gt_corr_y), dim=1) - correlation = corr_block(feat1, feat2, feat2) - + correlation_x = corr_block(feat1, feat2, False) + correlation_y = corr_block(feat1, feat2, True) + correlation = [correlation_x, correlation_y] corr_lookup_cfg = dict( type='CorrLookupFlow1D', radius=radius, From 7a74868fcdb95ff32b5806a245482af1188679fd Mon Sep 17 00:00:00 2001 From: Fc-idris <23311590278@qq.com> Date: Thu, 18 Aug 2022 16:05:42 +0800 Subject: [PATCH 3/5] feat: add flow1d correlation and correlation lookup --- mmflow/models/utils/correlation1d.py | 3 +-- mmflow/ops/corr_lookup.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mmflow/models/utils/correlation1d.py b/mmflow/models/utils/correlation1d.py index d25e1e49..1a170255 100644 --- a/mmflow/models/utils/correlation1d.py +++ b/mmflow/models/utils/correlation1d.py @@ -34,10 +34,9 @@ def forward(self, # y direction, corr shape is [B, W, H, H] feat1 = feat1.permute(0, 3, 2, 1) feat2 = feat2.permute(0, 3, 1, 2) - corr = torch.matmul(feat1, feat2) / scale_factor else: # x direction, corr shape is [B, H, W, W] feat1 = feat1.permute(0, 2, 3, 1) feat2 = feat2.permute(0, 2, 1, 3) - corr = torch.matmul(feat1, feat2) / scale_factor + corr = torch.matmul(feat1, feat2) / scale_factor return corr diff --git a/mmflow/ops/corr_lookup.py b/mmflow/ops/corr_lookup.py index 67aaf0c1..5b555908 100644 --- a/mmflow/ops/corr_lookup.py +++ b/mmflow/ops/corr_lookup.py @@ -181,9 +181,9 @@ def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor: concatenate together. """ B, _, H, W = flow.shape - # reshape corr_x to [B*H*W, 1, 1, W] + # reshape corr_x from [B, H, W, W] to [B*H*W, 1, 1, W] corr_x = corr[0].view(-1, 1, 1, W) - # reshape corr_y to [B*H*W, 1, H, 1] + # reshape corr_y from [B, W, H, H]to [B*H*W, 1, H, 1] corr_y = corr[1].permute(0, 2, 1, 3).contiguous().view(-1, 1, H, 1) # reshape flow to [B, H, W, 2] From 0c6f7e1e64a092f9a8fc86478e560fb6a60dbcf2 Mon Sep 17 00:00:00 2001 From: Fc-idris <23311590278@qq.com> Date: Thu, 13 Oct 2022 12:52:47 +0800 Subject: [PATCH 4/5] [feat]: add flow1d decoder and flow1d config --- configs/_base_/models/flow1d.py | 46 ++ .../flow1d_8x2_100k_flyingchairs_368x496.py | 24 + mmflow/models/decoders/__init__.py | 3 +- mmflow/models/decoders/flow1d_decoder.py | 425 ++++++++++++++++++ mmflow/models/flow_estimators/__init__.py | 3 +- mmflow/models/flow_estimators/flow1d.py | 156 +++++++ .../test_decoder/test_flow1d_decoder.py | 66 +++ tests/test_models/test_flow_estimator.py | 3 +- 8 files changed, 723 insertions(+), 3 deletions(-) create mode 100644 configs/_base_/models/flow1d.py create mode 100644 configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py create mode 100644 mmflow/models/decoders/flow1d_decoder.py create mode 100644 mmflow/models/flow_estimators/flow1d.py create mode 100644 tests/test_models/test_decoder/test_flow1d_decoder.py diff --git a/configs/_base_/models/flow1d.py b/configs/_base_/models/flow1d.py new file mode 100644 index 00000000..b2450184 --- /dev/null +++ b/configs/_base_/models/flow1d.py @@ -0,0 +1,46 @@ +model = dict( + type='Flow1D', + radius=4, + cxt_channels=128, + h_channels=128, + encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='IN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['InstanceNorm2d'], val=1, bias=0) + ]), + cxt_encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='SyncBN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['SyncBatchNorm2d'], val=1, bias=0) + ]), + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=4, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss'), + act_cfg=dict(type='ReLU')), + freeze_bn=False, + train_cfg=dict(), + test_cfg=dict(), +) diff --git a/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py b/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py new file mode 100644 index 00000000..5b3d089d --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/flyingchairs_raft_368x496.py', + '../_base_/default_runtime.py' +] + +optimizer = dict( + type='AdamW', + lr=0.0004, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.0001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.0004, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') diff --git a/mmflow/models/decoders/__init__.py b/mmflow/models/decoders/__init__.py index 7c210fd9..5fb99bb7 100644 --- a/mmflow/models/decoders/__init__.py +++ b/mmflow/models/decoders/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .context_net import ContextNet +from .flow1d_decoder import Flow1DDecoder from .flownet_decoder import FlowNetCDecoder, FlowNetSDecoder from .gma_decoder import GMADecoder from .irr_refine import FlowRefine, OccRefine, OccShuffleUpsample @@ -13,5 +14,5 @@ 'FlowNetCDecoder', 'FlowNetSDecoder', 'PWCNetDecoder', 'MaskFlowNetSDecoder', 'NetE', 'ContextNet', 'RAFTDecoder', 'FlowRefine', 'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder', - 'GMADecoder' + 'GMADecoder', 'Flow1DDecoder' ] diff --git a/mmflow/models/decoders/flow1d_decoder.py b/mmflow/models/decoders/flow1d_decoder.py new file mode 100644 index 00000000..343467ca --- /dev/null +++ b/mmflow/models/decoders/flow1d_decoder.py @@ -0,0 +1,425 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmflow.ops import build_operators +from ..builder import DECODERS, build_loss +from ..utils.attention1d import Attention1D +from ..utils.correlation1d import Correlation1D +from .base_decoder import BaseDecoder +from .raft_decoder import ConvGRU, XHead + + +class MotionEncoderFlow1D(BaseModule): + """The module of motion encoder for Flow1D. + + An encoder which consists of several convolution layers and outputs + features as GRU's input. + + Args: + num_levels (int): Number of levels used when calculating correlation + tensor. Default: 4. + radius (int): Radius used when calculating correlation tensor. + Default: 4. + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + Default: 'Basic'. + """ + _corr_channels = {'Basic': (256, 192), 'Small': 96} + _corr_kernel = {'Basic': (1, 3), 'Small': 1} + _corr_padding = {'Basic': (0, 1), 'Small': 0} + + _flow_channels = {'Basic': (128, 64), 'Small': (64, 32)} + _flow_kernel = {'Basic': (7, 3), 'Small': (7, 3)} + _flow_padding = {'Basic': (3, 1), 'Small': (3, 1)} + + _out_channels = {'Basic': 126, 'Small': 80} + _out_kernel = {'Basic': 3, 'Small': 3} + _out_padding = {'Basic': 1, 'Small': 1} + + def __init__(self, + radius: int = 4, + net_type: str = 'Basic', + **kwargs) -> None: + super().__init__() + assert net_type in ['Basic', 'Small'] + corr_channels = self._corr_channels.get(net_type) if isinstance( + self._corr_channels[net_type], + (tuple, list)) else [self._corr_channels[net_type]] + corr_kernel = self._corr_kernel.get(net_type) if isinstance( + self._corr_kernel.get(net_type), + (tuple, list)) else [self._corr_kernel.get(net_type)] + corr_padding = self._corr_padding.get(net_type) if isinstance( + self._corr_padding.get(net_type), + (tuple, list)) else [self._corr_padding.get(net_type)] + + flow_channels = self._flow_channels.get(net_type) + flow_kernel = self._flow_kernel.get(net_type) + flow_padding = self._flow_padding.get(net_type) + + self.out_channels = self._out_channels.get(net_type) if isinstance( + self._out_channels.get(net_type), + (tuple, list)) else [self._out_channels.get(net_type)] + out_kernel = self._out_kernel.get(net_type) if isinstance( + self._out_kernel.get(net_type), + (tuple, list)) else [self._out_kernel.get(net_type)] + out_padding = self._out_padding.get(net_type) if isinstance( + self._out_padding.get(net_type), + (tuple, list)) else [self._out_padding.get(net_type)] + + corr_inch = 2 * (2 * radius + 1) + corr_net = self._make_encoder(corr_inch, corr_channels, corr_kernel, + corr_padding, **kwargs) + self.corr_net = nn.Sequential(*corr_net) + + flow_inch = 2 + flow_net = self._make_encoder(flow_inch, flow_channels, flow_kernel, + flow_padding, **kwargs) + self.flow_net = nn.Sequential(*flow_net) + + out_inch = corr_channels[-1] + flow_channels[-1] + out_net = self._make_encoder(out_inch, self.out_channels, out_kernel, + out_padding, **kwargs) + self.out_net = nn.Sequential(*out_net) + + def _make_encoder(self, in_channel: int, channels: int, kernels: int, + paddings: int, conv_cfg: dict, norm_cfg: dict, + act_cfg: dict) -> None: + encoder = [] + + for ch, k, p in zip(channels, kernels, paddings): + encoder.append( + ConvModule( + in_channels=in_channel, + out_channels=ch, + kernel_size=k, + padding=p, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + in_channel = ch + return encoder + + def forward(self, corr: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: + """Forward function for MotionEncoder. + + Args: + corr (Tensor): The correlation feature. + flow (Tensor): The last estimated optical flow. + + Returns: + Tensor: The output feature of motion encoder. + """ + corr_feat = self.corr_net(corr) + flow_feat = self.flow_net(flow) + + out = self.out_net(torch.cat([corr_feat, flow_feat], dim=1)) + return torch.cat([out, flow], dim=1) + + +class PositionEmbeddingSine(nn.Module): + """refer to the standard version of position embedding used by the + Attention is all you need paper, generalized to work on images. + + https://github.com/facebookresearch/detr/blob/main/models/position_encod + """ + + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=True, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +@DECODERS.register_module() +class Flow1DDecoder(BaseDecoder): + """The decoder of Flow1D Net. + + The decoder of Flow1D Net, which outputs list of upsampled flow estimation. + + Args: + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + radius (int): Radius used when calculating correlation tensor. + iters (int): Total iteration number of iterative update of RAFTDecoder. + corr_op_cfg (dict): Config dict of correlation operator. + Default: dict(type='CorrLookup'). + gru_type (str): Type of the GRU module. Choices: ['Conv', 'SeqConv']. + Default: 'SeqConv'. + feat_channels (Sequence(int)): features channels of prediction module. + mask_channels (int): Output channels of mask prediction layer. + Default: 64. + conv_cfg (dict, optional): Config dict of convolution layers in motion + encoder. Default: None. + norm_cfg (dict, optional): Config dict of norm layer in motion encoder. + Default: None. + act_cfg (dict, optional): Config dict of activation layer in motion + encoder. Default: None. + """ + _h_channels = {'Basic': 128, 'Small': 96} + _cxt_channels = {'Basic': 128, 'Small': 64} + + def __init__( + self, + net_type: str, + radius: int, + iters: int, + corr_op_cfg: dict = dict(type='CorrLookupFlow1D', align_corners=True), + gru_type: str = 'SeqConv', + feat_channels: Union[int, Sequence[int]] = 256, + mask_channels: int = 64, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + act_cfg: Optional[dict] = None, + flow_loss: Optional[dict] = None, + ) -> None: + super().__init__() + assert net_type in ['Basic', 'Small'] + assert type(feat_channels) in (int, tuple, list) + self.attn_block_x = Attention1D( + in_channels=feat_channels, + y_attention=False, + double_cross_attn=True) + self.attn_block_y = Attention1D( + in_channels=feat_channels, + y_attention=True, + double_cross_attn=True) + self.corr_block = Correlation1D() + + feat_channels = feat_channels if isinstance(tuple, + list) else [feat_channels] + self.feat_channels = feat_channels + self.net_type = net_type + self.radius = radius + self.h_channels = self._h_channels.get(net_type) + self.cxt_channels = self._cxt_channels.get(net_type) + self.iters = iters + self.mask_channels = mask_channels * (2 * radius + 1) + corr_op_cfg['radius'] = radius + self.corr_lookup = build_operators(corr_op_cfg) + self.encoder = MotionEncoderFlow1D( + radius=radius, + net_type=net_type, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gru_type = gru_type + self.gru = self.make_gru_block() + self.flow_pred = XHead(self.h_channels, feat_channels, 2, x='flow') + + if net_type == 'Basic': + self.mask_pred = XHead( + self.h_channels, feat_channels, self.mask_channels, x='mask') + + if flow_loss is not None: + self.flow_loss = build_loss(flow_loss) + + def make_gru_block(self): + return ConvGRU( + self.h_channels, + self.encoder.out_channels[0] + 2 + self.cxt_channels, + net_type=self.gru_type) + + def _upsample(self, + flow: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex + combination. + + Args: + flow (Tensor): The optical flow with the shape [N, 2, H/8, W/8]. + mask (Tensor, optional): The learnable mask with shape + [N, grid_size x scale x scale, H/8, H/8]. + + Returns: + Tensor: The output optical flow with the shape [N, 2, H, W]. + """ + scale = 8 + grid_size = self.radius * 2 + 1 + grid_side = int(math.sqrt(grid_size)) + N, _, H, W = flow.shape + if mask is None: + new_size = (scale * H, scale * W) + return scale * F.interpolate( + flow, size=new_size, mode='bilinear', align_corners=True) + # predict a (Nx8×8×9xHxW) mask + mask = mask.view(N, 1, grid_size, scale, scale, H, W) + mask = torch.softmax(mask, dim=2) + + # extract local grid with 3x3 side padding = grid_side//2 + upflow = F.unfold(scale * flow, [grid_side, grid_side], padding=1) + # upflow with shape N, 2, 9, 1, 1, H, W + upflow = upflow.view(N, 2, grid_size, 1, 1, H, W) + + # take a weighted combination over the neighborhood grid 3x3 + # upflow with shape N, 2, 8, 8, H, W + upflow = torch.sum(mask * upflow, dim=2) + upflow = upflow.permute(0, 1, 4, 2, 5, 3) + return upflow.reshape(N, 2, scale * H, scale * W) + + def forward(self, feat1: torch.Tensor, feat2: torch.Tensor, + flow: torch.Tensor, h: torch.Tensor, + cxt_feat: torch.Tensor) -> Sequence[torch.Tensor]: + """Forward function for Flow1D. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The initialized flow when warm start. + h (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + + Returns: + Sequence[Tensor]: The list of predicted optical flow. + """ + pos_encoding = PositionEmbeddingSine(self.feat_channels[0] // 2) + position = pos_encoding(feat1) + + # attention + feat2_x, attn_x = self.attn_block_x(feat1, feat2, position, None) + feat2_y, attn_y = self.attn_block_y(feat1, feat2, position, None) + correlation_x = self.corr_block(feat1, feat2_x, False) + correlation_y = self.corr_block(feat1, feat2_y, True) + corrleation1d = [correlation_x, correlation_y] + upflow_preds = [] + delta_flow = torch.zeros_like(flow) + for _ in range(self.iters): + flow = flow.detach() + corr = self.corr_lookup(corrleation1d, flow) + motion_feat = self.encoder(corr, flow) + x = torch.cat([cxt_feat, motion_feat], dim=1) + h = self.gru(h, x) + + delta_flow = self.flow_pred(h) + flow = flow + delta_flow + + if hasattr(self, 'mask_pred'): + mask = .25 * self.mask_pred(h) + else: + mask = None + + upflow = self._upsample(flow, mask) + upflow_preds.append(upflow) + + return upflow_preds + + def forward_train( + self, + feat1: torch.Tensor, + feat2: torch.Tensor, + flow: torch.Tensor, + h_feat: torch.Tensor, + cxt_feat: torch.Tensor, + flow_gt: torch.Tensor, + valid: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """Forward function when model training. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The last estimated flow from GRU cell. + h (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + flow_gt (Tensor): The ground truth of optical flow. + Defaults to None. + valid (Tensor, optional): The valid mask. Defaults to None. + + Returns: + Dict[str, Tensor]: The losses of model. + """ + + flow_pred = self.forward(feat1, feat2, flow, h_feat, cxt_feat) + + return self.losses(flow_pred, flow_gt, valid=valid) + + def forward_test(self, + feat1: torch.Tensor, + feat2: torch.Tensor, + flow: torch.Tensor, + h_feat: torch.Tensor, + cxt_feat: torch.Tensor, + img_metas=None) -> Sequence[Dict[str, np.ndarray]]: + """Forward function when model training. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The last estimated flow from GRU cell. + h_feat (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Sequence[Dict[str, ndarray]]: The batch of predicted optical flow + with the same size of images before augmentation. + """ + flow_pred = self.forward(feat1, feat2, flow, h_feat, cxt_feat) + + flow_result = flow_pred[-1] + # flow maps with the shape [H, W, 2] + flow_result = flow_result.permute(0, 2, 3, 1).cpu().data.numpy() + # unravel batch dim + flow_result = list(flow_result) + flow_result = [dict(flow=f) for f in flow_result] + return self.get_flow(flow_result, img_metas=img_metas) + + def losses(self, + flow_pred: Sequence[torch.Tensor], + flow_gt: torch.Tensor, + valid: torch.Tensor = None) -> Dict[str, torch.Tensor]: + """Compute optical flow loss. + + Args: + flow_pred (Sequence[Tensor]): The list of predicted optical flow. + flow_gt (Tensor): The ground truth of optical flow. + valid (Tensor, optional): The valid mask. Defaults to None. + + Returns: + Dict[str, Tensor]: The dict of losses. + """ + + loss = dict() + loss['loss_flow'] = self.flow_loss(flow_pred, flow_gt, valid) + return loss diff --git a/mmflow/models/flow_estimators/__init__.py b/mmflow/models/flow_estimators/__init__.py index 7020cff2..7e8052d2 100644 --- a/mmflow/models/flow_estimators/__init__.py +++ b/mmflow/models/flow_estimators/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .flow1d import Flow1D from .flownet import FlowNetC, FlowNetS from .flownet2 import FlowNet2, FlowNetCSS from .irrpwc import IRRPWC @@ -9,5 +10,5 @@ __all__ = [ 'FlowNetC', 'FlowNetS', 'LiteFlowNet', 'PWCNet', 'MaskFlowNetS', 'RAFT', - 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet' + 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet', 'Flow1D' ] diff --git a/mmflow/models/flow_estimators/flow1d.py b/mmflow/models/flow_estimators/flow1d.py new file mode 100644 index 00000000..49b287cd --- /dev/null +++ b/mmflow/models/flow_estimators/flow1d.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from numpy import ndarray + +from ..builder import FLOW_ESTIMATORS, build_encoder +from .pwcnet import PWCNet + + +@FLOW_ESTIMATORS.register_module() +class Flow1D(PWCNet): + """Flow1D model. + + Args: + radius (int): Number of radius in . + cxt_channels (int): Number of channels of context feature. + h_channels (int): Number of channels of hidden feature in . + cxt_encoder (dict): Config dict for building context encoder. + freeze_bn (bool, optional): Whether to freeze batchnorm layer or not. + Default: False. + """ + + def __init__(self, + radius: int, + cxt_channels: int, + h_channels: int, + cxt_encoder: dict, + freeze_bn: bool = False, + **kwargs) -> None: + super().__init__(**kwargs) + self.radius = radius + self.context = build_encoder(cxt_encoder) + self.h_channels = h_channels + self.cxt_channels = cxt_channels + + assert self.radius == self.decoder.radius + assert self.h_channels == self.decoder.h_channels + assert self.cxt_channels == self.decoder.cxt_channels + assert self.h_channels + self.cxt_channels == self.context.out_channels + if freeze_bn: + self.freeze_bn() + + def freeze_bn(self) -> None: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def extract_feat( + self, imgs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Extract features from images. + + Args: + imgs (Tensor): The concatenated input images. + + Returns: + Tuple[Tensor, Tensor, Tensor, Tensor]: The feature from the first + image, the feature from the second image, the hidden state + feature for GRU cell and the contextual feature. + """ + in_channels = self.encoder.in_channels + img1 = imgs[:, :in_channels, ...] + img2 = imgs[:, in_channels:, ...] + + feat1 = self.encoder(img1) + feat2 = self.encoder(img2) + cxt_feat = self.context(img1) + + h_feat, cxt_feat = torch.split( + cxt_feat, [self.h_channels, self.cxt_channels], dim=1) + h_feat = torch.tanh(h_feat) + cxt_feat = torch.relu(cxt_feat) + + return feat1, feat2, h_feat, cxt_feat + + def forward_train( + self, + imgs: torch.Tensor, + flow_gt: torch.Tensor, + valid: torch.Tensor, + flow_init: Optional[torch.Tensor] = None, + img_metas: Optional[Sequence[dict]] = None + ) -> Dict[str, torch.Tensor]: + """Forward function for Flow1D when model training. + + Args: + imgs (Tensor): The concatenated input images. + flow_gt (Tensor): The ground truth of optical flow. + Defaults to None. + valid (Tensor, optional): The valid mask. Defaults to None. + flow_init (Tensor, optional): The initialized flow when warm start. + Default to None. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Dict[str, Tensor]: The losses of output. + """ + + feat1, feat2, h_feat, cxt_feat = self.extract_feat(imgs) + B, _, H, W = feat1.shape + + if flow_init is None: + flow_init = torch.zeros((B, 2, H, W), device=feat1.device) + + return self.decoder.forward_train( + feat1, + feat2, + flow=flow_init, + h_feat=h_feat, + cxt_feat=cxt_feat, + flow_gt=flow_gt, + valid=valid) + + def forward_test( + self, + imgs: torch.Tensor, + flow_init: Optional[torch.Tensor] = None, + img_metas: Optional[Sequence[dict]] = None) -> Sequence[ndarray]: + """Forward function for Flow1D when model testing. + + Args: + imgs (Tensor): The concatenated input images. + flow_init (Tensor, optional): The initialized flow when warm start. + Default to None. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Sequence[Dict[str, ndarray]]: the batch of predicted optical flow + with the same size of images after augmentation. + """ + train_iter = self.decoder.iters + if self.test_cfg is not None and self.test_cfg.get( + 'iters') is not None: + self.decoder.iters = self.test_cfg.get('iters') + + feat1, feat2, h_feat, cxt_feat = self.extract_feat(imgs) + B, _, H, W = feat1.shape + + if flow_init is None: + flow_init = torch.zeros((B, 2, H, W), device=feat1.device) + + results = self.decoder.forward_test( + feat1=feat1, + feat2=feat2, + flow=flow_init, + h_feat=h_feat, + cxt_feat=cxt_feat, + img_metas=img_metas) + # recover iter in train + self.decoder.iters = train_iter + + return results diff --git a/tests/test_models/test_decoder/test_flow1d_decoder.py b/tests/test_models/test_decoder/test_flow1d_decoder.py new file mode 100644 index 00000000..bded217f --- /dev/null +++ b/tests/test_models/test_decoder/test_flow1d_decoder.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmflow.models.decoders.flow1d_decoder import (Flow1DDecoder, + MotionEncoderFlow1D) + + +@pytest.mark.parametrize('net_type', ['Basic', 'Small']) +def test_motion_encoder(net_type): + + # test invalid net_type + with pytest.raises(AssertionError): + MotionEncoderFlow1D(net_type='invalid value') + + module = MotionEncoderFlow1D( + net_type=net_type, conv_cfg=None, norm_cfg=None, act_cfg=None) + radius = 4 + + input_corr = torch.randn((1, 2 * (2 * radius + 1), 56, 56)) + input_flow = torch.randn((1, 2, 56, 56)) + + corr_feat = module.corr_net(input_corr) + flow_feat = module.flow_net(input_flow) + our_feat = module.out_net(torch.cat([corr_feat, flow_feat], dim=1)) + + if net_type == 'Basic': + assert corr_feat.shape == torch.Size((1, 192, 56, 56)) + assert flow_feat.shape == torch.Size((1, 64, 56, 56)) + assert our_feat.shape == torch.Size((1, 126, 56, 56)) + elif net_type == 'Small': + assert corr_feat.shape == torch.Size((1, 96, 56, 56)) + assert flow_feat.shape == torch.Size((1, 32, 56, 56)) + assert our_feat.shape == torch.Size((1, 80, 56, 56)) + + +def test_flow1d_decoder(): + model = Flow1DDecoder( + net_type='Basic', + radius=4, + iters=12, + flow_loss=dict(type='SequenceLoss')) + mask = torch.ones((1, 64 * 9, 10, 10)) + flow = torch.randn((1, 2, 10, 10)) + assert model._upsample(flow, mask).shape == torch.Size((1, 2, 80, 80)) + + feat1 = torch.randn(1, 256, 8, 8) + feat2 = torch.randn(1, 256, 8, 8) + h_feat = torch.randn(1, 128, 8, 8) + cxt_feat = torch.randn(1, 128, 8, 8) + flow = torch.zeros((1, 2, 8, 8)) + + flow_gt = torch.randn(1, 2, 64, 64) + # test forward function + out = model(feat1, feat2, flow, h_feat, cxt_feat) + assert isinstance(out, list) + assert out[0].shape == torch.Size((1, 2, 64, 64)) + + # test forward train + loss = model.forward_train( + feat1, feat2, flow, h_feat, cxt_feat, flow_gt=flow_gt) + assert float(loss['loss_flow']) > 0. + + # test forward test + out = model.forward_test(feat1, feat2, flow, h_feat, cxt_feat) + assert out[0]['flow'].shape == (64, 64, 2) diff --git a/tests/test_models/test_flow_estimator.py b/tests/test_models/test_flow_estimator.py index dc3016dd..eb1db15e 100644 --- a/tests/test_models/test_flow_estimator.py +++ b/tests/test_models/test_flow_estimator.py @@ -44,6 +44,7 @@ def test_flow_estimator(cfg_file): @pytest.mark.parametrize('cfg_file', [ '../../configs/_base_/models/raft.py', + '../../configs/_base_/models/flow1d.py', '../../configs/_base_/models/flownets.py', '../../configs/_base_/models/flownet2/flownet2sd.py', '../../configs/_base_/models/gma/gma.py', @@ -57,7 +58,7 @@ def test_flow_estimator_without_cuda(cfg_file): cfg_file = osp.join(osp.dirname(__file__), cfg_file) cfg = Config.fromfile(cfg_file) - if cfg.model.type == 'RAFT': + if cfg.model.type == 'RAFT' or cfg.model.type == 'Flow1D': # Replace SyncBN with BN to inference on CPU cfg.model.cxt_encoder.norm_cfg = dict(type='BN', requires_grad=True) From e0897450752aafa4c8730614377fd5178ad3eb15 Mon Sep 17 00:00:00 2001 From: Fc-idris <23311590278@qq.com> Date: Mon, 31 Oct 2022 18:31:24 +0800 Subject: [PATCH 5/5] [feat]: update flow1d training config. Optimize flow1d_decoder.py. --- .../datasets/flyingthings3d_flow1d_400x720.py | 106 ++++++++++++++++++ configs/_base_/models/flow1d.py | 4 +- .../flow1d_8x2_100k_flyingthings3d_400x720.py | 29 +++++ ..._8x2_100k_flyingthings3d_sintel_368x768.py | 41 +++++++ .../flow1d/flow1d_8x2_100k_mixed_368x768.py | 42 +++++++ mmflow/models/decoders/flow1d_decoder.py | 86 ++------------ mmflow/models/utils/res_layer.py | 4 +- 7 files changed, 229 insertions(+), 83 deletions(-) create mode 100644 configs/_base_/datasets/flyingthings3d_flow1d_400x720.py create mode 100644 configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py create mode 100644 configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py create mode 100644 configs/flow1d/flow1d_8x2_100k_mixed_368x768.py diff --git a/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py b/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py new file mode 100644 index 00000000..915ca93b --- /dev/null +++ b/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py @@ -0,0 +1,106 @@ +train_dataset_type = 'FlyingThings3D' +train_data_root = 'data/flyingthings3d' +test_dataset_type = 'Sintel' +test_data_root = 'data/Sintel' + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='ColorJitter', + asymmetric_prob=0.2, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.5 / 3.14), + dict(type='Erase', prob=0.5, bounds=[50, 100], max_num=3), + dict( + type='SpacialTransform', + spacial_prob=0.8, + stretch_prob=0.8, + crop_size=(400, 720), + min_scale=-0.4, + max_scale=0.8, + max_stretch=0.2), + dict(type='RandomCrop', crop_size=(400, 720)), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='RandomFlip', prob=0.1, direction='vertical'), + dict(type='Validation', max_flow=1000.), + dict(type='Normalize', **img_norm_cfg), + dict(type='DefaultFormatBundle'), + dict( + type='Collect', + keys=['imgs', 'flow_gt', 'valid'], + meta_keys=[ + 'filename1', 'filename2', 'ori_filename1', 'ori_filename2', + 'filename_flow', 'ori_filename_flow', 'ori_shape', 'img_shape', + 'erase_bounds', 'erase_num', 'scale_factor' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='InputPad', exponent=3), + dict(type='Normalize', **img_norm_cfg), + dict(type='TestFormatBundle'), + dict( + type='Collect', + keys=['imgs'], + meta_keys=[ + 'flow_gt', 'filename1', 'filename2', 'ori_filename1', + 'ori_filename2', 'ori_shape', 'img_shape', 'img_norm_cfg', + 'scale_factor', 'pad_shape', 'pad' + ]) +] + +train_dataset_cleanpass = dict( + type=train_dataset_type, + data_root=train_data_root, + pipeline=train_pipeline, + test_mode=False, + pass_style='clean', + scene='left') +train_dataset_finalpass = dict( + type=train_dataset_type, + data_root=train_data_root, + pipeline=train_pipeline, + test_mode=False, + pass_style='final', + scene='left') +test_data_cleanpass = dict( + type=test_dataset_type, + data_root=test_data_root, + pipeline=test_pipeline, + test_mode=True, + pass_style='clean') +test_data_finalpass = dict( + type=test_dataset_type, + data_root=test_data_root, + pipeline=test_pipeline, + test_mode=True, + pass_style='final') + +data = dict( + train_dataloader=dict( + samples_per_gpu=2, + workers_per_gpu=2, + shuffle=False, + drop_last=True, + persistent_workers=True), + val_dataloader=dict( + samples_per_gpu=1, + workers_per_gpu=2, + shuffle=False, + persistent_workers=True), + test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2, shuffle=False), + train=[train_dataset_cleanpass, train_dataset_finalpass], + val=dict( + type='ConcatDataset', + datasets=[test_data_cleanpass, test_data_finalpass], + separate_eval=True), + test=dict( + type='ConcatDataset', + datasets=[test_data_cleanpass, test_data_finalpass], + separate_eval=True)) diff --git a/configs/_base_/models/flow1d.py b/configs/_base_/models/flow1d.py index b2450184..6a347a0c 100644 --- a/configs/_base_/models/flow1d.py +++ b/configs/_base_/models/flow1d.py @@ -1,6 +1,6 @@ model = dict( type='Flow1D', - radius=4, + radius=32, cxt_channels=128, h_channels=128, encoder=dict( @@ -34,7 +34,7 @@ decoder=dict( type='Flow1DDecoder', net_type='Basic', - radius=4, + radius=32, iters=12, corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), gru_type='SeqConv', diff --git a/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py new file mode 100644 index 00000000..bc711844 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py @@ -0,0 +1,29 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/flyingthings3d_raft_400x720.py', + '../_base_/default_runtime.py' +] + +model = dict(freeze_bn=True, test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and finetune on FlyingThings3D +load_from = "work-dir/flow1d/flyingchair/latest.pth" diff --git a/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py new file mode 100644 index 00000000..640410f4 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py @@ -0,0 +1,41 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/sintel_cleanx100_sintel_fianlx100_flyingthings3d_raft_368x768.py', # noqa + '../_base_/default_runtime.py' +] + +model = dict( + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss', gamma=0.85), + act_cfg=dict(type='ReLU')), + freeze_bn=True, + test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and FlyingThings3D, and finetune on FlyingThings3D +# and Sintel +load_from = "work-dir/flow1d/flyingthingd/latest.pth" \ No newline at end of file diff --git a/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py b/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py new file mode 100644 index 00000000..ca12fad7 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py @@ -0,0 +1,42 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/sintel_cleanx100_sintel_fianlx100_kitti2015x200_hd1kx5_flyingthings3d_raft_384x768.py', # noqa + '../_base_/default_runtime.py' +] + +model = dict( + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss', gamma=0.85), + act_cfg=dict(type='ReLU')), + freeze_bn=True, + test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and FlyingThings3D, and finetune on +# and Sintel, KITTI2015 and HD1K +load_from = "work-dir/flow1d/sintel/latest.pth" + diff --git a/mmflow/models/decoders/flow1d_decoder.py b/mmflow/models/decoders/flow1d_decoder.py index 343467ca..a97cf5b1 100644 --- a/mmflow/models/decoders/flow1d_decoder.py +++ b/mmflow/models/decoders/flow1d_decoder.py @@ -14,41 +14,27 @@ from ..utils.attention1d import Attention1D from ..utils.correlation1d import Correlation1D from .base_decoder import BaseDecoder -from .raft_decoder import ConvGRU, XHead +from .raft_decoder import ConvGRU, XHead, MotionEncoder -class MotionEncoderFlow1D(BaseModule): +class MotionEncoderFlow1D(MotionEncoder): """The module of motion encoder for Flow1D. An encoder which consists of several convolution layers and outputs features as GRU's input. Args: - num_levels (int): Number of levels used when calculating correlation - tensor. Default: 4. radius (int): Radius used when calculating correlation tensor. - Default: 4. + Default: 32. net_type (str): Type of the net. Choices: ['Basic', 'Small']. Default: 'Basic'. """ - _corr_channels = {'Basic': (256, 192), 'Small': 96} - _corr_kernel = {'Basic': (1, 3), 'Small': 1} - _corr_padding = {'Basic': (0, 1), 'Small': 0} - - _flow_channels = {'Basic': (128, 64), 'Small': (64, 32)} - _flow_kernel = {'Basic': (7, 3), 'Small': (7, 3)} - _flow_padding = {'Basic': (3, 1), 'Small': (3, 1)} - - _out_channels = {'Basic': 126, 'Small': 80} - _out_kernel = {'Basic': 3, 'Small': 3} - _out_padding = {'Basic': 1, 'Small': 1} def __init__(self, - radius: int = 4, + radius: int = 32, net_type: str = 'Basic', **kwargs) -> None: - super().__init__() - assert net_type in ['Basic', 'Small'] + super().__init__(radius=radius, net_type=net_type, **kwargs) corr_channels = self._corr_channels.get(net_type) if isinstance( self._corr_channels[net_type], (tuple, list)) else [self._corr_channels[net_type]] @@ -59,69 +45,11 @@ def __init__(self, self._corr_padding.get(net_type), (tuple, list)) else [self._corr_padding.get(net_type)] - flow_channels = self._flow_channels.get(net_type) - flow_kernel = self._flow_kernel.get(net_type) - flow_padding = self._flow_padding.get(net_type) - - self.out_channels = self._out_channels.get(net_type) if isinstance( - self._out_channels.get(net_type), - (tuple, list)) else [self._out_channels.get(net_type)] - out_kernel = self._out_kernel.get(net_type) if isinstance( - self._out_kernel.get(net_type), - (tuple, list)) else [self._out_kernel.get(net_type)] - out_padding = self._out_padding.get(net_type) if isinstance( - self._out_padding.get(net_type), - (tuple, list)) else [self._out_padding.get(net_type)] - corr_inch = 2 * (2 * radius + 1) corr_net = self._make_encoder(corr_inch, corr_channels, corr_kernel, corr_padding, **kwargs) self.corr_net = nn.Sequential(*corr_net) - flow_inch = 2 - flow_net = self._make_encoder(flow_inch, flow_channels, flow_kernel, - flow_padding, **kwargs) - self.flow_net = nn.Sequential(*flow_net) - - out_inch = corr_channels[-1] + flow_channels[-1] - out_net = self._make_encoder(out_inch, self.out_channels, out_kernel, - out_padding, **kwargs) - self.out_net = nn.Sequential(*out_net) - - def _make_encoder(self, in_channel: int, channels: int, kernels: int, - paddings: int, conv_cfg: dict, norm_cfg: dict, - act_cfg: dict) -> None: - encoder = [] - - for ch, k, p in zip(channels, kernels, paddings): - encoder.append( - ConvModule( - in_channels=in_channel, - out_channels=ch, - kernel_size=k, - padding=p, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg)) - in_channel = ch - return encoder - - def forward(self, corr: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: - """Forward function for MotionEncoder. - - Args: - corr (Tensor): The correlation feature. - flow (Tensor): The last estimated optical flow. - - Returns: - Tensor: The output feature of motion encoder. - """ - corr_feat = self.corr_net(corr) - flow_feat = self.flow_net(flow) - - out = self.out_net(torch.cat([corr_feat, flow_feat], dim=1)) - return torch.cat([out, flow], dim=1) - class PositionEmbeddingSine(nn.Module): """refer to the standard version of position embedding used by the @@ -235,7 +163,7 @@ def __init__( self.h_channels = self._h_channels.get(net_type) self.cxt_channels = self._cxt_channels.get(net_type) self.iters = iters - self.mask_channels = mask_channels * (2 * radius + 1) + self.mask_channels = mask_channels * 9 corr_op_cfg['radius'] = radius self.corr_lookup = build_operators(corr_op_cfg) self.encoder = MotionEncoderFlow1D( @@ -276,7 +204,7 @@ def _upsample(self, Tensor: The output optical flow with the shape [N, 2, H, W]. """ scale = 8 - grid_size = self.radius * 2 + 1 + grid_size = 9 grid_side = int(math.sqrt(grid_size)) N, _, H, W = flow.shape if mask is None: diff --git a/mmflow/models/utils/res_layer.py b/mmflow/models/utils/res_layer.py index 14b60d46..6e8ff86a 100644 --- a/mmflow/models/utils/res_layer.py +++ b/mmflow/models/utils/res_layer.py @@ -78,7 +78,7 @@ def _inner_forward(x): if self.downsample is not None: identity = self.downsample(x) - out += identity + out = out + identity return out @@ -288,7 +288,7 @@ def _inner_forward(x): if self.downsample is not None: identity = self.downsample(x) - out += identity + out = out + identity return out