/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; namespace paddle { namespace operators { template static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, const T* out, const T* dout, int64_t size, T* dx, T* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { T o = dout[col]; dx[col] = o / y[col]; dy[col] = -o * out[col] / y[col]; col += blockDim.x * gridDim.x; } } template <> __global__ void SimpleElemwiseDivGradCUDAKernel>( const paddle::platform::complex* x, const paddle::platform::complex* y, const paddle::platform::complex* out, const paddle::platform::complex* dout, int64_t size, paddle::platform::complex* dx, paddle::platform::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { paddle::platform::complex o = dout[col]; paddle::platform::complex y_conj(y[col].real, -y[col].imag); paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, -(out[col] / y[col]).imag); dx[col] = o / y_conj; dy[col] = -o * out_div_y_conj; col += blockDim.x * gridDim.x; } } template <> __global__ void SimpleElemwiseDivGradCUDAKernel>( const paddle::platform::complex* x, const paddle::platform::complex* y, const paddle::platform::complex* out, const paddle::platform::complex* dout, int64_t size, paddle::platform::complex* dx, paddle::platform::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { paddle::platform::complex o = dout[col]; paddle::platform::complex y_conj(y[col].real, -y[col].imag); paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, -(out[col] / y[col]).imag); dx[col] = o / y_conj; dy[col] = -o * out_div_y_conj; col += blockDim.x * gridDim.x; } } template typename std::enable_if< std::is_same::value>::type elementwise_div_grad(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseDivGradCUDAKernel< T><<().stream()>>>( x->data(), y->data(), out->data(), dout->data(), size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); } } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel>, ops::ElementwiseDivKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel>, ops::ElementwiseDivGradKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad_grad, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel>, ops::ElementwiseDivDoubleGradKernel>);