temporal_shift_op.cu 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/temporal_shift_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using framework::Tensor;

template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
D
dengkaipeng 已提交
22 23 24
                                  const int tchw, const int chw, const int hw,
                                  const int w, const int t, const int c,
                                  const float shift_ratio) {
25 26 27 28
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  int src_it = 0;
  for (; tid < ntchw; tid += stride) {
D
dengkaipeng 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    int in = tid / tchw;
    int it = (tid % tchw) / chw;
    int ic = (tid % chw) / hw;
    int ih = (tid % hw) / w;
    int iw = tid % w;

    const int c1 = static_cast<T>(c * shift_ratio);
    const int c2 = static_cast<T>(c * 2 * shift_ratio);

    if (ic < c1) {
      src_it = it - 1;
    } else if (ic < c2) {
      src_it = it + 1;
    } else {
      src_it = it;
    }

    if (src_it < 0 || src_it >= t) {
      output[tid] = 0;
    } else {
      int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
      output[tid] = input[src_idx];
    }
52 53 54 55
  }
}

template <typename T>
D
dengkaipeng 已提交
56 57 58 59 60
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
                                  const int ntchw, const int tchw,
                                  const int chw, const int hw, const int w,
                                  const int t, const int c,
                                  const float shift_ratio) {
61 62 63 64
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  int src_it = 0;
  for (; tid < ntchw; tid += stride) {
D
dengkaipeng 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    int in = tid / tchw;
    int it = (tid % tchw) / chw;
    int ic = (tid % chw) / hw;
    int ih = (tid % hw) / w;
    int iw = tid % w;

    const int c1 = static_cast<T>(c * shift_ratio);
    const int c2 = static_cast<T>(c * 2 * shift_ratio);

    if (ic < c1) {
      src_it = it - 1;
    } else if (ic < c2) {
      src_it = it + 1;
    } else {
      src_it = it;
    }

    if (src_it >= 0 && src_it < t) {
      int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
      input_grad[src_idx] = output_grad[tid];
    }
86 87 88 89 90 91 92 93 94 95 96 97
  }
}

template <typename T>
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    int t = ctx.Attr<int>("seg_num");
D
dengkaipeng 已提交
98
    float shift_ratio = ctx.Attr<float>("shift_ratio");
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

    const int nt = input->dims()[0];
    const int c = input->dims()[1];
    const int h = input->dims()[2];
    const int w = input->dims()[3];

    const int hw = h * w;
    const int chw = c * hw;
    const int tchw = t * chw;
    const int ntchw = nt * chw;

    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());

    int pixelNum = nt * chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;

    KeTemporalShiftFw<
D
dengkaipeng 已提交
118 119
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
120 121 122 123 124 125 126 127 128 129
  }
};

template <typename T>
class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    int t = ctx.Attr<int>("seg_num");
D
dengkaipeng 已提交
130
    float shift_ratio = ctx.Attr<float>("shift_ratio");
131 132 133 134 135 136 137 138 139 140 141 142

    const int nt = output_grad->dims()[0];
    const int c = output_grad->dims()[1];
    const int h = output_grad->dims()[2];
    const int w = output_grad->dims()[3];

    const int hw = h * w;
    const int chw = c * hw;
    const int tchw = t * chw;
    const int ntchw = nt * chw;

    const T* output_grad_data = output_grad->data<T>();
D
dengkaipeng 已提交
143 144
    T* input_grad_data =
        input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
145 146 147
    math::SetConstant<platform::CUDADeviceContext, T>()(
        ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
        static_cast<T>(0));
148 149 150 151 152 153

    int pixelNum = nt * chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;

    KeTemporalShiftBw<
D
dengkaipeng 已提交
154 155 156
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
        shift_ratio);
157 158 159 160 161 162 163 164 165 166 167 168
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
                        ops::TemporalShiftOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
                        ops::TemporalShiftGradOpCUDAKernel<float>,
                        ops::TemporalShiftGradOpCUDAKernel<double>);