temporal_shift_op.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
20
using DataLayout = framework::DataLayout;
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
template <typename T>
void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw,
                         const int tchw, const int chw, const int hw,
                         const int t, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < ntchw; i++) {
    int it = (i % tchw) / chw;
    int ic = (i % chw) / hw;

    if (ic < c1) {
      src_it = it - 1;
    } else if (ic < c2) {
      src_it = it + 1;
    } else {
      src_it = it;
    }

    if (src_it < 0 || src_it >= t) {
      output[i] = 0;
    } else {
      output[i] = input[i + (src_it - it) * chw];
    }
  }
}

template <typename T>
void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc,
                         const int thwc, const int hwc, const int t,
                         const int c, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < nthwc; i++) {
    int it = (i % thwc) / hwc;
    int ic = i % c;

    if (ic < c1) {
      src_it = it - 1;
    } else if (ic < c2) {
      src_it = it + 1;
    } else {
      src_it = it;
    }

    if (src_it < 0 || src_it >= t) {
      output[i] = 0;
    } else {
      output[i] = input[i + (src_it - it) * hwc];
    }
  }
}

template <typename T>
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
                         const int tchw, const int chw, const int hw,
                         const int t, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < ntchw; i++) {
    int it = (i % tchw) / chw;
    int ic = (i % chw) / hw;

    if (ic < c1) {
      src_it = it + 1;
    } else if (ic < c2) {
      src_it = it - 1;
    } else {
      src_it = it;
    }

    if (src_it >= 0 && src_it < t) {
      input_grad[i] = output_grad[i + (src_it - it) * chw];
    } else {
      input_grad[i] = 0;
    }
  }
}

template <typename T>
void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
                         const int thwc, const int hwc, const int t,
                         const int c, const int c1, const int c2) {
  int src_it = 0;
  for (int i = 0; i < nthwc; i++) {
    int it = (i % thwc) / hwc;
    int ic = i % c;

    if (ic < c1) {
      src_it = it + 1;
    } else if (ic < c2) {
      src_it = it - 1;
    } else {
      src_it = it;
    }

    if (src_it >= 0 && src_it < t) {
      input_grad[i] = output_grad[i + (src_it - it) * hwc];
    } else {
      input_grad[i] = 0;
    }
  }
120 121 122
}

template <typename T>
D
dengkaipeng 已提交
123
class TemporalShiftKernel : public framework::OpKernel<T> {
124 125 126 127 128
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    int t = ctx.Attr<int>("seg_num");
D
dengkaipeng 已提交
129
    float shift_ratio = ctx.Attr<float>("shift_ratio");
130 131 132
    const std::string data_format_str = ctx.Attr<std::string>("data_format");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_format_str);
133 134

    const int nt = input->dims()[0];
135 136 137 138 139 140
    const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
                                                    : input->dims()[3]);
    const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
                                                    : input->dims()[1]);
    const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
                                                    : input->dims()[2]);
D
dengkaipeng 已提交
141

142 143 144
    const int hw = h * w;
    const int chw = c * hw;
    const int tchw = t * chw;
145
    const int ntchw = nt * chw;
146

147 148 149 150 151 152
    const int c1 = static_cast<int>(c * shift_ratio);
    const int c2 = static_cast<int>(c * 2 * shift_ratio);

    framework::DDim out_dims = (data_layout == DataLayout::kNCHW
                                    ? framework::make_ddim({nt, c, h, w})
                                    : framework::make_ddim({nt, h, w, c}));
153
    const T* input_data = input->data<T>();
154 155 156 157 158 159 160 161
    T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());

    if (data_layout == DataLayout::kNCHW) {
      TemporalShiftFwNCHW<T>(input_data, output_data, ntchw, tchw, chw, hw, t,
                             c1, c2);
    } else {
      TemporalShiftFwNHWC<T>(input_data, output_data, ntchw, tchw, chw, t, c,
                             c1, c2);
162 163 164 165 166 167 168 169 170 171 172
    }
  }
};

template <typename T>
class TemporalShiftGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    int t = ctx.Attr<int>("seg_num");
D
dengkaipeng 已提交
173
    float shift_ratio = ctx.Attr<float>("shift_ratio");
174 175 176
    const std::string data_format_str = ctx.Attr<std::string>("data_format");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_format_str);
177 178

    const int nt = output_grad->dims()[0];
179 180 181 182 183 184
    const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
                                                    : output_grad->dims()[3]);
    const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
                                                    : output_grad->dims()[1]);
    const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
                                                    : output_grad->dims()[2]);
D
dengkaipeng 已提交
185

186 187 188
    const int hw = h * w;
    const int chw = c * hw;
    const int tchw = t * chw;
189 190 191 192
    const int ntchw = nt * chw;

    const int c1 = static_cast<int>(c * shift_ratio);
    const int c2 = static_cast<int>(c * 2 * shift_ratio);
193

194 195 196
    framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
                                        ? framework::make_ddim({nt, c, h, w})
                                        : framework::make_ddim({nt, h, w, c}));
197
    const T* output_grad_data = output_grad->data<T>();
D
dengkaipeng 已提交
198
    T* input_grad_data =
199 200 201 202 203 204 205 206
        input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());

    if (data_layout == DataLayout::kNCHW) {
      TemporalShiftBwNCHW<T>(output_grad_data, input_grad_data, ntchw, tchw,
                             chw, hw, t, c1, c2);
    } else {
      TemporalShiftBwNHWC<T>(output_grad_data, input_grad_data, ntchw, tchw,
                             chw, t, c, c1, c2);
207 208 209 210 211 212
    }
  }
};

}  // namespace operators
}  // namespace paddle