/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/nearest_neighbor_interp_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { using framework::Tensor; template __global__ void KeNearestNeighborInterpFw( const T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < nthreads) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; int out_img_size = output_w / num_channels; int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; int in_img_idy = static_cast(round(ratio_h * out_img_idy)); int out_img_idx = tid % out_img_w; int in_img_idx = static_cast(round(ratio_w * out_img_idx)); out[tid] = in[out_id_h * input_w + channel_id * in_img_size + in_img_idy * in_img_w + in_img_idx]; } } template __global__ void KeNearestNeighborInterpBw( T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, const T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < nthreads) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; int out_img_size = output_w / num_channels; int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; int in_img_idy = static_cast(round(ratio_h * out_img_idy)); int out_img_idx = tid % out_img_w; int in_img_idx = static_cast(round(ratio_w * out_img_idx)); T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idy * in_img_w + in_img_idx]; const T out_pos = out[out_id_h * output_w + out_id_w]; atomicAdd(in_pos, out_pos); } } template class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); auto* input = ctx.Input("X"); // float tensor auto* output = ctx.Output("Out"); // float tensor auto* input_data = input->data(); int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); auto out_size = ctx.Input("OutSize"); if (out_size != nullptr) { Tensor sizes; framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); auto size_data = sizes.data(); out_h = size_data[0]; out_w = size_data[1]; } int n = input->dims()[0]; int c = input->dims()[1]; int in_h = input->dims()[2]; int in_w = input->dims()[3]; auto* output_data = output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); int in_hw = in_h * in_w; int out_hw = out_h * out_w; int in_chw = c * in_hw; int out_chw = c * out_hw; T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; if (in_h == out_h && in_w == out_w) { memcpy(output_data, input_data, input->numel() * sizeof(T)); return; } int threadNum = n * out_chw; int blocks = (threadNum + 1024 - 1) / 1024; KeNearestNeighborInterpFw< T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } }; template class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); auto* output_grad_data = output_grad->data(); auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); math::SetConstant zero; zero(device_ctx, input_grad, static_cast(0.0)); int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); auto out_size = ctx.Input("OutSize"); if (out_size != nullptr) { Tensor sizes; framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); auto size_data = sizes.data(); out_h = size_data[0]; out_w = size_data[1]; } int n = input_grad->dims()[0]; int c = input_grad->dims()[1]; int in_h = input_grad->dims()[2]; int in_w = input_grad->dims()[3]; int in_hw = in_h * in_w; int out_hw = out_h * out_w; int in_chw = c * in_hw; int out_chw = c * out_hw; T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; if (in_h == out_h && in_w == out_w) { memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T)); return; } int threadNum = n * out_chw; int blocks = (threadNum + 1024 - 1) / 1024; KeNearestNeighborInterpBw< T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp, ops::NearestNeighborInterpOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad, ops::NearestNeighborInterpGradOpCUDAKernel);