temporal_shift_op.h 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include "paddle/fluid/framework/op_registry.h"
14
#include "paddle/phi/kernels/funcs/math_function.h"
15 16 17 18 19

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
20
using DataLayout = framework::DataLayout;
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
template <typename T>
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
                         const int tchw, const int chw, const int hw,
                         const int t, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < ntchw; i++) {
    int it = (i % tchw) / chw;
    int ic = (i % chw) / hw;

    if (ic < c1) {
      src_it = it + 1;
    } else if (ic < c2) {
      src_it = it - 1;
    } else {
      src_it = it;
    }

    if (src_it >= 0 && src_it < t) {
      input_grad[i] = output_grad[i + (src_it - it) * chw];
    } else {
      input_grad[i] = 0;
    }
  }
}

template <typename T>
void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
                         const int thwc, const int hwc, const int t,
                         const int c, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < nthwc; i++) {
    int it = (i % thwc) / hwc;
    int ic = i % c;

    if (ic < c1) {
      src_it = it + 1;
    } else if (ic < c2) {
      src_it = it - 1;
    } else {
      src_it = it;
    }

    if (src_it >= 0 && src_it < t) {
      input_grad[i] = output_grad[i + (src_it - it) * hwc];
    } else {
      input_grad[i] = 0;
    }
  }
70 71 72
}

template <typename T>
D
dengkaipeng 已提交
73
class TemporalShiftKernel : public framework::OpKernel<T> {
74
 public:
P
update  
phlrain 已提交
75
  void Compute(const framework::ExecutionContext& ctx) const override {}
76 77 78 79 80 81 82 83 84
};

template <typename T>
class TemporalShiftGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    int t = ctx.Attr<int>("seg_num");
D
dengkaipeng 已提交
85
    float shift_ratio = ctx.Attr<float>("shift_ratio");
86 87 88
    const std::string data_format_str = ctx.Attr<std::string>("data_format");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_format_str);
89 90

    const int nt = output_grad->dims()[0];
91 92 93 94 95 96
    const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
                                                    : output_grad->dims()[3]);
    const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
                                                    : output_grad->dims()[1]);
    const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
                                                    : output_grad->dims()[2]);
D
dengkaipeng 已提交
97

98 99 100
    const int hw = h * w;
    const int chw = c * hw;
    const int tchw = t * chw;
101 102 103 104
    const int ntchw = nt * chw;

    const int c1 = static_cast<int>(c * shift_ratio);
    const int c2 = static_cast<int>(c * 2 * shift_ratio);
105

106
    framework::DDim in_grad_dims =
107 108
        (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
                                          : phi::make_ddim({nt, h, w, c}));
109
    const T* output_grad_data = output_grad->data<T>();
D
dengkaipeng 已提交
110
    T* input_grad_data =
111 112 113 114 115 116 117 118
        input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());

    if (data_layout == DataLayout::kNCHW) {
      TemporalShiftBwNCHW<T>(output_grad_data, input_grad_data, ntchw, tchw,
                             chw, hw, t, c1, c2);
    } else {
      TemporalShiftBwNHWC<T>(output_grad_data, input_grad_data, ntchw, tchw,
                             chw, t, c, c1, c2);
119 120 121 122 123 124
    }
  }
};

}  // namespace operators
}  // namespace paddle