diff --git a/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py b/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py new file mode 100644 index 00000000..915ca93b --- /dev/null +++ b/configs/_base_/datasets/flyingthings3d_flow1d_400x720.py @@ -0,0 +1,106 @@ +train_dataset_type = 'FlyingThings3D' +train_data_root = 'data/flyingthings3d' +test_dataset_type = 'Sintel' +test_data_root = 'data/Sintel' + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='ColorJitter', + asymmetric_prob=0.2, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.5 / 3.14), + dict(type='Erase', prob=0.5, bounds=[50, 100], max_num=3), + dict( + type='SpacialTransform', + spacial_prob=0.8, + stretch_prob=0.8, + crop_size=(400, 720), + min_scale=-0.4, + max_scale=0.8, + max_stretch=0.2), + dict(type='RandomCrop', crop_size=(400, 720)), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='RandomFlip', prob=0.1, direction='vertical'), + dict(type='Validation', max_flow=1000.), + dict(type='Normalize', **img_norm_cfg), + dict(type='DefaultFormatBundle'), + dict( + type='Collect', + keys=['imgs', 'flow_gt', 'valid'], + meta_keys=[ + 'filename1', 'filename2', 'ori_filename1', 'ori_filename2', + 'filename_flow', 'ori_filename_flow', 'ori_shape', 'img_shape', + 'erase_bounds', 'erase_num', 'scale_factor' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='InputPad', exponent=3), + dict(type='Normalize', **img_norm_cfg), + dict(type='TestFormatBundle'), + dict( + type='Collect', + keys=['imgs'], + meta_keys=[ + 'flow_gt', 'filename1', 'filename2', 'ori_filename1', + 'ori_filename2', 'ori_shape', 'img_shape', 'img_norm_cfg', + 'scale_factor', 'pad_shape', 'pad' + ]) +] + +train_dataset_cleanpass = dict( + type=train_dataset_type, + data_root=train_data_root, + pipeline=train_pipeline, + test_mode=False, + pass_style='clean', + scene='left') +train_dataset_finalpass = dict( + type=train_dataset_type, + data_root=train_data_root, + pipeline=train_pipeline, + test_mode=False, + pass_style='final', + scene='left') +test_data_cleanpass = dict( + type=test_dataset_type, + data_root=test_data_root, + pipeline=test_pipeline, + test_mode=True, + pass_style='clean') +test_data_finalpass = dict( + type=test_dataset_type, + data_root=test_data_root, + pipeline=test_pipeline, + test_mode=True, + pass_style='final') + +data = dict( + train_dataloader=dict( + samples_per_gpu=2, + workers_per_gpu=2, + shuffle=False, + drop_last=True, + persistent_workers=True), + val_dataloader=dict( + samples_per_gpu=1, + workers_per_gpu=2, + shuffle=False, + persistent_workers=True), + test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2, shuffle=False), + train=[train_dataset_cleanpass, train_dataset_finalpass], + val=dict( + type='ConcatDataset', + datasets=[test_data_cleanpass, test_data_finalpass], + separate_eval=True), + test=dict( + type='ConcatDataset', + datasets=[test_data_cleanpass, test_data_finalpass], + separate_eval=True)) diff --git a/configs/_base_/models/flow1d.py b/configs/_base_/models/flow1d.py new file mode 100644 index 00000000..6a347a0c --- /dev/null +++ b/configs/_base_/models/flow1d.py @@ -0,0 +1,46 @@ +model = dict( + type='Flow1D', + radius=32, + cxt_channels=128, + h_channels=128, + encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='IN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['InstanceNorm2d'], val=1, bias=0) + ]), + cxt_encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='SyncBN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['SyncBatchNorm2d'], val=1, bias=0) + ]), + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss'), + act_cfg=dict(type='ReLU')), + freeze_bn=False, + train_cfg=dict(), + test_cfg=dict(), +) diff --git a/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py b/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py new file mode 100644 index 00000000..5b3d089d --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingchairs_368x496.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/flyingchairs_raft_368x496.py', + '../_base_/default_runtime.py' +] + +optimizer = dict( + type='AdamW', + lr=0.0004, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.0001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.0004, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') diff --git a/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py new file mode 100644 index 00000000..bc711844 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_400x720.py @@ -0,0 +1,29 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/flyingthings3d_raft_400x720.py', + '../_base_/default_runtime.py' +] + +model = dict(freeze_bn=True, test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and finetune on FlyingThings3D +load_from = "work-dir/flow1d/flyingchair/latest.pth" diff --git a/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py new file mode 100644 index 00000000..640410f4 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_flyingthings3d_sintel_368x768.py @@ -0,0 +1,41 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/sintel_cleanx100_sintel_fianlx100_flyingthings3d_raft_368x768.py', # noqa + '../_base_/default_runtime.py' +] + +model = dict( + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss', gamma=0.85), + act_cfg=dict(type='ReLU')), + freeze_bn=True, + test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and FlyingThings3D, and finetune on FlyingThings3D +# and Sintel +load_from = "work-dir/flow1d/flyingthingd/latest.pth" \ No newline at end of file diff --git a/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py b/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py new file mode 100644 index 00000000..ca12fad7 --- /dev/null +++ b/configs/flow1d/flow1d_8x2_100k_mixed_368x768.py @@ -0,0 +1,42 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/sintel_cleanx100_sintel_fianlx100_kitti2015x200_hd1kx5_flyingthings3d_raft_384x768.py', # noqa + '../_base_/default_runtime.py' +] + +model = dict( + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=12, + corr_op_cfg=dict(type='CorrLookupFlow1D', align_corners=True), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss', gamma=0.85), + act_cfg=dict(type='ReLU')), + freeze_bn=True, + test_cfg=dict(iters=32)) + +optimizer = dict( + type='AdamW', + lr=0.000125, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.00001, + amsgrad=False) +optimizer_config = dict(grad_clip=dict(max_norm=1.)) +lr_config = dict( + policy='OneCycle', + max_lr=0.000125, + total_steps=100100, + pct_start=0.05, + anneal_strategy='linear') + +runner = dict(type='IterBasedRunner', max_iters=100000) +checkpoint_config = dict(by_epoch=False, interval=10000) +evaluation = dict(interval=10000, metric='EPE') + +# Train on FlyingChairs and FlyingThings3D, and finetune on +# and Sintel, KITTI2015 and HD1K +load_from = "work-dir/flow1d/sintel/latest.pth" + diff --git a/mmflow/models/decoders/__init__.py b/mmflow/models/decoders/__init__.py index 7c210fd9..5fb99bb7 100644 --- a/mmflow/models/decoders/__init__.py +++ b/mmflow/models/decoders/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .context_net import ContextNet +from .flow1d_decoder import Flow1DDecoder from .flownet_decoder import FlowNetCDecoder, FlowNetSDecoder from .gma_decoder import GMADecoder from .irr_refine import FlowRefine, OccRefine, OccShuffleUpsample @@ -13,5 +14,5 @@ 'FlowNetCDecoder', 'FlowNetSDecoder', 'PWCNetDecoder', 'MaskFlowNetSDecoder', 'NetE', 'ContextNet', 'RAFTDecoder', 'FlowRefine', 'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder', - 'GMADecoder' + 'GMADecoder', 'Flow1DDecoder' ] diff --git a/mmflow/models/decoders/flow1d_decoder.py b/mmflow/models/decoders/flow1d_decoder.py new file mode 100644 index 00000000..a97cf5b1 --- /dev/null +++ b/mmflow/models/decoders/flow1d_decoder.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmflow.ops import build_operators +from ..builder import DECODERS, build_loss +from ..utils.attention1d import Attention1D +from ..utils.correlation1d import Correlation1D +from .base_decoder import BaseDecoder +from .raft_decoder import ConvGRU, XHead, MotionEncoder + + +class MotionEncoderFlow1D(MotionEncoder): + """The module of motion encoder for Flow1D. + + An encoder which consists of several convolution layers and outputs + features as GRU's input. + + Args: + radius (int): Radius used when calculating correlation tensor. + Default: 32. + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + Default: 'Basic'. + """ + + def __init__(self, + radius: int = 32, + net_type: str = 'Basic', + **kwargs) -> None: + super().__init__(radius=radius, net_type=net_type, **kwargs) + corr_channels = self._corr_channels.get(net_type) if isinstance( + self._corr_channels[net_type], + (tuple, list)) else [self._corr_channels[net_type]] + corr_kernel = self._corr_kernel.get(net_type) if isinstance( + self._corr_kernel.get(net_type), + (tuple, list)) else [self._corr_kernel.get(net_type)] + corr_padding = self._corr_padding.get(net_type) if isinstance( + self._corr_padding.get(net_type), + (tuple, list)) else [self._corr_padding.get(net_type)] + + corr_inch = 2 * (2 * radius + 1) + corr_net = self._make_encoder(corr_inch, corr_channels, corr_kernel, + corr_padding, **kwargs) + self.corr_net = nn.Sequential(*corr_net) + + +class PositionEmbeddingSine(nn.Module): + """refer to the standard version of position embedding used by the + Attention is all you need paper, generalized to work on images. + + https://github.com/facebookresearch/detr/blob/main/models/position_encod + """ + + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=True, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +@DECODERS.register_module() +class Flow1DDecoder(BaseDecoder): + """The decoder of Flow1D Net. + + The decoder of Flow1D Net, which outputs list of upsampled flow estimation. + + Args: + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + radius (int): Radius used when calculating correlation tensor. + iters (int): Total iteration number of iterative update of RAFTDecoder. + corr_op_cfg (dict): Config dict of correlation operator. + Default: dict(type='CorrLookup'). + gru_type (str): Type of the GRU module. Choices: ['Conv', 'SeqConv']. + Default: 'SeqConv'. + feat_channels (Sequence(int)): features channels of prediction module. + mask_channels (int): Output channels of mask prediction layer. + Default: 64. + conv_cfg (dict, optional): Config dict of convolution layers in motion + encoder. Default: None. + norm_cfg (dict, optional): Config dict of norm layer in motion encoder. + Default: None. + act_cfg (dict, optional): Config dict of activation layer in motion + encoder. Default: None. + """ + _h_channels = {'Basic': 128, 'Small': 96} + _cxt_channels = {'Basic': 128, 'Small': 64} + + def __init__( + self, + net_type: str, + radius: int, + iters: int, + corr_op_cfg: dict = dict(type='CorrLookupFlow1D', align_corners=True), + gru_type: str = 'SeqConv', + feat_channels: Union[int, Sequence[int]] = 256, + mask_channels: int = 64, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + act_cfg: Optional[dict] = None, + flow_loss: Optional[dict] = None, + ) -> None: + super().__init__() + assert net_type in ['Basic', 'Small'] + assert type(feat_channels) in (int, tuple, list) + self.attn_block_x = Attention1D( + in_channels=feat_channels, + y_attention=False, + double_cross_attn=True) + self.attn_block_y = Attention1D( + in_channels=feat_channels, + y_attention=True, + double_cross_attn=True) + self.corr_block = Correlation1D() + + feat_channels = feat_channels if isinstance(tuple, + list) else [feat_channels] + self.feat_channels = feat_channels + self.net_type = net_type + self.radius = radius + self.h_channels = self._h_channels.get(net_type) + self.cxt_channels = self._cxt_channels.get(net_type) + self.iters = iters + self.mask_channels = mask_channels * 9 + corr_op_cfg['radius'] = radius + self.corr_lookup = build_operators(corr_op_cfg) + self.encoder = MotionEncoderFlow1D( + radius=radius, + net_type=net_type, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gru_type = gru_type + self.gru = self.make_gru_block() + self.flow_pred = XHead(self.h_channels, feat_channels, 2, x='flow') + + if net_type == 'Basic': + self.mask_pred = XHead( + self.h_channels, feat_channels, self.mask_channels, x='mask') + + if flow_loss is not None: + self.flow_loss = build_loss(flow_loss) + + def make_gru_block(self): + return ConvGRU( + self.h_channels, + self.encoder.out_channels[0] + 2 + self.cxt_channels, + net_type=self.gru_type) + + def _upsample(self, + flow: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex + combination. + + Args: + flow (Tensor): The optical flow with the shape [N, 2, H/8, W/8]. + mask (Tensor, optional): The learnable mask with shape + [N, grid_size x scale x scale, H/8, H/8]. + + Returns: + Tensor: The output optical flow with the shape [N, 2, H, W]. + """ + scale = 8 + grid_size = 9 + grid_side = int(math.sqrt(grid_size)) + N, _, H, W = flow.shape + if mask is None: + new_size = (scale * H, scale * W) + return scale * F.interpolate( + flow, size=new_size, mode='bilinear', align_corners=True) + # predict a (Nx8×8×9xHxW) mask + mask = mask.view(N, 1, grid_size, scale, scale, H, W) + mask = torch.softmax(mask, dim=2) + + # extract local grid with 3x3 side padding = grid_side//2 + upflow = F.unfold(scale * flow, [grid_side, grid_side], padding=1) + # upflow with shape N, 2, 9, 1, 1, H, W + upflow = upflow.view(N, 2, grid_size, 1, 1, H, W) + + # take a weighted combination over the neighborhood grid 3x3 + # upflow with shape N, 2, 8, 8, H, W + upflow = torch.sum(mask * upflow, dim=2) + upflow = upflow.permute(0, 1, 4, 2, 5, 3) + return upflow.reshape(N, 2, scale * H, scale * W) + + def forward(self, feat1: torch.Tensor, feat2: torch.Tensor, + flow: torch.Tensor, h: torch.Tensor, + cxt_feat: torch.Tensor) -> Sequence[torch.Tensor]: + """Forward function for Flow1D. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The initialized flow when warm start. + h (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + + Returns: + Sequence[Tensor]: The list of predicted optical flow. + """ + pos_encoding = PositionEmbeddingSine(self.feat_channels[0] // 2) + position = pos_encoding(feat1) + + # attention + feat2_x, attn_x = self.attn_block_x(feat1, feat2, position, None) + feat2_y, attn_y = self.attn_block_y(feat1, feat2, position, None) + correlation_x = self.corr_block(feat1, feat2_x, False) + correlation_y = self.corr_block(feat1, feat2_y, True) + corrleation1d = [correlation_x, correlation_y] + upflow_preds = [] + delta_flow = torch.zeros_like(flow) + for _ in range(self.iters): + flow = flow.detach() + corr = self.corr_lookup(corrleation1d, flow) + motion_feat = self.encoder(corr, flow) + x = torch.cat([cxt_feat, motion_feat], dim=1) + h = self.gru(h, x) + + delta_flow = self.flow_pred(h) + flow = flow + delta_flow + + if hasattr(self, 'mask_pred'): + mask = .25 * self.mask_pred(h) + else: + mask = None + + upflow = self._upsample(flow, mask) + upflow_preds.append(upflow) + + return upflow_preds + + def forward_train( + self, + feat1: torch.Tensor, + feat2: torch.Tensor, + flow: torch.Tensor, + h_feat: torch.Tensor, + cxt_feat: torch.Tensor, + flow_gt: torch.Tensor, + valid: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """Forward function when model training. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The last estimated flow from GRU cell. + h (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + flow_gt (Tensor): The ground truth of optical flow. + Defaults to None. + valid (Tensor, optional): The valid mask. Defaults to None. + + Returns: + Dict[str, Tensor]: The losses of model. + """ + + flow_pred = self.forward(feat1, feat2, flow, h_feat, cxt_feat) + + return self.losses(flow_pred, flow_gt, valid=valid) + + def forward_test(self, + feat1: torch.Tensor, + feat2: torch.Tensor, + flow: torch.Tensor, + h_feat: torch.Tensor, + cxt_feat: torch.Tensor, + img_metas=None) -> Sequence[Dict[str, np.ndarray]]: + """Forward function when model training. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The last estimated flow from GRU cell. + h_feat (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Sequence[Dict[str, ndarray]]: The batch of predicted optical flow + with the same size of images before augmentation. + """ + flow_pred = self.forward(feat1, feat2, flow, h_feat, cxt_feat) + + flow_result = flow_pred[-1] + # flow maps with the shape [H, W, 2] + flow_result = flow_result.permute(0, 2, 3, 1).cpu().data.numpy() + # unravel batch dim + flow_result = list(flow_result) + flow_result = [dict(flow=f) for f in flow_result] + return self.get_flow(flow_result, img_metas=img_metas) + + def losses(self, + flow_pred: Sequence[torch.Tensor], + flow_gt: torch.Tensor, + valid: torch.Tensor = None) -> Dict[str, torch.Tensor]: + """Compute optical flow loss. + + Args: + flow_pred (Sequence[Tensor]): The list of predicted optical flow. + flow_gt (Tensor): The ground truth of optical flow. + valid (Tensor, optional): The valid mask. Defaults to None. + + Returns: + Dict[str, Tensor]: The dict of losses. + """ + + loss = dict() + loss['loss_flow'] = self.flow_loss(flow_pred, flow_gt, valid) + return loss diff --git a/mmflow/models/flow_estimators/__init__.py b/mmflow/models/flow_estimators/__init__.py index 7020cff2..7e8052d2 100644 --- a/mmflow/models/flow_estimators/__init__.py +++ b/mmflow/models/flow_estimators/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .flow1d import Flow1D from .flownet import FlowNetC, FlowNetS from .flownet2 import FlowNet2, FlowNetCSS from .irrpwc import IRRPWC @@ -9,5 +10,5 @@ __all__ = [ 'FlowNetC', 'FlowNetS', 'LiteFlowNet', 'PWCNet', 'MaskFlowNetS', 'RAFT', - 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet' + 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet', 'Flow1D' ] diff --git a/mmflow/models/flow_estimators/flow1d.py b/mmflow/models/flow_estimators/flow1d.py new file mode 100644 index 00000000..49b287cd --- /dev/null +++ b/mmflow/models/flow_estimators/flow1d.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from numpy import ndarray + +from ..builder import FLOW_ESTIMATORS, build_encoder +from .pwcnet import PWCNet + + +@FLOW_ESTIMATORS.register_module() +class Flow1D(PWCNet): + """Flow1D model. + + Args: + radius (int): Number of radius in . + cxt_channels (int): Number of channels of context feature. + h_channels (int): Number of channels of hidden feature in . + cxt_encoder (dict): Config dict for building context encoder. + freeze_bn (bool, optional): Whether to freeze batchnorm layer or not. + Default: False. + """ + + def __init__(self, + radius: int, + cxt_channels: int, + h_channels: int, + cxt_encoder: dict, + freeze_bn: bool = False, + **kwargs) -> None: + super().__init__(**kwargs) + self.radius = radius + self.context = build_encoder(cxt_encoder) + self.h_channels = h_channels + self.cxt_channels = cxt_channels + + assert self.radius == self.decoder.radius + assert self.h_channels == self.decoder.h_channels + assert self.cxt_channels == self.decoder.cxt_channels + assert self.h_channels + self.cxt_channels == self.context.out_channels + if freeze_bn: + self.freeze_bn() + + def freeze_bn(self) -> None: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def extract_feat( + self, imgs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Extract features from images. + + Args: + imgs (Tensor): The concatenated input images. + + Returns: + Tuple[Tensor, Tensor, Tensor, Tensor]: The feature from the first + image, the feature from the second image, the hidden state + feature for GRU cell and the contextual feature. + """ + in_channels = self.encoder.in_channels + img1 = imgs[:, :in_channels, ...] + img2 = imgs[:, in_channels:, ...] + + feat1 = self.encoder(img1) + feat2 = self.encoder(img2) + cxt_feat = self.context(img1) + + h_feat, cxt_feat = torch.split( + cxt_feat, [self.h_channels, self.cxt_channels], dim=1) + h_feat = torch.tanh(h_feat) + cxt_feat = torch.relu(cxt_feat) + + return feat1, feat2, h_feat, cxt_feat + + def forward_train( + self, + imgs: torch.Tensor, + flow_gt: torch.Tensor, + valid: torch.Tensor, + flow_init: Optional[torch.Tensor] = None, + img_metas: Optional[Sequence[dict]] = None + ) -> Dict[str, torch.Tensor]: + """Forward function for Flow1D when model training. + + Args: + imgs (Tensor): The concatenated input images. + flow_gt (Tensor): The ground truth of optical flow. + Defaults to None. + valid (Tensor, optional): The valid mask. Defaults to None. + flow_init (Tensor, optional): The initialized flow when warm start. + Default to None. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Dict[str, Tensor]: The losses of output. + """ + + feat1, feat2, h_feat, cxt_feat = self.extract_feat(imgs) + B, _, H, W = feat1.shape + + if flow_init is None: + flow_init = torch.zeros((B, 2, H, W), device=feat1.device) + + return self.decoder.forward_train( + feat1, + feat2, + flow=flow_init, + h_feat=h_feat, + cxt_feat=cxt_feat, + flow_gt=flow_gt, + valid=valid) + + def forward_test( + self, + imgs: torch.Tensor, + flow_init: Optional[torch.Tensor] = None, + img_metas: Optional[Sequence[dict]] = None) -> Sequence[ndarray]: + """Forward function for Flow1D when model testing. + + Args: + imgs (Tensor): The concatenated input images. + flow_init (Tensor, optional): The initialized flow when warm start. + Default to None. + img_metas (Sequence[dict], optional): meta data of image to revert + the flow to original ground truth size. Defaults to None. + + Returns: + Sequence[Dict[str, ndarray]]: the batch of predicted optical flow + with the same size of images after augmentation. + """ + train_iter = self.decoder.iters + if self.test_cfg is not None and self.test_cfg.get( + 'iters') is not None: + self.decoder.iters = self.test_cfg.get('iters') + + feat1, feat2, h_feat, cxt_feat = self.extract_feat(imgs) + B, _, H, W = feat1.shape + + if flow_init is None: + flow_init = torch.zeros((B, 2, H, W), device=feat1.device) + + results = self.decoder.forward_test( + feat1=feat1, + feat2=feat2, + flow=flow_init, + h_feat=h_feat, + cxt_feat=cxt_feat, + img_metas=img_metas) + # recover iter in train + self.decoder.iters = train_iter + + return results diff --git a/mmflow/models/utils/correlation1d.py b/mmflow/models/utils/correlation1d.py new file mode 100644 index 00000000..1a170255 --- /dev/null +++ b/mmflow/models/utils/correlation1d.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmcv.runner import BaseModule +from torch import Tensor + + +class Correlation1D(BaseModule): + """Correlation1D Module. + + The neck of Flow1D, which calculates correlation tensor of input features + with the method of 3D cost volume. + """ + + def __init__(self): + super().__init__() + + def forward(self, + feat1: Tensor, + feat2: Tensor, + y_direction: bool = False) -> Tensor: + """Forward function for Correlation1D. + + Args: + feat1 (Tensor): The feature from first input image. + feat2 (Tensor): The 1D cross attention feat2 on x or y direction. + y_direction (bool): whether y direction or not. + Returns: + Tensor: Correlation of x correlation or y correlation. + """ + b, c, h, w = feat1.shape + scale_factor = c**0.5 + if y_direction: + # y direction, corr shape is [B, W, H, H] + feat1 = feat1.permute(0, 3, 2, 1) + feat2 = feat2.permute(0, 3, 1, 2) + else: + # x direction, corr shape is [B, H, W, W] + feat1 = feat1.permute(0, 2, 3, 1) + feat2 = feat2.permute(0, 2, 1, 3) + corr = torch.matmul(feat1, feat2) / scale_factor + return corr diff --git a/mmflow/models/utils/res_layer.py b/mmflow/models/utils/res_layer.py index 14b60d46..6e8ff86a 100644 --- a/mmflow/models/utils/res_layer.py +++ b/mmflow/models/utils/res_layer.py @@ -78,7 +78,7 @@ def _inner_forward(x): if self.downsample is not None: identity = self.downsample(x) - out += identity + out = out + identity return out @@ -288,7 +288,7 @@ def _inner_forward(x): if self.downsample is not None: identity = self.downsample(x) - out += identity + out = out + identity return out diff --git a/mmflow/ops/corr_lookup.py b/mmflow/ops/corr_lookup.py index db8d375b..5b555908 100644 --- a/mmflow/ops/corr_lookup.py +++ b/mmflow/ops/corr_lookup.py @@ -135,3 +135,90 @@ def forward(self, corr_pyramid: Sequence[Tensor], flow: Tensor) -> Tensor: out = torch.cat(out_corr_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() + + +@OPERATORS.register_module() +class CorrLookupFlow1D(nn.Module): + """Correlation lookup operator for Flow1D. + + This operator is used in `Flow1D`_ + + Args: + radius (int): the radius of the local neighborhood of the pixels. + Default to 4. + mode (str): interpolation mode to calculate output values 'bilinear' + | 'nearest' | 'bicubic'. Default: 'bilinear' Note: mode='bicubic' + supports only 4-D input. + padding_mode (str): padding mode for outside grid values 'zeros' | + 'border' | 'reflection'. Default: 'zeros' + align_corners (bool): If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s corner + pixels. If set to False, they are instead considered as referring + to the corner points of the input’s corner pixels, making the + sampling more resolution agnostic. Default to True. + """ + + def __init__(self, + radius: int = 4, + mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = True) -> None: + super().__init__() + self.r = radius + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + + def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor: + """Forward function of Correlation lookup for Flow1D. + + Args: + corr (Sequence[Tensor]): Correlation on x and y direction. + flow (Tensor): Current estimated optical flow. + + Returns: + Tensor: lookup cost volume on the correlation of x and y directions + concatenate together. + """ + B, _, H, W = flow.shape + # reshape corr_x from [B, H, W, W] to [B*H*W, 1, 1, W] + corr_x = corr[0].view(-1, 1, 1, W) + # reshape corr_y from [B, W, H, H]to [B*H*W, 1, H, 1] + corr_y = corr[1].permute(0, 2, 1, 3).contiguous().view(-1, 1, H, 1) + + # reshape flow to [B, H, W, 2] + flow = flow.permute(0, 2, 3, 1) + coords_x = flow[:, :, :, 0] + coords_y = flow[:, :, :, 1] + coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1) + coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1) + centroid_x = coords_x.view(B * H * W, 1, 1, 2) + centroid_y = coords_y.view(B * H * W, 1, 1, 2) + + dx = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + dy = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + + delta_x = torch.stack((dx, torch.zeros_like(dx)), dim=-1) + delta_y = torch.stack((torch.zeros_like(dy), dy), dim=-1) + # [1, 2r+1, 1, 2] + delta_y = delta_y.view(1, 2 * self.r + 1, 1, 2) + + coords_x = centroid_x + delta_x + coords_y = centroid_y + delta_y + + corr_x = bilinear_sample(corr_x, coords_x, self.mode, + self.padding_mode, self.align_corners) + corr_y = bilinear_sample(corr_y, coords_y, self.mode, + self.padding_mode, self.align_corners) + + # shape is [B, 2r+1, H, W] + corr_x = corr_x.view(B, H, W, -1) + corr_x = corr_x.permute(0, 3, 1, 2).contiguous() + corr_y = corr_y.view(B, H, W, -1) + corr_y = corr_y.permute(0, 3, 1, 2).contiguous() + + correlation = torch.cat((corr_x, corr_y), dim=1) + + return correlation diff --git a/tests/test_models/test_decoder/test_flow1d_decoder.py b/tests/test_models/test_decoder/test_flow1d_decoder.py new file mode 100644 index 00000000..bded217f --- /dev/null +++ b/tests/test_models/test_decoder/test_flow1d_decoder.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmflow.models.decoders.flow1d_decoder import (Flow1DDecoder, + MotionEncoderFlow1D) + + +@pytest.mark.parametrize('net_type', ['Basic', 'Small']) +def test_motion_encoder(net_type): + + # test invalid net_type + with pytest.raises(AssertionError): + MotionEncoderFlow1D(net_type='invalid value') + + module = MotionEncoderFlow1D( + net_type=net_type, conv_cfg=None, norm_cfg=None, act_cfg=None) + radius = 4 + + input_corr = torch.randn((1, 2 * (2 * radius + 1), 56, 56)) + input_flow = torch.randn((1, 2, 56, 56)) + + corr_feat = module.corr_net(input_corr) + flow_feat = module.flow_net(input_flow) + our_feat = module.out_net(torch.cat([corr_feat, flow_feat], dim=1)) + + if net_type == 'Basic': + assert corr_feat.shape == torch.Size((1, 192, 56, 56)) + assert flow_feat.shape == torch.Size((1, 64, 56, 56)) + assert our_feat.shape == torch.Size((1, 126, 56, 56)) + elif net_type == 'Small': + assert corr_feat.shape == torch.Size((1, 96, 56, 56)) + assert flow_feat.shape == torch.Size((1, 32, 56, 56)) + assert our_feat.shape == torch.Size((1, 80, 56, 56)) + + +def test_flow1d_decoder(): + model = Flow1DDecoder( + net_type='Basic', + radius=4, + iters=12, + flow_loss=dict(type='SequenceLoss')) + mask = torch.ones((1, 64 * 9, 10, 10)) + flow = torch.randn((1, 2, 10, 10)) + assert model._upsample(flow, mask).shape == torch.Size((1, 2, 80, 80)) + + feat1 = torch.randn(1, 256, 8, 8) + feat2 = torch.randn(1, 256, 8, 8) + h_feat = torch.randn(1, 128, 8, 8) + cxt_feat = torch.randn(1, 128, 8, 8) + flow = torch.zeros((1, 2, 8, 8)) + + flow_gt = torch.randn(1, 2, 64, 64) + # test forward function + out = model(feat1, feat2, flow, h_feat, cxt_feat) + assert isinstance(out, list) + assert out[0].shape == torch.Size((1, 2, 64, 64)) + + # test forward train + loss = model.forward_train( + feat1, feat2, flow, h_feat, cxt_feat, flow_gt=flow_gt) + assert float(loss['loss_flow']) > 0. + + # test forward test + out = model.forward_test(feat1, feat2, flow, h_feat, cxt_feat) + assert out[0]['flow'].shape == (64, 64, 2) diff --git a/tests/test_models/test_flow_estimator.py b/tests/test_models/test_flow_estimator.py index dc3016dd..eb1db15e 100644 --- a/tests/test_models/test_flow_estimator.py +++ b/tests/test_models/test_flow_estimator.py @@ -44,6 +44,7 @@ def test_flow_estimator(cfg_file): @pytest.mark.parametrize('cfg_file', [ '../../configs/_base_/models/raft.py', + '../../configs/_base_/models/flow1d.py', '../../configs/_base_/models/flownets.py', '../../configs/_base_/models/flownet2/flownet2sd.py', '../../configs/_base_/models/gma/gma.py', @@ -57,7 +58,7 @@ def test_flow_estimator_without_cuda(cfg_file): cfg_file = osp.join(osp.dirname(__file__), cfg_file) cfg = Config.fromfile(cfg_file) - if cfg.model.type == 'RAFT': + if cfg.model.type == 'RAFT' or cfg.model.type == 'Flow1D': # Replace SyncBN with BN to inference on CPU cfg.model.cxt_encoder.norm_cfg = dict(type='BN', requires_grad=True) diff --git a/tests/test_models/test_utils/test_correlation1d.py b/tests/test_models/test_utils/test_correlation1d.py new file mode 100644 index 00000000..fa9f4ffd --- /dev/null +++ b/tests/test_models/test_utils/test_correlation1d.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from mmflow.models.utils.correlation1d import Correlation1D + +_feat1 = torch.arange(0, 24).view(1, 2, 3, 4) +_feat2 = _feat1 + 1 +b, c, h, w = _feat1.size() + + +def test_correlation(): + gt_corr_x = Tensor([[[[110.3087, 118.7939, 127.2792, 135.7645], + [120.2082, 130.1077, 140.0071, 149.9066], + [130.1077, 141.4214, 152.7351, 164.0488], + [140.0071, 152.7351, 165.4630, 178.1909]], + [[206.4752, 220.6173, 234.7595, 248.9016], + [222.0315, 237.5879, 253.1442, 268.7006], + [237.5879, 254.5584, 271.5290, 288.4996], + [253.1442, 271.5290, 289.9138, 308.2986]], + [[347.8965, 367.6955, 387.4945, 407.2935], + [369.1097, 390.3229, 411.5362, 432.7494], + [390.3229, 412.9504, 435.5778, 458.2052], + [411.5362, 435.5778, 459.6194, 483.6610]]]]) + gt_corr_y = Tensor([[[[110.3087, 144.2498, 178.1909], + [149.9066, 206.4752, 263.0437], + [189.5046, 268.7006, 347.8965]], + [[130.1077, 169.7056, 209.3036], + [175.3625, 237.5879, 299.8133], + [220.6173, 305.4701, 390.3229]], + [[152.7351, 197.9899, 243.2447], + [203.6468, 271.5290, 339.4113], + [254.5584, 345.0681, 435.5778]], + [[178.1909, 229.1026, 280.0143], + [234.7595, 308.2986, 381.8377], + [291.3280, 387.4945, 483.6610]]]]) + corr = Correlation1D() + corr_x = corr(_feat1, _feat2, False) + corr_y = corr(_feat1, _feat2, True) + assert corr_x.size() == (b, h, w, w) + assert corr_y.size() == (b, w, h, h) + assert torch.allclose(corr_x, gt_corr_x, atol=1e-4) + assert torch.allclose(corr_y, gt_corr_y, atol=1e-4) diff --git a/tests/test_op/test_corr_lookup.py b/tests/test_op/test_corr_lookup.py index f31f4a2b..e5cb99a9 100644 --- a/tests/test_op/test_corr_lookup.py +++ b/tests/test_op/test_corr_lookup.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch +from torch import Tensor from mmflow.models.decoders.raft_decoder import CorrelationPyramid +from mmflow.models.utils.correlation1d import Correlation1D from mmflow.ops.builder import build_operators from mmflow.ops.corr_lookup import bilinear_sample, coords_grid @@ -17,7 +20,6 @@ def test_coords_grid(): assert grid.shape == torch.Size((2, 2, H, W)) for i in range(H): for j in range(W): - assert torch.all(grid[0, :, i, j] == torch.Tensor((j, i))) @@ -56,3 +58,51 @@ def test_corr_lookup(): corr_lpt = corr_lookup_op(corr_pyramid, torch.randn(1, 2, H, W)) assert corr_lpt.shape == torch.Size((1, 81 * 4, H, W)) + + +@pytest.mark.parametrize('mode', ['bilinear', 'nearest', 'bicubic']) +@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection']) +@pytest.mark.parametrize('align_corners', [True, False]) +def test_corr_lookup_flow1d(mode, padding_mode, align_corners): + corr_block = Correlation1D() + feat1 = torch.arange(0, 24) + feat1 = feat1.view(1, 2, 3, 4) + feat2 = feat1 + 1 + flow = torch.ones_like(feat1) + b, c, h, w = feat1.size() + radius = 1 + + # gronud truth + gt_corr_x = Tensor([[[[110.3087, 120.2082, 130.1077, 140.0071], + [206.4752, 222.0315, 237.5879, 253.1442], + [347.8965, 369.1097, 390.3229, 411.5362]], + [[118.7939, 130.1077, 141.4214, 152.7351], + [220.6173, 237.5879, 254.5584, 271.5290], + [367.6955, 390.3229, 412.9504, 435.5778]], + [[127.2792, 140.0071, 152.7351, 165.4630], + [234.7595, 253.1442, 271.5290, 289.9138], + [387.4945, 411.5362, 435.5778, 459.6194]]]]) + gt_corr_y = Tensor([[[[110.3087, 130.1077, 152.7351, 178.1909], + [149.9066, 175.3625, 203.6468, 234.7595], + [189.5046, 220.6173, 254.5584, 291.3280]], + [[144.2498, 169.7056, 197.9899, 229.1026], + [206.4752, 237.5879, 271.5290, 308.2986], + [268.7006, 305.4701, 345.0681, 387.4945]], + [[178.1909, 209.3036, 243.2447, 280.0143], + [263.0437, 299.8133, 339.4113, 381.8377], + [347.8965, 390.3229, 435.5778, 483.6610]]]]) + gt_corr = torch.cat((gt_corr_x, gt_corr_y), dim=1) + correlation_x = corr_block(feat1, feat2, False) + correlation_y = corr_block(feat1, feat2, True) + correlation = [correlation_x, correlation_y] + corr_lookup_cfg = dict( + type='CorrLookupFlow1D', + radius=radius, + mode=mode, + padding_mode=padding_mode, + align_corners=True) + corr_lookup_op = build_operators(corr_lookup_cfg) + + corr_xy = corr_lookup_op(correlation, flow) + assert corr_xy.size() == (b, 2 * (2 * radius + 1), h, w) + assert torch.allclose(gt_corr, corr_xy, atol=1e-4)