diff --git a/configs/_base_/models/spynet.py b/configs/_base_/models/spynet.py new file mode 100644 index 0000000..2079083 --- /dev/null +++ b/configs/_base_/models/spynet.py @@ -0,0 +1,18 @@ +model = dict( + type='SpyNet', + img_channels=3, + pyramid_levels=[ + 'level0', 'level1', 'level2', 'level3', 'level4', 'level5' + ], + decoder=dict( + type='SpyNetDecoder', + in_channels=8, + pyramid_levels=[ + 'level0', 'level1', 'level2', 'level3', 'level4', 'level5' + ], + out_channels=(32, 64, 32, 16, 2), + kernel_size=7, + stride=1, + warp_cfg=dict(type='Warp', align_corners=True), + act_cfg=dict(type='ReLU'), + )) diff --git a/mmflow/models/__init__.py b/mmflow/models/__init__.py index 6afcd82..05464c8 100644 --- a/mmflow/models/__init__.py +++ b/mmflow/models/__init__.py @@ -4,11 +4,13 @@ build_flow_estimator) from .decoders import (FlowNetCDecoder, FlowNetSDecoder, FlowRefine, IRRPWCDecoder, MaskFlowNetDecoder, MaskFlowNetSDecoder, - NetE, OccRefine, OccShuffleUpsample, PWCNetDecoder) + NetE, OccRefine, OccShuffleUpsample, PWCNetDecoder, + SpyNetDecoder) from .encoders import (CorrEncoder, FlowNetEncoder, FlowNetSDEncoder, NetC, PWCNetEncoder, RAFTEncoder) from .flow_estimators import (IRRPWC, FlowNet2, FlowNetC, FlowNetCSS, FlowNetS, - LiteFlowNet, MaskFlowNet, MaskFlowNetS, PWCNet) + LiteFlowNet, MaskFlowNet, MaskFlowNetS, PWCNet, + SpyNet) from .losses import (MultiLevelBCE, MultiLevelCharbonnierLoss, MultiLevelEPE, SequenceLoss) @@ -21,5 +23,6 @@ 'build_flow_estimator', 'COMPONENTS', 'build_components', 'MultiLevelBCE', 'MultiLevelEPE', 'MultiLevelCharbonnierLoss', 'SequenceLoss', 'IRRPWC', 'IRRPWCDecoder', 'FlowRefine', 'OccRefine', 'OccShuffleUpsample', - 'FlowNet2', 'FlowNetCSS', 'MaskFlowNetDecoder', 'MaskFlowNet' + 'FlowNet2', 'FlowNetCSS', 'MaskFlowNetDecoder', 'MaskFlowNet', + 'SpyNetDecoder', 'SpyNet' ] diff --git a/mmflow/models/decoders/__init__.py b/mmflow/models/decoders/__init__.py index 362bc11..bac32b6 100644 --- a/mmflow/models/decoders/__init__.py +++ b/mmflow/models/decoders/__init__.py @@ -7,9 +7,11 @@ from .maskflownet_decoder import MaskFlowNetDecoder, MaskFlowNetSDecoder from .pwcnet_decoder import PWCNetDecoder from .raft_decoder import RAFTDecoder +from .spynet_decoder import SpyNetDecoder __all__ = [ 'FlowNetCDecoder', 'FlowNetSDecoder', 'PWCNetDecoder', 'MaskFlowNetSDecoder', 'NetE', 'ContextNet', 'RAFTDecoder', 'FlowRefine', - 'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder' + 'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder', + 'SpyNetDecoder' ] diff --git a/mmflow/models/decoders/spynet_decoder.py b/mmflow/models/decoders/spynet_decoder.py new file mode 100644 index 0000000..56d6a22 --- /dev/null +++ b/mmflow/models/decoders/spynet_decoder.py @@ -0,0 +1,155 @@ +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmflow.ops.builder import build_operators +from ..builder import DECODERS +from .base_decoder import BaseDecoder + + +class BasicLayers(BaseModule): + + def __init__(self, + in_channels: int, + out_channels=(32, 64, 32, 16, 2), + kernel_size=7, + stride=1, + act_cfg=dict(type='ReLU', inplace=False), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + convs = [] + in_ch = in_channels + for out_ch in out_channels[:-1]: + convs.append( + ConvModule( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + act_cfg=act_cfg)) + in_ch = out_ch + convs.append( + nn.Conv2d( + in_channels=in_ch, + out_channels=out_channels[-1], + kernel_size=kernel_size, + padding=kernel_size // 2)) + self.layers = nn.Sequential(*convs) + + def forward(self, x): + return self.layers(x) + + +@DECODERS.register_module() +class SpyNetDecoder(BaseDecoder): + + def __init__(self, + in_channels, + pyramid_levels, + out_channels=(32, 64, 32, 16, 2), + kernel_size=7, + stride=1, + warp_cfg: dict = dict(type='Warp', align_corners=True), + act_cfg=dict(type='ReLU'), + init_cfg: Optional[Union[dict, list]] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.pyramid_levels = pyramid_levels + self.pyramid_levels.sort() + + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.act_cfg = act_cfg + + self.warp = build_operators(warp_cfg) + + layers = [] + + for level in self.pyramid_levels: + + layers.append([level, self.make_layers()]) + + self.decoders = nn.ModuleDict(layers) + + def make_layers(self): + return BasicLayers( + in_channels=self.in_channels, out_channels=self.out_channels) + + def forward(self, imgs1, imgs2): + flow = None + + residual_flow_preds = dict() + previous_flow_preds = dict() + for level in self.pyramid_levels[::-1]: + + img1 = imgs1[level] + img2 = imgs2[level] + _, _, H, W = img1.shape + + if flow is None: + flow = torch.zeros(1, 2, H, W).to(img1) + else: + flow = F.interpolate( + flow, scale_factor=2, mode='bilinear', + align_corners=False) * 2.0 + + warped_img2 = self.warp(img2, flow) + previous_flow_preds[level] = flow + + in_feat = torch.cat((img1, warped_img2, flow), dim=1) + + residual_flow = self.decoders[level](in_feat) + flow += residual_flow + + residual_flow_preds[level] = residual_flow + + return flow, residual_flow_preds, previous_flow_preds + + def losses( + self, + residual_flow_preds: Dict[str, torch.Tensor], + previous_flow_preds: Dict[str, torch.Tensor], + flow_gt: torch.Tensor, + valid: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """Compute optical flow loss. + + Args: + flow_pred (Dict[str, Tensor]): multi-level predicted optical flow. + flow_gt (Tensor): The ground truth of optical flow. + valid (Tensor, optional): The valid mask. Defaults to None. + + Returns: + Dict[str, Tensor]: The dict of losses. + """ + loss = dict() + loss['loss_flow'] = self.flow_loss(residual_flow_preds, + previous_flow_preds, flow_gt, valid) + return loss + + def forward_train(self, imgs1, imgs2, flow_gt, valid=None): + _, residual_flow_preds, previous_flow_preds = self.forward( + imgs1=imgs1, imgs2=imgs2) + + return self.losses( + residual_flow_preds=residual_flow_preds, + previous_flow_preds=previous_flow_preds, + flow_gt=flow_gt, + valid=valid) + + def forward_test(self, imgs1, imgs2, img_metas=None): + flow, _, _ = self.forward(imgs1=imgs1, imgs2=imgs2) + flow_result = flow.permute(0, 2, 3, 1).cpu().data.numpy() + + # unravel batch dim, + flow_result = list(flow_result) + flow_result = [dict(flow=f) for f in flow_result] + + return self.get_flow(flow_result, img_metas=img_metas) diff --git a/mmflow/models/flow_estimators/__init__.py b/mmflow/models/flow_estimators/__init__.py index 7020cff..bb04c90 100644 --- a/mmflow/models/flow_estimators/__init__.py +++ b/mmflow/models/flow_estimators/__init__.py @@ -6,8 +6,9 @@ from .maskflownet import MaskFlowNet, MaskFlowNetS from .pwcnet import PWCNet from .raft import RAFT +from .spynet import SpyNet __all__ = [ 'FlowNetC', 'FlowNetS', 'LiteFlowNet', 'PWCNet', 'MaskFlowNetS', 'RAFT', - 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet' + 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet', 'SpyNet' ] diff --git a/mmflow/models/flow_estimators/spynet.py b/mmflow/models/flow_estimators/spynet.py new file mode 100644 index 0000000..fc759d1 --- /dev/null +++ b/mmflow/models/flow_estimators/spynet.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F + +from ..builder import FLOW_ESTIMATORS, build_decoder +from .base import FlowEstimator + + +@FLOW_ESTIMATORS.register_module() +class SpyNet(FlowEstimator): + + def __init__(self, + pyramid_levels, + decoder, + img_channels=3, + **kwargs) -> None: + super().__init__(**kwargs) + self.pyramid_levels = pyramid_levels + self.pyramid_levels.sort() + self.img_channels = img_channels + self.decoder = build_decoder(decoder) + + def downsample_images(self, imgs): + imgs1 = dict() + imgs2 = dict() + img1 = imgs[:, :self.img_channels, ...] + img2 = imgs[:, self.img_channels:, ...] + + imgs1[self.pyramid_levels[0]] = img1 + imgs2[self.pyramid_levels[0]] = img2 + + for level in self.pyramid_levels[1:]: + + img1 = F.avg_pool2d( + img1, + kernel_size=2, + stride=2, + ) + img2 = F.avg_pool2d( + img2, + kernel_size=2, + stride=2, + ) + imgs1[level] = img1 + imgs2[level] = img2 + + return imgs1, imgs2 + + def forward_train(self, imgs, flow_gt, valid=None, img_meta=None): + imgs1, imgs2 = self.downsample_images(imgs) + + return self.decoder.forward_train( + imgs1=imgs1, imgs2=imgs2, flow_gt=flow_gt, valid=valid) + + def forward_test(self, imgs, img_metas=None): + imgs1, imgs2 = self.downsample_images(imgs) + + return self.decoder.forward_test( + imgs1=imgs1, imgs2=imgs2, img_metas=img_metas)