msvsr.py 42.5 KB
Newer Older
W
wangna11BD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
#  Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.ops import DeformConv2D

from .basicvsr import PixelShufflePack, flow_warp, SPyNet, ResidualBlocksWithInputConv
from ...utils.download import get_path_from_url
from ...modules.init import kaiming_normal_, constant_
from .builder import GENERATORS


@GENERATORS.register()
class MSVSR(nn.Layer):
    """PP-MSVSR network structure for video super-resolution.

    Support only x4 upsampling.
    Paper:
        PP-MSVSR: Multi-Stage Video Super-Resolution, 2021

    Args:
        mid_channels (int): Channel number of the intermediate features.
            Default: 32.
        num_init_blocks (int): Number of residual blocks in feat_extract.
            Default: 2.
        num_blocks (int): Number of residual blocks in each propagation branch.
            Default: 3.
        num_reconstruction_blocks (int): Number of residual blocks in reconstruction.
            Default: 2.
        only_last (bool): Whether the hr feature only do the last convolution.
            Default: True.
        use_tiny_spynet (bool): Whether use tiny spynet.
            Default: True.
        deform_groups (int): Number of deformable_groups in DeformConv2D in stage2 and stage3.
            Defaults: 4.
        stage1_groups (int): Number of deformable_groups in DeformConv2D in stage1.
            Defaults: 8.
        auxiliary_loss (bool): Whether use auxiliary loss.
            Default: True.
        use_refine_align (bool): Whether use refine align.
            Default: True.
        aux_reconstruction_blocks : Number of residual blocks in auxiliary reconstruction.
            Default: 1.
        use_local_connnect (bool): Whether add feature of stage1 after upsample.
            Default: True.
    """
    def __init__(self,
                 mid_channels=32,
                 num_init_blocks=2,
                 num_blocks=3,
                 num_reconstruction_blocks=2,
                 only_last=True,
                 use_tiny_spynet=True,
                 deform_groups=4,
                 stage1_groups=8,
                 auxiliary_loss=True,
                 use_refine_align=True,
                 aux_reconstruction_blocks=1,
                 use_local_connnect=True):

        super().__init__()

        self.mid_channels = mid_channels
        self.only_last = only_last
        self.deform_groups = deform_groups
        self.auxiliary_loss = auxiliary_loss
        self.use_refine_align = use_refine_align
        self.use_local_connnect = use_local_connnect

        # optical flow module
        if use_tiny_spynet:
            self.spynet = ModifiedSPyNet(num_blocks=3, use_tiny_block=True)
            weight_path = get_path_from_url(
                'https://paddlegan.bj.bcebos.com/models/modified_spynet_tiny.pdparams'
            )
            self.spynet.set_state_dict(paddle.load(weight_path))
        else:
            self.spynet = ModifiedSPyNet(num_blocks=6, use_tiny_block=False)
            weight_path = get_path_from_url(
                'https://paddlegan.bj.bcebos.com/models/modified_spynet.pdparams'
            )
            self.spynet.set_state_dict(paddle.load(weight_path))

        # feature extraction module
        self.feat_extract = ResidualBlocksWithInputConv(3, mid_channels,
                                                        num_init_blocks)

        # propagation branches module for stage2 and stage3
        self.deform_align = nn.LayerDict()
        self.backbone = nn.LayerDict()

        prop_names = [
            'stage2_backward', 'stage2_forward', 'stage3_backward',
            'stage3_forward'
        ]

        for i, layer in enumerate(prop_names):
            if i > 1 and self.use_refine_align:
                self.deform_align[layer] = ReAlignmentModule(
                    mid_channels,
                    mid_channels,
                    3,
                    padding=1,
                    deformable_groups=deform_groups)
            else:
                self.deform_align[layer] = AlignmentModule(
                    mid_channels,
                    mid_channels,
                    3,
                    padding=1,
                    deformable_groups=deform_groups)

            self.backbone[layer] = ResidualBlocksWithInputConv(
                (3 + i) * mid_channels, mid_channels, num_blocks)

        # stage1
        self.stage1_align = AlignmentModule(mid_channels,
                                            mid_channels,
                                            3,
                                            padding=1,
                                            deformable_groups=stage1_groups)
        self.stage1_blocks = ResidualBlocksWithInputConv(
            3 * mid_channels, mid_channels, 3)

        # upsampling module
        self.reconstruction = ResidualBlocksWithInputConv(
            6 * mid_channels, mid_channels, num_reconstruction_blocks)

        self.upsample1 = PixelShufflePack(mid_channels,
                                          mid_channels,
                                          2,
                                          upsample_kernel=3)
        self.upsample2 = PixelShufflePack(mid_channels,
                                          mid_channels,
                                          2,
                                          upsample_kernel=3)
        if self.only_last:
            self.conv_last = nn.Conv2D(mid_channels, 3, 3, 1, 1)
        else:
            self.conv_hr = nn.Conv2D(mid_channels, mid_channels, 3, 1, 1)
            self.conv_last = nn.Conv2D(mid_channels, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(scale_factor=4,
                                        mode='bilinear',
                                        align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1)

        # auxiliary loss
        if self.auxiliary_loss:
            self.aux_fusion = nn.Conv2D(mid_channels * 2, mid_channels, 3, 1, 1)

            self.aux_reconstruction = ResidualBlocksWithInputConv(
                4 * mid_channels, mid_channels, aux_reconstruction_blocks)

            self.aux_block_down1 = nn.Sequential(
                nn.Conv2D(3 + mid_channels, mid_channels, 3, 2, 1),
                nn.LeakyReLU(negative_slope=0.1),
                nn.Conv2D(mid_channels, mid_channels, 3, 1, 1))
            self.aux_block_down2 = nn.Sequential(
                nn.Conv2D(mid_channels * 2, mid_channels, 3, 2, 1),
                nn.LeakyReLU(negative_slope=0.1),
                nn.Conv2D(mid_channels, mid_channels, 3, 1, 1))

            self.aux_conv_last = nn.Conv2D(mid_channels, 3, 3, 1, 1)

        self.aux_upsample1 = PixelShufflePack(mid_channels,
                                              mid_channels,
                                              2,
                                              upsample_kernel=3)
        self.aux_upsample2 = PixelShufflePack(mid_channels,
                                              mid_channels,
                                              2,
                                              upsample_kernel=3)
        self.hybrid_conv_last = nn.Conv2D(mid_channels, 3, 3, 1, 1)

    def check_if_mirror_extended(self, lrs):
        """Check whether the input is a mirror-extended sequence.
        If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
        (t-1-i)-th frame.
        Args:
            lrs (tensor): Input LR images with shape (n, t, c, h, w)

        Returns:
            Bool: Whether the input is a mirror-extended sequence.
        """

        with paddle.no_grad():
            self.is_mirror_extended = False
            if lrs.shape[1] % 2 == 0:
                lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1)
                lrs_2 = paddle.flip(lrs_2, [1])
                if paddle.norm(lrs_1 - lrs_2) == 0:
                    self.is_mirror_extended = True

    def compute_flow(self, lrs):
        """Compute optical flow using pretrained flow network for feature alignment.
        Args:
            lrs (tensor): Input LR images with shape (n, t, c, h, w)

        Returns:
            Tuple: Tensor of forward optical flow and backward optical flow with shape (n, t-1, 2, h, w).
        """
        n, t, c, h, w = lrs.shape

        lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w])
        lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w])

        flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w])

        if self.is_mirror_extended:
            flows_forward = flows_backward.flip(1)
        else:
            flows_forward = self.spynet(lrs_2,
                                        lrs_1).reshape([n, t - 1, 2, h, w])

        return flows_forward, flows_backward

    def stage1(self, feats, flows, flows_forward=None):
        """Stage1 of PP-MSVSR network.
        Args:
            feats (dict): Dict with key 'spatial', the value is Array of tensor after feature extraction with shape (n, c, h, w).
            flows (tensor): Backward optical flow with shape (n, t-1, 2, h, w).
            flows_forward (tensor): Forward optical flow with shape (n, t-1, 2, h, w).

        Returns:
            Dict: The input dict with new keys 'feat_stage1', the value of 'feat_stage1' is Array of tensor after Local Fusion Module with shape (n, c, h, w).
        """

        n, t, _, h, w = flows.shape

        frame_idx = range(t, -1, -1)
        flow_idx = range(t, -1, -1)
        mapping_idx = list(range(0, len(feats['spatial'])))
        mapping_idx += mapping_idx[::-1]

        # Local Fusion Module
        for i, idx in enumerate(frame_idx):
            feat_current = feats['spatial'][mapping_idx[idx]]

            # get aligned right adjacent frames
            if i > 0:
                feat_prop = feats['spatial'][mapping_idx[idx + 1]]
                flow_n1 = flows[:, flow_idx[i], :, :, :]
                cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1]))
                cond = paddle.concat([cond_n1, feat_current], axis=1)
                feat_prop, _, _ = self.stage1_align(feat_prop, cond, flow_n1)
            else:
                feat_prop = paddle.zeros([n, self.mid_channels, h, w])

            # get aligned left adjacent frames
            if i < t:
                feat_back = feats['spatial'][mapping_idx[idx - 1]]
                flow_n1_ = flows_forward[:, flow_idx[i] - 1, :, :, :]
                cond_n1_ = flow_warp(feat_back, flow_n1_.transpose([0, 2, 3,
                                                                    1]))
                cond_ = paddle.concat([cond_n1_, feat_current], axis=1)
                feat_back, _, _ = self.stage1_align(feat_back, cond_, flow_n1_)
            else:
                feat_back = paddle.zeros([n, self.mid_channels, h, w])

            # concatenate and residual blocks
            feat = [feat_current] + [feat_prop] + [feat_back]
            feat = paddle.concat(feat, axis=1)
            feat = self.stage1_blocks(feat)

            feats['feat_stage1'].append(feat)

        feats['feat_stage1'] = feats['feat_stage1'][::-1]

        return feats

    def stage2(self, feats, flows):
        """Stage2 of PP-MSVSR network.
        Args:
            feats (dict): Dict with key 'spatial' and 'feat_stage1' after stage1.
            flows (tuple): Tensor of backward optical flow and forward optical flow with shape (n, t-1, 2, h, w).

        Returns:
            feats (dict): The input dict with new keys 'stage2_backward' and 'stage2_forward', the value of both is Array of feature after stage2 with shape (n, c, h, w).
            pre_offset (dict): Dict with keys 'stage2_backward' and 'stage2_forward', the value of both is Array of offset in stage2 with shape (n, 18*deform_groups, h, w).
            pre_mask (dict): Dict with keys 'stage2_backward' and 'stage2_forward', the value of both is Array of mask in stage2 with shape (n, 9*deform_groups, h, w).
        """
        flows_backward, flows_forward = flows
        n, t, _, h, w = flows_backward.shape

        pre_offset = {}
        pre_mask = {}

        # propagation branches module
L
lzzyzlbb 已提交
306 307 308
        prop_names = ['stage2_backward', 'stage2_forward']
        for index in range(2):
            prop_name = prop_names[index]
W
wangna11BD 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
            pre_offset[prop_name] = [0 for _ in range(t)]
            pre_mask[prop_name] = [0 for _ in range(t)]
            feats[prop_name] = []
            frame_idx = range(0, t + 1)
            flow_idx = range(-1, t)
            mapping_idx = list(range(0, len(feats['spatial'])))
            mapping_idx += mapping_idx[::-1]

            if 'backward' in prop_name:
                frame_idx = frame_idx[::-1]
                flow_idx = frame_idx
                flows = flows_backward
            else:
                flows = flows_forward

            feat_prop = paddle.zeros([n, self.mid_channels, h, w])
            for i, idx in enumerate(frame_idx):
                feat_current = feats['spatial'][mapping_idx[idx]]

                if i > 0:
                    flow_n1 = flows[:, flow_idx[i], :, :, :]

                    cond_n1 = flow_warp(feat_prop,
                                        flow_n1.transpose([0, 2, 3, 1]))
                    cond = paddle.concat([cond_n1, feat_current], axis=1)

                    feat_prop, offset, mask = self.deform_align[prop_name](
                        feat_prop, cond, flow_n1)
                    pre_offset[prop_name][flow_idx[i]] = offset
                    pre_mask[prop_name][flow_idx[i]] = (mask)

                # concatenate and residual blocks
                feat = [feat_current] + [
                    feats[k][idx]
                    for k in feats if k not in ['spatial', prop_name]
                ] + [feat_prop]

                feat = paddle.concat(feat, axis=1)
                feat_prop = feat_prop + self.backbone[prop_name](feat)

                feats[prop_name].append(feat_prop)

            if 'backward' in prop_name:
                feats[prop_name] = feats[prop_name][::-1]

        return feats, pre_offset, pre_mask

    def stage3(self,
               feats,
               flows,
               aux_feats=None,
               pre_offset=None,
               pre_mask=None):
        """Stage3 of PP-MSVSR network.
        Args:
            feats (dict): Dict of features after stage2.
            flows (tuple): Tensor of backward optical flow and forward optical flow with shape (n, t-1, 2, h, w).
            aux_feats (dict): Dict with keys 'outs' and 'feats', the value is Array of tensor after auxiliary_stage with shape (n, 3, 4*h, 4*w) and (n, c, h, w), separately.
            pre_offset (dict): Dict with keys 'stage2_backward' and 'stage2_forward', the value of both is Array of offset in stage2 with shape (n, 18*deform_groups, h, w).
            pre_mask (dict): Dict with keys 'stage2_backward' and 'stage2_forward', the value of both is Array of mask in stage2 with shape (n, 9*deform_groups, h, w).

        Returns:
            feats (dict): The input feats dict with new keys 'stage3_backward' and 'stage3_forward', the value of both is Array of feature after stage3 with shape (n, c, h, w).
            """
        flows_backward, flows_forward = flows
        n, t, _, h, w = flows_backward.shape

        # propagation branches module
L
lzzyzlbb 已提交
377 378 379
        prop_names = ['stage3_backward', 'stage3_forward']
        for index in range(2):
            prop_name = prop_names[index]
W
wangna11BD 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
            feats[prop_name] = []
            frame_idx = range(0, t + 1)
            flow_idx = range(-1, t)
            mapping_idx = list(range(0, len(feats['spatial'])))
            mapping_idx += mapping_idx[::-1]

            if 'backward' in prop_name:
                frame_idx = frame_idx[::-1]
                flow_idx = frame_idx
                flows = flows_backward
                pre_stage_name = 'stage2_backward'
            else:
                flows = flows_forward
                pre_stage_name = 'stage2_forward'

            feat_prop = paddle.zeros([n, self.mid_channels, h, w])
            for i, idx in enumerate(frame_idx):
                feat_current = feats['spatial'][mapping_idx[idx]]
                if aux_feats is not None and 'feats' in aux_feats:
                    feat_current = aux_feats['feats'][mapping_idx[idx]]

                if i > 0:
                    flow_n1 = flows[:, flow_idx[i], :, :, :]

                    cond_n1 = flow_warp(feat_prop,
                                        flow_n1.transpose([0, 2, 3, 1]))
                    cond = paddle.concat([cond_n1, feat_current], axis=1)

                    feat_prop = self.deform_align[prop_name](
                        feat_prop, cond, flow_n1, feat_current,
                        pre_offset[pre_stage_name][flow_idx[i]],
                        pre_mask[pre_stage_name][flow_idx[i]])

                # concatenate and residual blocks
                feat = [feat_current] + [
                    feats[k][idx]
                    for k in feats if k not in ['spatial', prop_name]
                ] + [feat_prop]

                feat = paddle.concat(feat, axis=1)
                feat_prop = feat_prop + self.backbone[prop_name](feat)

                feats[prop_name].append(feat_prop)

            if 'backward' in prop_name:
                feats[prop_name] = feats[prop_name][::-1]

        return feats

    def auxiliary_stage(self, feats, lqs):
        """Compute the output image and auxiliary feature for Auxiliary Loss in stage2.
        Args:
            feats (dict): Dict of features after stage2.
            lqs (tensor): Input LR images with shape (n, t, c, h, w)

        Returns:
            dict: Dict with keys 'outs' and 'feats', the value is Array of tensor after auxiliary_stage with shape (n, 3, 4*h, 4*w) and (n, c, h, w), separately.
        """
        aux_feats = {}
        aux_feats['outs'] = []
        aux_feats['feats'] = []
        num_outputs = len(feats['spatial'])

        mapping_idx = list(range(0, num_outputs))
        mapping_idx += mapping_idx[::-1]

L
lzzyzlbb 已提交
446 447
        t = lqs.shape[1]
        for i in range(0, t):
W
wangna11BD 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
            hr = [feats[k][i] for k in feats if (k != 'spatial')]
            feat_current = feats['spatial'][mapping_idx[i]]
            hr.insert(0, feat_current)
            hr = paddle.concat(hr, axis=1)

            hr_low = self.aux_reconstruction(hr)
            hr_mid = self.lrelu(self.aux_upsample1(hr_low))
            hr_high = self.lrelu(self.aux_upsample2(hr_mid))

            hr = self.aux_conv_last(hr_high)
            hr += self.img_upsample(lqs[:, i, :, :, :])

            # output tensor of auxiliary_stage with shape (n, 3, 4*h, 4*w)
            aux_feats['outs'].append(hr)

            aux_feat = self.aux_block_down1(paddle.concat([hr, hr_high],
                                                          axis=1))
            aux_feat = self.aux_block_down2(
                paddle.concat([aux_feat, hr_mid], axis=1))
            aux_feat = self.aux_fusion(paddle.concat([aux_feat, hr_low],
                                                     axis=1))

            # out feature of auxiliary_stage with shape (n, c, h, w)
            aux_feats['feats'].append(aux_feat)

        return aux_feats

    def upsample(self, lqs, feats, aux_feats=None):
        """Compute the output image given the features.
        Args:
            lqs (tensor): Input LR images with shape (n, t, c, h, w).
            feats (dict): Dict of features after stage3.
            aux_feats (dict): Dict with keys 'outs' and 'feats', the value is Array of tensor after auxiliary_stage with shape (n, 3, 4*h, 4*w) and (n, c, h, w), separately.

        Returns:
            Tensor: Output HR sequence with shape (n, t, 3, 4*h, 4*w).
        """

        outputs = []
        num_outputs = len(feats['spatial'])

        mapping_idx = list(range(0, num_outputs))
        mapping_idx += mapping_idx[::-1]

L
lzzyzlbb 已提交
492 493
        t = lqs.shape[1]
        for i in range(0, t):
W
wangna11BD 已提交
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
            hr = [
                feats[k].pop(0) for k in feats
                if (k != 'spatial' and k != 'feat_stage1')
            ]
            if 'feat_stage1' in feats:
                local_feat = feats['feat_stage1'].pop(0)
                hr.insert(0, local_feat)
            hr.insert(0, feats['spatial'][mapping_idx[i]])
            hr = paddle.concat(hr, axis=1)

            hr = self.reconstruction(hr)

            hr = self.lrelu(self.upsample1(hr))
            hr = self.lrelu(self.upsample2(hr))
            if self.only_last:
                hr = self.conv_last(hr)
            else:
                hr = self.lrelu(self.conv_hr(hr))
                hr = self.conv_last(hr)

            hr += self.img_upsample(lqs[:, i, :, :, :])
            if self.use_local_connnect:
                local_head = self.lrelu(self.aux_upsample1(local_feat))
                local_head = self.lrelu(self.aux_upsample2(local_head))
                hr = self.hybrid_conv_last(local_head) + hr

            outputs.append(hr)

        if self.auxiliary_loss:
            return paddle.stack(aux_feats['outs'],
                                axis=1), paddle.stack(outputs, axis=1)
        return paddle.stack(outputs, axis=1)

    def forward(self, lqs):
        """Forward function for PP-MSVSR.
        Args:
            lqs (Tensor): Input LR sequence with shape (n, t, c, h, w).
        Returns:
            Tensor: Output HR sequence with shape (n, t, 3, 4*h, 4*w).
        """

        n, t, c, h, w = lqs.shape

        lqs_downsample = lqs

        # check whether the input is an extended sequence
        self.check_if_mirror_extended(lqs)

        feats = {}
        feats_ = self.feat_extract(lqs.reshape([-1, c, h, w]))

        h, w = feats_.shape[2:]
        feats_ = feats_.reshape([n, t, -1, h, w])
        feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]

        # compute optical flow using the low-res inputs
        assert lqs_downsample.shape[3] >= 64 and lqs_downsample.shape[4] >= 64, (
            'The height and width of low-res inputs must be at least 64, '
            f'but got {h} and {w}.')

        flows_forward, flows_backward = self.compute_flow(lqs_downsample)

        # feature propgation
        feats['feat_stage1'] = []
        feats = self.stage1(feats, flows_backward, flows_forward)

        feats, pre_offset, pre_mask = self.stage2(
            feats, (flows_backward, flows_forward))

        if self.auxiliary_loss:
            aux_feats = self.auxiliary_stage(feats, lqs)

        feats = self.stage3(feats, (flows_backward, flows_forward), aux_feats,
                            pre_offset, pre_mask)

        return self.upsample(lqs, feats, aux_feats=aux_feats)


class AlignmentModule(nn.Layer):
    """deformable alignment module.
    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        deformable_groups (int): Number of deformable_groups in DeformConv2D.
    """
    def __init__(self,
                 in_channels=128,
                 out_channels=64,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 deformable_groups=16):
        super(AlignmentModule, self).__init__()

        self.conv_offset = nn.Sequential(
            nn.Conv2D(2 * out_channels + 2, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1),
        )
        self.dcn = DeformConv2D(in_channels,
                                out_channels,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding=padding,
                                dilation=dilation,
                                deformable_groups=deformable_groups)

        self.init_offset()

    def init_offset(self):
        constant_(self.conv_offset[-1].weight, 0)
        constant_(self.conv_offset[-1].bias, 0)

    def forward(self, x, extra_feat, flow_1):
        extra_feat = paddle.concat([extra_feat, flow_1], axis=1)
        out = self.conv_offset(extra_feat)
        o1, o2, mask = paddle.chunk(out, 3, axis=1)

        # offset
        offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1))
        offset = offset + flow_1.flip(1).tile([1, offset.shape[1] // 2, 1, 1])

        # mask
        mask = F.sigmoid(mask)
        out = self.dcn(x, offset, mask)
        return out, offset, mask


class ReAlignmentModule(nn.Layer):
    """refine deformable alignment module.
    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        deformable_groups (int): Number of deformable_groups in DeformConv2D.
    """
    def __init__(self,
                 in_channels=128,
                 out_channels=64,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 deformable_groups=16):
        super(ReAlignmentModule, self).__init__()

        self.mdconv = DeformConv2D(in_channels,
                                   out_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   dilation=dilation,
                                   deformable_groups=deformable_groups)
        self.conv_offset = nn.Sequential(
            nn.Conv2D(2 * out_channels + 2, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1),
        )
        self.dcn = DeformConv2D(in_channels,
                                out_channels,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding=padding,
                                dilation=dilation,
                                deformable_groups=deformable_groups)

        self.init_offset()

    def init_offset(self):
        constant_(self.conv_offset[-1].weight, 0)
        constant_(self.conv_offset[-1].bias, 0)

    def forward(self,
                x,
                extra_feat,
                flow_1,
                feat_current,
                pre_stage_flow=None,
                pre_stage_mask=None):
        if pre_stage_flow is not None:
            pre_feat = self.mdconv(x, pre_stage_flow, pre_stage_mask)
            extra_feat = paddle.concat([pre_feat, feat_current, flow_1], axis=1)
        else:
            extra_feat = paddle.concat([extra_feat, flow_1], axis=1)
        out = self.conv_offset(extra_feat)
        o1, o2, mask = paddle.chunk(out, 3, axis=1)

        # offset
        offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1))
        if pre_stage_flow is not None:
            offset = offset + pre_stage_flow
        else:
            offset = offset + flow_1.flip(1).tile(
                [1, offset.shape[1] // 2, 1, 1])

        # mask
        if pre_stage_mask is not None:
            mask = (F.sigmoid(mask) + pre_stage_mask) / 2.0
        else:
            mask = F.sigmoid(mask)
        out = self.dcn(x, offset, mask)
        return out


class ModifiedSPyNet(nn.Layer):
    """Modified SPyNet network structure.

    The difference to the SPyNet in paper is that
        1. convolution with kernel_size=7 is replaced by convolution with kernel_size=3 in this version,
        2. less SPyNetBasicModule is used in this version,
        3. no BN is used in this version.

    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017

    Args:
        act_cfg (dict): Activation function.
            Default: dict(name='LeakyReLU').
        num_blocks (int): Number of SPyNetBlock.
            Default: 6.
        use_tiny_block (bool): Whether use tiny spynet.
            Default: True.
    """
    def __init__(self,
                 act_cfg=dict(name='LeakyReLU'),
                 num_blocks=6,
                 use_tiny_block=False):
        super().__init__()
        self.num_blocks = num_blocks
        self.basic_module = nn.LayerList([
            SPyNetBlock(act_cfg=act_cfg, use_tiny_block=use_tiny_block)
            for _ in range(num_blocks)
        ])

        self.register_buffer(
            'mean',
            paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
        self.register_buffer(
            'std',
            paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))

    def compute_flow(self, ref, supp):
        """Compute flow from ref to supp.

        Note that in this function, the images are already resized to a
        multiple of 32.

        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).

        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """
        n, _, h, w = ref.shape

        # normalize the input images
        ref = [(ref - self.mean) / self.std]
        supp = [(supp - self.mean) / self.std]

        # generate downsampled frames
        for level in range(self.num_blocks - 1):
            ref.append(F.avg_pool2d(ref[-1], kernel_size=2, stride=2))
            supp.append(F.avg_pool2d(supp[-1], kernel_size=2, stride=2))
        ref = ref[::-1]
        supp = supp[::-1]

        # flow computation
        flow = paddle.to_tensor(
            np.zeros([
                n, 2, h // (2**(self.num_blocks - 1)), w //
                (2**(self.num_blocks - 1))
            ], 'float32'))

        for level in range(len(ref)):
            if level == 0:
                flow_up = flow
            else:
                flow_up = F.interpolate(
                    flow, scale_factor=2, mode='bilinear',
                    align_corners=True) * 2.0

            # add the residue to the upsampled flow
            flow = flow_up + self.basic_module[level](paddle.concat([
                ref[level],
                flow_warp(supp[level],
                          flow_up.transpose([0, 2, 3, 1]),
                          padding_mode='border'), flow_up
            ],
                                                                    axis=1))

        return flow

    def compute_flow_list(self, ref, supp):
        n, _, h, w = ref.shape

        # normalize the input images
        ref = [(ref - self.mean) / self.std]
        supp = [(supp - self.mean) / self.std]

        # generate downsampled frames
        for level in range(self.num_blocks - 1):
            ref.append(F.avg_pool2d(ref[-1], kernel_size=2, stride=2))
            supp.append(F.avg_pool2d(supp[-1], kernel_size=2, stride=2))
        ref = ref[::-1]
        supp = supp[::-1]

        # flow computation
        flow_list = []
        flow = paddle.to_tensor(
            np.zeros([
                n, 2, h // (2**(self.num_blocks - 1)), w //
                (2**(self.num_blocks - 1))
            ], 'float32'))
        for level in range(len(ref)):
            if level == 0:
                flow_up = flow
            else:
                flow_up = F.interpolate(
                    flow, scale_factor=2, mode='bilinear',
                    align_corners=True) * 2.0

            # add the residue to the upsampled flow
            flow = flow_up + self.basic_module[level](paddle.concat([
                ref[level],
                flow_warp(supp[level],
                          flow_up.transpose([0, 2, 3, 1]),
                          padding_mode='border'), flow_up
            ],
                                                                    axis=1))
            flow_list.append(flow)
        return flow_list

    def forward(self, ref, supp):
        """Forward function of Modified SPyNet.

        This function computes the optical flow from ref to supp.

        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).

        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """

        # upsize to a multiple of 32
        h, w = ref.shape[2:4]
        w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
        h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
        ref = F.interpolate(ref,
                            size=(h_up, w_up),
                            mode='bilinear',
                            align_corners=False)

        supp = F.interpolate(supp,
                             size=(h_up, w_up),
                             mode='bilinear',
                             align_corners=False)

        ref.stop_gradient = False
        supp.stop_gradient = False

        # compute flow, and resize back to the original resolution
        flow = F.interpolate(self.compute_flow(ref, supp),
                             size=(h, w),
                             mode='bilinear',
                             align_corners=False)

        # adjust the flow values
        flow[:, 0, :, :] *= float(w) / float(w_up)
        flow[:, 1, :, :] *= float(h) / float(h_up)

        return flow


class SPyNetBlock(nn.Layer):
    """Basic Block of Modified SPyNet.
    refer to Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
    """
    def __init__(self, act_cfg=dict(name='LeakyReLU'), use_tiny_block=False):
        super().__init__()
        if use_tiny_block:
            self.basic_module = nn.Sequential(
                ConvLayer(in_channels=8,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=8,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=8,
                          out_channels=8,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=8,
                          out_channels=2,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=None))
        else:
            self.basic_module = nn.Sequential(
                ConvLayer(in_channels=8,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=64,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=64,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=32,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=32,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=16,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=act_cfg),
                ConvLayer(in_channels=16,
                          out_channels=2,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          act_cfg=None))

    def forward(self, tensor_input):
        """Forward function of SPyNetBlock.
        Args:
            tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
                8 channels contain:
                [reference image (3), neighbor image (3), initial flow (2)].

        Returns:
            Tensor: Refined flow with shape (b, 2, h, w)
        """
        return self.basic_module(tensor_input)


class ConvLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 act_cfg=dict(name='ReLU')):
        super(ConvLayer, self).__init__()
        self.act_cfg = act_cfg
        self.with_activation = act_cfg is not None

        self.conv = nn.Conv2D(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              dilation=dilation,
                              groups=groups)

        if self.with_activation:
            if act_cfg['name'] == 'ReLU':
                self.act = paddle.nn.ReLU()
            elif act_cfg['name'] == 'LeakyReLU':
                self.act = nn.LeakyReLU(negative_slope=0.1)

    def forward(self, tensor_input):
        out = self.conv(tensor_input)
        if self.with_activation:
            out = self.act(out)
        return out