/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/temporal_shift_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { using framework::Tensor; template __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, const int tchw, const int chw, const int hw, const int w, const int t, const int c) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; for (; tid < ntchw; tid += stride) { int in = tid / tchw; int it = (tid % tchw) / chw; int ic = (tid % chw) / hw; int ih = (tid % hw) / w; int iw = tid % w; if (ic < c / 4) { src_it = it - 1; } else if (ic < c / 2) { src_it = it + 1; } else { src_it = it; } if (src_it < 0 || src_it >= t) { output[tid] = 0; } else { int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); output[tid] = input[src_idx]; } } } template __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, const int tchw, const int chw, const int hw, const int w, const int t, const int c) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; for (; tid < ntchw; tid += stride) { int in = tid / tchw; int it = (tid % tchw) / chw; int ic = (tid % chw) / hw; int ih = (tid % hw) / w; int iw = tid % w; if (ic < c / 4) { src_it = it - 1; } else if (ic < c / 2) { src_it = it + 1; } else { src_it = it; } if (src_it >= 0 && src_it < t) { int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); input_grad[src_idx] = output_grad[tid]; } } } template class TemporalShiftOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); int t = ctx.Attr("seg_num"); const int nt = input->dims()[0]; const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; const int ntchw = nt * chw; const T* input_data = input->data(); T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); int pixelNum = nt * chw; int grid_dim = (pixelNum + 512 - 1) / 512; grid_dim = grid_dim > 8 ? 8 : grid_dim; KeTemporalShiftFw< T><<>>( input_data, output_data, ntchw, tchw, chw, hw, w, t, c); } }; template class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); int t = ctx.Attr("seg_num"); const int nt = output_grad->dims()[0]; const int c = output_grad->dims()[1]; const int h = output_grad->dims()[2]; const int w = output_grad->dims()[3]; const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; const int ntchw = nt * chw; const T* output_grad_data = output_grad->data(); T* input_grad_data = input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), input_grad, static_cast(0)); int pixelNum = nt * chw; int grid_dim = (pixelNum + 512 - 1) / 512; grid_dim = grid_dim > 8 ? 8 : grid_dim; KeTemporalShiftBw< T><<>>( output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel, ops::TemporalShiftOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(temporal_shift_grad, ops::TemporalShiftGradOpCUDAKernel, ops::TemporalShiftGradOpCUDAKernel);