/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" namespace ops = paddle::operators; namespace plat = paddle::platform; namespace paddle { namespace operators { template static __global__ void SimpleElemwiseAddGradCUDAKernel( const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; int loop = size / vec_size; int remainder = size % vec_size; const float4* dout_vec = reinterpret_cast(dout); float4* dx_vec = reinterpret_cast(dx); float4* dy_vec = reinterpret_cast(dy); float4 tmp_loop; for (int i = tid; i < loop; i += stride) { tmp_loop = dout_vec[i]; dx_vec[i] = tmp_loop; dy_vec[i] = tmp_loop; } if (tid == loop && remainder != 0) { T tmp_rem; while (remainder) { int idx = size - remainder; remainder--; tmp_rem = dout[idx]; dx[idx] = tmp_rem; dy[idx] = tmp_rem; } } } template typename std::enable_if< std::is_same::value>::type default_elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { int axis = ctx.Attr("axis"); auto* dout_data = dout->data(); // dx if (dx != nullptr) { auto* dx_data = dx->mutable_data(ctx.GetPlace()); if (dx->dims() == dout->dims()) { if (dx_data != dout_data) { framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dx); } } else { // For inplace strategy, dx will be stored in addr of dout, which makes // the result of dy wrong. if (dx->IsSharedBufferWith(*dout)) { dx->clear(); dx->mutable_data(x->dims(), ctx.GetPlace()); } std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); TensorReduceFunctorImpl>( *dout, dx, kps::IdentityFunctor(), reduce_dims, stream); } } // dy if (dy != nullptr) { auto* dy_data = dy->mutable_data(ctx.GetPlace()); if (dy->dims() == dout->dims()) { if (dy_data != dout_data) { framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dy); } } else { std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); TensorReduceFunctorImpl>( *dout, dy, kps::IdentityFunctor(), reduce_dims, stream); } } } template typename std::enable_if< std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { auto* dx_data = dx->mutable_data(ctx.GetPlace()); auto* dy_data = dy->mutable_data(ctx.GetPlace()); auto* dout_data = dout->data(); if (dx_data == dout_data && dy_data != dout_data) { VLOG(4) << "Special case when dx_data is the same as dout_data, " "only need copy dout to dy"; framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dy); } else if (dx_data != dout_data && dy_data == dout_data) { VLOG(4) << "Special case when dy_data is the same as dout_data, " "only need copy dout to dx"; framework::TensorCopy( *dout, ctx.GetPlace(), ctx.template device_context(), dx); } else if (dx_data != dout_data && dy_data != dout_data) { auto size = x->numel(); int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); dim3 grid_size = dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); SimpleElemwiseAddGradCUDAKernel< T><<().stream()>>>( dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); } else { VLOG(4) << "Special case when dy_data is the same as dout_data, " "and dx_data is the same as dout_data, do not need " "any operator"; } } } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_add, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel>, ops::ElementwiseAddKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel>, ops::ElementwiseAddGradKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad_grad, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel>, ops::ElementwiseAddDoubleGradKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_add_triple_grad, ops::ElementwiseAddTripleGradKernel, ops::ElementwiseAddTripleGradKernel, ops::ElementwiseAddTripleGradKernel, ops::ElementwiseAddTripleGradKernel, ops::ElementwiseAddTripleGradKernel, ops::ElementwiseAddTripleGradKernel>, ops::ElementwiseAddTripleGradKernel>); REGISTER_OP_CUDA_KERNEL( grad_add, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel>, ops::ElementwiseAddKernel>);