diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.kps b/paddle/fluid/operators/elementwise/elementwise_add_op.kps index d6e0749318e901947b46b4b1d6ff8bbdb16bef36..3b7457d72e15d733a45bc10ea433db1937dbac89 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.kps +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.kps @@ -39,7 +39,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #else #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/phi/kernels/gpu/elementwise.h" +#include "paddle/phi/kernels/gpu/elementwise_grad.h" #endif namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 418779c32e8bc216be1532bf714bc21d91c452aa..102127e6ffe4ea60b8305c718e645a3695557ae4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -16,9 +16,6 @@ #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -// only can include the headers in paddle/top/api dirs -#include "paddle/phi/kernels/gpu/elementwise.h" - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index a1a7f8310986616d0a9f7db572ed31ca44399027..61862aa9f87408048c5d31a13c0be8a013046902 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -31,6 +31,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/cpu/elementwise_grad.h" #if defined(__NVCC__) || defined(__HIPCC__) #ifdef __NVCC__ @@ -133,7 +134,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, inline framework::DDim trim_trailing_singular_dims( const framework::DDim &dims) { - return phi::funcs::trim_trailing_singular_dims(dims); + return phi::funcs::TrimTrailingSingularDims(dims); } template ( dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { - phi::ElemwiseGradComputeWithBroadcast( + phi::funcs::ElemwiseGradComputeWithBroadcast( dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } } @@ -173,19 +174,9 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, const framework::Tensor *y, int axis, Functor func, framework::Tensor *z) { z->mutable_data(ctx.GetPlace()); - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - const auto &dev_ctx = - ctx.template device_context(); - phi::ElementwiseCompute(dev_ctx, *x, *y, axis, func, - z); - -#endif - return; - } - const auto &dev_ctx = - ctx.template device_context(); - phi::ElementwiseCompute(dev_ctx, *x, *y, axis, func, z); + const auto &dev_ctx = ctx.template device_context(); + phi::funcs::ElementwiseCompute(dev_ctx, *x, *y, axis, + func, z); } // FusedElemwiseAndAct @@ -443,8 +434,8 @@ void FusedElemwiseAndActComputeWithBroadcast( axis = (y_dim.size() == 0) ? x_dim.size() : axis; int pre, n, post, is_run_common_broadcast; - phi::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, - &is_run_common_broadcast); + phi::funcs::GetMidDims(x_dim, y_dim, axis, &pre, &n, &post, + &is_run_common_broadcast); if (post == 1) { int h = pre; int w = n; @@ -991,8 +982,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast( axis = (y_dim.size() == 0) ? x_dim.size() : axis; int pre, n, post, is_run_common_broadcast; - phi::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, - &is_run_common_broadcast); + phi::funcs::GetMidDims(x_dim, y_dim, axis, &pre, &n, &post, + &is_run_common_broadcast); const T *x_data = nullptr; const T *y_data = nullptr; if (x->IsInitialized()) x_data = x->data(); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 7d7bb4f26fcf42ec63cd1fab7ec2667a03c8ba4c..f49e2ab4e173efbd2cb8a33ec3e7471faff11154 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -19,7 +19,7 @@ limitations under the License. */ // only can include the headers in paddle/top/api dirs #include "paddle/phi/api/lib/utils/tensor_utils.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/viterbi_decode_op.h b/paddle/fluid/operators/viterbi_decode_op.h index 8f01a0c36043b7a12f77d09c4aab0b70cdc0eccb..bf12a03e7b4dc13d0e1da2be96b4bbd35efb31af 100644 --- a/paddle/fluid/operators/viterbi_decode_op.h +++ b/paddle/fluid/operators/viterbi_decode_op.h @@ -151,12 +151,12 @@ struct GetInputIndex { const std::vector& output_strides, int output_idx, int* index_array, int* lhs_idx, int* rhs_idx) { int out_dims_size = output_strides.size(); - *lhs_idx = - phi::GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array); - *rhs_idx = - phi::GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array); - phi::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, - index_array); + *lhs_idx = phi::funcs::GetElementwiseIndex(lhs_dims.data(), out_dims_size, + index_array); + *rhs_idx = phi::funcs::GetElementwiseIndex(rhs_dims.data(), out_dims_size, + index_array); + phi::funcs::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, + index_array); } }; diff --git a/paddle/phi/kernels/cpu/elementwise.h b/paddle/phi/kernels/cpu/elementwise.h index 28bf5ab743f6d5d0608fe65c00d5a0de2af3415b..0f67df661136dc659c28da3855b661e4a7df2af0 100644 --- a/paddle/phi/kernels/cpu/elementwise.h +++ b/paddle/phi/kernels/cpu/elementwise.h @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -189,250 +189,6 @@ struct SameDimsMultiplyFunctor< } }; -inline void UpdateElementwiseIndexArray(const int* out_dims_array, - const int max_dim, - int* index_array) { - for (int i = max_dim - 1; i >= 0; --i) { - ++index_array[i]; - if (index_array[i] >= out_dims_array[i]) { - index_array[i] -= out_dims_array[i]; - } else { - break; - } - } -} - -inline int GetElementwiseIndex(const int* x_dims_array, - const int max_dim, - const int* index_array) { - int index_ = 0; - for (int i = 0; i < max_dim; i++) { - if (x_dims_array[i] > 1) { - index_ = index_ * x_dims_array[i] + index_array[i]; - } - } - return index_; -} - -template -void CommonGradBroadcastCPU(const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - DenseTensor* dx, - DenseTensor* dy, - int* x_dims_array, - int* y_dims_array, - int* out_dims_array, - int max_dim, - const CPUContext& ctx, - DX_OP dx_op, - DY_OP dy_op) { - std::vector index_array(max_dim, 0); - const T* x_data = x.data(); - const T* y_data = y.data(); - const Tout* out_data = out.data(); - const Tout* dout_data = dout.data(); - T* dx_data = dx == nullptr ? nullptr : ctx.Alloc(dx); - T* dy_data = dy == nullptr ? nullptr : ctx.Alloc(dy); - if (dx_data != nullptr) { - memset(dx_data, 0, dx->numel() * sizeof(T)); - } - if (dy_data != nullptr) { - memset(dy_data, 0, dy->numel() * sizeof(T)); - } - const int out_size = std::accumulate( - out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); - int x_index, y_index; - for (int out_index = 0; out_index < out_size; ++out_index) { - x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); - y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); - if (dx_data != nullptr) { - dx_data[x_index] += dx_op(x_data[x_index], - y_data[y_index], - out_data[out_index], - dout_data[out_index]); - } - if (dy_data != nullptr) { - dy_data[y_index] += dy_op(x_data[x_index], - y_data[y_index], - out_data[out_index], - dout_data[out_index]); - } - - UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); - } -} - -template -void CommonForwardBroadcastCPU(const DenseTensor& x, - const DenseTensor& y, - DenseTensor* z, - int* x_dims_array, - int* y_dims_array, - int* out_dims_array, - int max_dim, - const CPUContext& ctx, - Functor func, - const bool is_xsize_larger = true) { - std::vector index_array(max_dim, 0); - const T* x_data = x.data(); - const T* y_data = y.data(); - PADDLE_ENFORCE_NOT_NULL( - x_data, phi::errors::InvalidArgument("The input X should not be empty.")); - PADDLE_ENFORCE_NOT_NULL( - y_data, phi::errors::InvalidArgument("The input Y should not be empty.")); - OutType* out_data = ctx.Alloc(z); - - const int out_size = std::accumulate( - out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); - int x_index, y_index; - for (int out_index = 0; out_index < out_size; ++out_index) { - x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); - y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); - if (is_xsize_larger) { - out_data[out_index] = func(x_data[x_index], y_data[y_index]); - } else { - out_data[out_index] = func(y_data[y_index], x_data[x_index]); - } - - UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); - } -} - -template -void CommonElementwiseBroadcastForward(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* z, - const DDim& x_dims, - const DDim& y_dims, - Functor func, - int axis, - const bool is_xsize_larger = true) { - int max_dim = (std::max)(x_dims.size(), y_dims.size()); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - PADDLE_ENFORCE_GE( - axis, - 0, - phi::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - funcs::GetBroadcastDimsArrays(x_dims, - y_dims, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - - CommonForwardBroadcastCPU(x, - y, - z, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - dev_ctx, - func, - is_xsize_larger); -} - -// It is a common CPU implementation to compute binary calculation with the -// support of broadcast. Note: -// 1. CPU implementation cannot support the case when x needs broadcast, thus -// this function need to be called with XxxFunctor and XxxInverseFunctor, -// like AddFunctor and InverseAddFunctor. -// 2. The corresponding GPU implementation supports all the broadcast cases, -// thus there is no need to define and call with XxxInverseFunctor. -// TODO(liuyiqun): optimize the CPU implementation to support all broadcast -// cases and avoid the need of XxxInverseFunctor. -template -void ElementwiseCompute(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - Functor func, - DenseTensor* z) { - dev_ctx.Alloc(z); - auto x_dims = x.dims(); - auto y_dims = y.dims(); - bool is_xsize_larger = true; - int max_dim = x_dims.size(); - if (x_dims.size() < y_dims.size()) { - is_xsize_larger = false; - max_dim = y_dims.size(); - } - funcs::TransformFunctor functor( - x, y, z, dev_ctx, func, is_xsize_larger); - if (x_dims == y_dims) { - functor.Run(); - return; - } - - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - PADDLE_ENFORCE_GE( - axis, - 0, - phi::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); - - int pre, n, post, is_run_common_broadcast, axis_trim = 0; - if (is_xsize_larger) { - auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims); - axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; - funcs::get_mid_dims(x_dims, - y_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); - } else { - auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims); - axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; - funcs::get_mid_dims(y_dims, - x_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); - } - // special case for common implementation. - // case 1: x=[2,3,1,5], y=[2,1,4,1] - // case 2: x=[2,3,4], y=[1,1,4] - if (is_run_common_broadcast == 1) { - CommonElementwiseBroadcastForward( - dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); - return; - } - - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } -} - template struct SameDimsElementwiseCompute { void operator()(const CPUContext& dev_ctx, @@ -443,377 +199,4 @@ struct SameDimsElementwiseCompute { } }; -// BACKWARD CODE - -template -static void ElemwiseGradBroadcast1CPU(const T* x, - const T* y, - const Tout* out, - const Tout* dout, - int h, - int w, - bool is_xsize_larger, - DX_OP dx_op, - DY_OP dy_op, - T* dx, - T* dy) { - if (is_xsize_larger) { - for (int i = 0; i < h; ++i) { - for (int j = 0; j < w; ++j) { - int x_offset = i * w + j; - if (dx != nullptr) { - dx[x_offset] = - dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy != nullptr) { - T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - if (i == 0) { - dy[j] = tmp; - } else { - dy[j] += tmp; - } - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int i = 0; i < h; ++i) { - for (int j = 0; j < w; ++j) { - int y_offset = i * w + j; - if (dy != nullptr) { - dy[y_offset] = - dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx != nullptr) { - T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - if (i == 0) { - dx[j] = tmp; - } else { - dx[j] += tmp; - } - } - } - } - } -} - -template -static void ElemwiseGradBroadcast2CPU(const T* x, - const T* y, - const Tout* out, - const Tout* dout, - int pre, - int n, - int post, - bool is_xsize_larger, - DX_OP dx_op, - DY_OP dy_op, - T* dx, - T* dy) { - if (is_xsize_larger) { - for (int i = 0; i < pre; ++i) { - for (int j = 0; j < n; ++j) { - for (int k = 0; k < post; ++k) { - int x_offset = i * n * post + j * post + k; - if (dx != nullptr) { - dx[x_offset] = - dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy != nullptr) { - T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - if (i == 0 && k == 0) { - dy[j] = tmp; - } else { - dy[j] += tmp; - } - } - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int i = 0; i < pre; ++i) { - for (int j = 0; j < n; ++j) { - for (int k = 0; k < post; ++k) { - int y_offset = i * n * post + j * post + k; - if (dy != nullptr) { - dy[y_offset] = - dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx != nullptr) { - T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - if (i == 0 && k == 0) { - dx[j] = tmp; - } else { - dx[j] += tmp; - } - } - } - } - } - } -} - -template -void CommonElementwiseBroadcastBackward(const CPUContext& ctx, - const DDim& x_dims, - const DDim& y_dims, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - int axis, - DenseTensor* dx, - DenseTensor* dy, - DX_OP dx_op, - DY_OP dy_op) { - int max_dim = std::max(x_dims.size(), y_dims.size()); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - funcs::GetBroadcastDimsArrays(x_dims, - y_dims, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - // for inplace strategy. memset will make dx and dout clear and get wrong - // result. - if (dx && dx->IsSharedBufferWith(dout)) { - dx->clear(); - dx->mutable_data(x_dims, ctx.GetPlace()); - } - - VLOG(3) << "CommonElementwiseBroadcastBackward xdims:" - << phi::make_ddim(x_dims_array) - << " ydim:" << phi::make_ddim(y_dims_array); - - CommonGradBroadcastCPU(x, - y, - out, - dout, - dx, - dy, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - ctx, - dx_op, - dy_op); -} - -template -void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx, - const DDim& x_dims, - const DDim& y_dims, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - int axis, - DenseTensor* dx, - DenseTensor* dy, - DX_OP dx_op, - DY_OP dy_op) { - bool is_xsize_larger = true; - - int max_dim = x_dims.size(); - if (x_dims.size() < y_dims.size()) { - is_xsize_larger = false; - max_dim = y_dims.size(); - } - - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - PADDLE_ENFORCE_GE( - axis, - 0, - phi::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); - - int pre, n, post, is_run_common_broadcast, axis_trim = 0; - if (is_xsize_larger) { - auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims); - axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; - funcs::get_mid_dims(x_dims, - y_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); - } else { - auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims); - axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; - funcs::get_mid_dims(y_dims, - x_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); - } - // special case for common backward implementation. - if (is_run_common_broadcast) { - CommonElementwiseBroadcastBackward( - ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); - return; - } - if (post == 1) { - ElemwiseGradBroadcast1CPU(x.data(), - y.data(), - out.data(), - dout.data(), - pre, - n, - is_xsize_larger, - dx_op, - dy_op, - dx == nullptr ? nullptr : ctx.Alloc(dx), - dy == nullptr ? nullptr : ctx.Alloc(dy)); - } else { - ElemwiseGradBroadcast2CPU(x.data(), - y.data(), - out.data(), - dout.data(), - pre, - n, - post, - is_xsize_larger, - dx_op, - dy_op, - dx == nullptr ? nullptr : ctx.Alloc(dx), - dy == nullptr ? nullptr : ctx.Alloc(dy)); - } -} - -// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub. -// explicit gradient can cut off X, Y, Out from gradient op -// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse -// elementwise code. -template -void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - int axis, - DenseTensor* dx, - DenseTensor* dy, - DX_OP dx_op, - DY_OP dy_op) { - const DDim& x_dim = x.dims(); - const DDim& y_dim = y.dims(); - if (x.dims() == y.dims()) { - phi::funcs::ElemwiseGradComputeNoBroadcast( - dev_ctx, - x_dim, - y_dim, - dout, - dout, - out, - dout, - axis, - dx, - dy, - dx_op, - dy_op); - } else { - ElemwiseGradComputeWithBroadcast(dev_ctx, - x_dim, - y_dim, - dout, - dout, - out, - dout, - axis, - dx, - dy, - dx_op, - dy_op); - } -} - -/* -****************************** - Add Grad -****************************** -*/ -template -struct IdentityGrad { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } -}; - -template -typename std::enable_if::value>::type -elementwise_add_grad(const CPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - DenseTensor* dx, - DenseTensor* dy, - int axis = -1) { - auto blas = phi::funcs::GetBlas(ctx); - if (dx) { - blas.VCOPY( - dout.numel(), dout.data(), dx->mutable_data(ctx.GetPlace())); - } - - if (dy) { - blas.VCOPY( - dout.numel(), dout.data(), dy->mutable_data(ctx.GetPlace())); - } -} - -template -typename std::enable_if::value>::type -elementwise_add_grad(const CPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - DenseTensor* dx, - DenseTensor* dy, - int axis = -1) { - ElemwiseExplicitGradCompute, IdentityGrad>( - ctx, x, y, out, dout, axis, dx, dy, IdentityGrad(), IdentityGrad()); -} - -/* -****************************** - Sub Grad -****************************** -*/ - -template -struct SubGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } -}; - -template -struct SubGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } -}; - -template -void elementwise_sub_grad(const CPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& dout, - DenseTensor* dx, - DenseTensor* dy, - int axis = -1) { - ElemwiseExplicitGradCompute, SubGradDY>( - ctx, x, y, out, dout, axis, dx, dy, SubGradDX(), SubGradDY()); -} - } // namespace phi diff --git a/paddle/phi/kernels/cpu/elementwise_grad.h b/paddle/phi/kernels/cpu/elementwise_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..92587566eb87591c5c35572fcba8a39af8445f5a --- /dev/null +++ b/paddle/phi/kernels/cpu/elementwise_grad.h @@ -0,0 +1,146 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" + +namespace phi { + +// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub. +// explicit gradient can cut off X, Y, Out from gradient op +// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse +// elementwise code. +template +void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DX_OP dx_op, + DY_OP dy_op) { + const DDim& x_dim = x.dims(); + const DDim& y_dim = y.dims(); + if (x.dims() == y.dims()) { + funcs::ElemwiseGradComputeNoBroadcast(dev_ctx, + x_dim, + y_dim, + dout, + dout, + out, + dout, + axis, + dx, + dy, + dx_op, + dy_op); + } else { + funcs::ElemwiseGradComputeWithBroadcast(dev_ctx, + x_dim, + y_dim, + dout, + dout, + out, + dout, + axis, + dx, + dy, + dx_op, + dy_op); + } +} + +/* +****************************** + Add Grad +****************************** +*/ +template +struct IdentityGrad { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +typename std::enable_if::value>::type +ElementwiseAddGrad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + auto blas = phi::funcs::GetBlas(ctx); + if (dx) { + blas.VCOPY( + dout.numel(), dout.data(), dx->mutable_data(ctx.GetPlace())); + } + + if (dy) { + blas.VCOPY( + dout.numel(), dout.data(), dy->mutable_data(ctx.GetPlace())); + } +} + +template +typename std::enable_if::value>::type +ElementwiseAddGrad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + ElemwiseExplicitGradCompute, IdentityGrad>( + ctx, x, y, out, dout, axis, dx, dy, IdentityGrad(), IdentityGrad()); +} + +/* +****************************** + Sub Grad +****************************** +*/ + +template +struct SubGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +struct SubGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } +}; + +template +void ElementwiseSubGrad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + ElemwiseExplicitGradCompute, SubGradDY>( + ctx, x, y, out, dout, axis, dx, dy, SubGradDX(), SubGradDY()); +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index c878e8133ffc0dc0c5e4992b315af48bc6cdaf03..e48ee805959088aa8a6a3da7fcb8fa02b642c1c8 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -17,7 +17,8 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" -#include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/cpu/elementwise_grad.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" @@ -33,7 +34,7 @@ void AddGradFunc(const CPUContext& dev_ctx, DenseTensor* dy, int axis = -1) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy); + ElementwiseAddGrad(dev_ctx, x, y, out, dout, dx, dy); } else { ElemwiseExplicitGradCompute, IdentityGrad>( dev_ctx, @@ -68,15 +69,7 @@ void AddDoubleGradKernel(const Context& dev_ctx, const DenseTensor& dout, int axis, DenseTensor* ddout) { - phi::AddDoubleGradImpl(dev_ctx, - y, - ddx, - ddy, - dout, - axis, - ddout, - ElementwiseCompute, T>, - ElementwiseCompute, T>); + phi::AddDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } template @@ -101,7 +94,7 @@ void SubtractGradKernel(const Context& dev_ctx, DenseTensor* dy) { // skip out auto* out = &dout; - elementwise_sub_grad(dev_ctx, x, y, *out, dout, dx, dy, axis); + ElementwiseSubGrad(dev_ctx, x, y, *out, dout, dx, dy, axis); } template @@ -112,15 +105,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, const DenseTensor& dout, int axis, DenseTensor* ddout) { - phi::SubtractDoubleGradImpl( - dev_ctx, - y, - ddx, - ddy, - dout, - axis, - ddout, - ElementwiseCompute, T>); + phi::SubtractDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/logical_kernel.cc b/paddle/phi/kernels/cpu/logical_kernel.cc index 3d179e1e75f4fa98057f32737f09025ce1d6b2fb..a0747b128e53899b77767298cab4fa37f31e495a 100644 --- a/paddle/phi/kernels/cpu/logical_kernel.cc +++ b/paddle/phi/kernels/cpu/logical_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/logical_functor.h" // See Note [ Why still include the fluid headers? ] @@ -24,15 +24,15 @@ namespace phi { -#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ - template \ - void Logical##type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - funcs::Logical##type##Functor binary_func; \ - ElementwiseCompute, T, bool>( \ - dev_ctx, x, y, -1, binary_func, out); \ +#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ + template \ + void Logical##type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + funcs::Logical##type##Functor binary_func; \ + funcs::ElementwiseCompute, T, bool>( \ + dev_ctx, x, y, -1, binary_func, out); \ } DEFINE_LOGICAL_BINARY_KERNEL(And) diff --git a/paddle/phi/kernels/cpu/math_kernel.cc b/paddle/phi/kernels/cpu/math_kernel.cc index 5cfcfe62c7816c84a4f2876942b4d9b30dfad167..250f656926c0536f71e5724eb9df779c1502a673 100644 --- a/paddle/phi/kernels/cpu/math_kernel.cc +++ b/paddle/phi/kernels/cpu/math_kernel.cc @@ -20,6 +20,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/reduce_functor.h" @@ -45,10 +46,10 @@ namespace phi { auto x_dims = x.dims(); \ auto y_dims = y.dims(); \ if (x_dims.size() >= y_dims.size()) { \ - ElementwiseCompute, T>( \ + funcs::ElementwiseCompute, T>( \ dev_ctx, x, y, axis, funcs::name##Functor(), out); \ } else { \ - ElementwiseCompute, T>( \ + funcs::ElementwiseCompute, T>( \ dev_ctx, x, y, axis, funcs::Inverse##name##Functor(), out); \ } \ } \ @@ -93,10 +94,10 @@ void DivideRawKernel(const Context& dev_ctx, auto x_dims = x.dims(); auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { - ElementwiseCompute, T>( + funcs::ElementwiseCompute, T>( dev_ctx, x, y, axis, funcs::DivideFunctor(), out); } else { - ElementwiseCompute, T>( + funcs::ElementwiseCompute, T>( dev_ctx, x, y, axis, funcs::InverseDivideFunctor(), out); } } diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 84a36b849afa1c4cdcc1a0f4d4ada598944a1faa..e9fd4cf47b834775c03e9b48ff1e3a5096228fb2 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -25,6 +25,8 @@ namespace kps = phi::kps; namespace phi { namespace funcs { +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) + struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)( @@ -183,8 +185,6 @@ struct DimensionsTransform { } }; -#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) - template __device__ __forceinline__ void LoadData( T *dst, @@ -578,6 +578,20 @@ void BroadcastKernel(const KPDevice &ctx, } } +template +void ElementwiseCompute(const GPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + int axis, + Functor func, + DenseTensor *z) { + std::vector ins = {&x, &y}; + std::vector outs = {z}; + z->mutable_data(dev_ctx.GetPlace()); + BroadcastKernel( + dev_ctx, ins, &outs, axis, func); +} + #endif } // namespace funcs diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index d369781f845eb0887817f83be761b1027fc0bab0..235dbdd40f6b7db5524251aec80b92cdc22aa819 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -18,7 +18,8 @@ limitations under the License. */ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/elementwise_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) @@ -44,28 +45,6 @@ using ConditionalT = namespace funcs { using DDim = phi::DDim; -template -struct ElemwiseGradNoBroadcast { - const T *x_; - const T *y_; - const Tout *out_; - const Tout *dout_; - - HOSTDEVICE void operator()(size_t i) { - if (dx_ != nullptr) { - dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); - } - if (dy_ != nullptr) { - dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); - } - } - - DX_OP dx_op_; - DY_OP dy_op_; - T *dx_; - T *dy_; -}; - template class RowwiseTransformIterator; @@ -293,73 +272,172 @@ class TransformFunctor { bool is_xsize_larger_; }; -inline DDim trim_trailing_singular_dims(const DDim &dims) { - // Remove trailing dimensions of size 1 for y - auto actual_dims_size = dims.size(); - for (; actual_dims_size != 0; --actual_dims_size) { - if (dims[actual_dims_size - 1] != 1) break; - } - if (actual_dims_size == dims.size()) return dims; - std::vector trim_dims; - trim_dims.resize(actual_dims_size); - for (int i = 0; i < actual_dims_size; ++i) { - trim_dims[i] = dims[i]; - } - if (trim_dims.size() == 0) { - return DDim(phi::make_dim()); +template +void CommonForwardBroadcastCPU(const DenseTensor &x, + const DenseTensor &y, + DenseTensor *z, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + int max_dim, + const CPUContext &ctx, + Functor func, + const bool is_xsize_larger = true) { + std::vector index_array(max_dim, 0); + const T *x_data = x.data(); + const T *y_data = y.data(); + PADDLE_ENFORCE_NOT_NULL( + x_data, errors::InvalidArgument("The input X should not be empty.")); + PADDLE_ENFORCE_NOT_NULL( + y_data, errors::InvalidArgument("The input Y should not be empty.")); + OutType *out_data = ctx.Alloc(z); + + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_index, y_index; + for (int out_index = 0; out_index < out_size; ++out_index) { + x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); + y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); + if (is_xsize_larger) { + out_data[out_index] = func(x_data[x_index], y_data[y_index]); + } else { + out_data[out_index] = func(y_data[y_index], x_data[x_index]); + } + + UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); } - DDim actual_dims = phi::make_ddim(trim_dims); - return actual_dims; } -/* - * Out = X ⊙ Y - * If Y's shape does not match X' shape, they will be reshaped. - * For example: - * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 - * pre=2, n=3*4, post=5 - * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) - * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) - * pre=2*3, n=4*5, post=1 - * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) - * - * New parameter: *is_run_common_broadcast* is a flag to record whether to run - * common broadcast code. - */ -inline void get_mid_dims(const DDim &x_dims, - const DDim &y_dims, - const int axis, - int *pre, - int *n, - int *post, - int *is_run_common_broadcast) { - *pre = 1; - *n = 1; - *post = 1; - *is_run_common_broadcast = 0; - for (int i = 0; i < axis; ++i) { - (*pre) *= x_dims[i]; - } - for (int i = 0; i < y_dims.size(); ++i) { - if (x_dims[i + axis] != y_dims[i]) { - PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, - true, - phi::errors::InvalidArgument( - "Broadcast dimension mismatch. Operands " - "could not be broadcast together with the shape of " - "X = [%s] and the shape of Y = [%s]. Received [%d] " - "in X is not equal to [%d] in Y.", - x_dims, - y_dims, - x_dims[i + axis], - y_dims[i])); - *is_run_common_broadcast = 1; - return; - } - (*n) *= y_dims[i]; - } - for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { - (*post) *= x_dims[i]; +template +void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *z, + const DDim &x_dims, + const DDim &y_dims, + Functor func, + int axis, + const bool is_xsize_larger = true) { + int max_dim = (std::max)(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, + 0, + phi::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + max_dim, + phi::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, + axis)); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + + CommonForwardBroadcastCPU(x, + y, + z, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + dev_ctx, + func, + is_xsize_larger); +} + +// It is a common CPU implementation to compute binary calculation with the +// support of broadcast. Note: +// 1. CPU implementation cannot support the case when x needs broadcast, thus +// this function need to be called with XxxFunctor and XxxInverseFunctor, +// like AddFunctor and InverseAddFunctor. +// 2. The corresponding GPU implementation supports all the broadcast cases, +// thus there is no need to define and call with XxxInverseFunctor. +// TODO(liuyiqun): optimize the CPU implementation to support all broadcast +// cases and avoid the need of XxxInverseFunctor. +template +void ElementwiseCompute(const CPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + int axis, + Functor func, + DenseTensor *z) { + dev_ctx.Alloc(z); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + bool is_xsize_larger = true; + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + TransformFunctor functor( + x, y, z, dev_ctx, func, is_xsize_larger); + if (x_dims == y_dims) { + functor.Run(); + return; + } + + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, + 0, + errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, + axis)); + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = TrimTrailingSingularDims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + GetMidDims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = TrimTrailingSingularDims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + GetMidDims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common implementation. + // case 1: x=[2,3,1,5], y=[2,1,4,1] + // case 2: x=[2,3,4], y=[1,1,4] + if (is_run_common_broadcast == 1) { + CommonElementwiseBroadcastForward( + dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); + return; + } + + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; } } @@ -395,41 +473,11 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); *ddx_safe = phi::Empty(dev_ctx, std::move(meta)); ddx_safe->mutable_data(dev_ctx.GetPlace()); - phi::funcs::SetConstant set_zero; + SetConstant set_zero; set_zero(dev_ctx, ddx_safe, static_cast(0)); } } -template -void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx, - const DDim &x_dim, - const DDim &y_dim, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &out, - const DenseTensor &dout, - int axis, - DenseTensor *dx, - DenseTensor *dy, - DX_OP dx_op, - DY_OP dy_op) { - size_t N = static_cast(phi::product(x_dim)); - phi::funcs::ForRange for_range(dev_ctx, N); - for_range(ElemwiseGradNoBroadcast{ - x.data(), - y.data(), - out.data(), - dout.data(), - dx_op, - dy_op, - dx == nullptr ? nullptr : dev_ctx.template Alloc(dx), - dy == nullptr ? nullptr : dev_ctx.template Alloc(dy)}); -} - inline void ElementwiseGradPreProcess(const DenseTensor &dout, DenseTensor *dx) { if (dx != nullptr) { @@ -806,6 +854,7 @@ void ElementwiseKernel(const KPDevice &ctx, } } } + #endif } // namespace funcs diff --git a/paddle/phi/kernels/gpu/elementwise.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h similarity index 78% rename from paddle/phi/kernels/gpu/elementwise.h rename to paddle/phi/kernels/funcs/elementwise_grad_base.h index 12cafc7023bb5100d5f619aeec29a357a13e4935..dff0cfe5b8b90155f62aa6e465c1fd6450d64807 100644 --- a/paddle/phi/kernels/gpu/elementwise.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -14,16 +14,25 @@ limitations under the License. */ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/funcs/elementwise_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" + +#endif #ifdef __HIPCC__ constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; #else constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif + #define BLOCK_X 32 #define BLOCK_Y 32 @@ -36,21 +45,361 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; namespace phi { -// General binary elementwise comutaion with the support of broadcast. -template -void ElementwiseCompute(const GPUContext &dev_ctx, - const DenseTensor &x, - const DenseTensor &y, - int axis, - Functor func, - DenseTensor *z) { - std::vector ins = {&x, &y}; - std::vector outs = {z}; - z->mutable_data(dev_ctx.GetPlace()); - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, axis, func); +namespace funcs { +using DDim = phi::DDim; + +template +void CommonGradBroadcastCPU(const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + int max_dim, + const CPUContext &ctx, + DX_OP dx_op, + DY_OP dy_op) { + std::vector index_array(max_dim, 0); + const T *x_data = x.data(); + const T *y_data = y.data(); + const Tout *out_data = out.data(); + const Tout *dout_data = dout.data(); + T *dx_data = dx == nullptr ? nullptr : ctx.Alloc(dx); + T *dy_data = dy == nullptr ? nullptr : ctx.Alloc(dy); + if (dx_data != nullptr) { + memset(dx_data, 0, dx->numel() * sizeof(T)); + } + if (dy_data != nullptr) { + memset(dy_data, 0, dy->numel() * sizeof(T)); + } + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_index, y_index; + for (int out_index = 0; out_index < out_size; ++out_index) { + x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); + y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); + if (dx_data != nullptr) { + dx_data[x_index] += dx_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index]); + } + if (dy_data != nullptr) { + dy_data[y_index] += dy_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index]); + } + + UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); + } +} + +template +static void ElemwiseGradBroadcast1CPU(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + if (is_xsize_larger) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int x_offset = i * w + j; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int y_offset = i * w + j; + if (dy != nullptr) { + dy[y_offset] = + dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx != nullptr) { + T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + if (i == 0) { + dx[j] = tmp; + } else { + dx[j] += tmp; + } + } + } + } + } +} + +template +static void ElemwiseGradBroadcast2CPU(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int pre, + int n, + int post, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + if (is_xsize_larger) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int x_offset = i * n * post + j * post + k; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0 && k == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int y_offset = i * n * post + j * post + k; + if (dy != nullptr) { + dy[y_offset] = + dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx != nullptr) { + T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + if (i == 0 && k == 0) { + dx[j] = tmp; + } else { + dx[j] += tmp; + } + } + } + } + } + } +} + +template +void CommonElementwiseBroadcastBackward(const CPUContext &ctx, + const DDim &x_dims, + const DDim &y_dims, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + // for inplace strategy. memset will make dx and dout clear and get wrong + // result. + if (dx && dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x_dims, ctx.GetPlace()); + } + + VLOG(3) << "CommonElementwiseBroadcastBackward xdims:" + << phi::make_ddim(x_dims_array) + << " ydim:" << phi::make_ddim(y_dims_array); + + CommonGradBroadcastCPU(x, + y, + out, + dout, + dx, + dy, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + ctx, + dx_op, + dy_op); +} + +template +void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, + const DDim &x_dims, + const DDim &y_dims, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + bool is_xsize_larger = true; + + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, + 0, + errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, + axis)); + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = TrimTrailingSingularDims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + GetMidDims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = TrimTrailingSingularDims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + GetMidDims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common backward implementation. + if (is_run_common_broadcast) { + CommonElementwiseBroadcastBackward( + ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + return; + } + if (post == 1) { + ElemwiseGradBroadcast1CPU(x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : ctx.Alloc(dx), + dy == nullptr ? nullptr : ctx.Alloc(dy)); + } else { + ElemwiseGradBroadcast2CPU(x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + post, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : ctx.Alloc(dx), + dy == nullptr ? nullptr : ctx.Alloc(dy)); + } +} + +template +struct ElemwiseGradNoBroadcast { + const T *x_; + const T *y_; + const Tout *out_; + const Tout *dout_; + + HOSTDEVICE void operator()(size_t i) { + if (dx_ != nullptr) { + dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); + } + if (dy_ != nullptr) { + dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); + } + } + + DX_OP dx_op_; + DY_OP dy_op_; + T *dx_; + T *dy_; +}; + +template +void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx, + const DDim &x_dim, + const DDim &y_dim, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + size_t N = static_cast(phi::product(x_dim)); + phi::funcs::ForRange for_range(dev_ctx, N); + for_range(ElemwiseGradNoBroadcast{ + x.data(), + y.data(), + out.data(), + dout.data(), + dx_op, + dy_op, + dx == nullptr ? nullptr : dev_ctx.template Alloc(dx), + dy == nullptr ? nullptr : dev_ctx.template Alloc(dy)}); } +#if defined(__NVCC__) || defined(__HIPCC__) // Suppose only has contiguous dims static inline bool CheckContiguousDims(const std::vector &broadcast_pos) { for (int i = 1; i < broadcast_pos.size(); ++i) { @@ -114,7 +463,6 @@ inline void ComputeBroadcastKernelSize(int *x_dims_array, } } -#ifndef __xpu__ template static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, const T *y, @@ -1282,13 +1630,13 @@ void CommonElementwiseBroadcastBackward(const GPUContext &ctx, std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); std::vector out_dims_array(max_dim); - funcs::GetBroadcastDimsArrays(x_dims, - y_dims, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); // for inplace strategy. memset will make dx and dout clear and get wrong // result. if (dx && dx->IsSharedBufferWith(dout)) { @@ -1340,37 +1688,37 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, PADDLE_ENFORCE_GE( axis, 0, - phi::errors::InvalidArgument( + errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - phi::errors::InvalidArgument( + errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { - auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims); + auto y_dims_trimed = TrimTrailingSingularDims(y_dims); axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; - funcs::get_mid_dims(x_dims, - y_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); + GetMidDims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); } else { - auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims); + auto x_dims_trimed = TrimTrailingSingularDims(x_dims); axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; - funcs::get_mid_dims(y_dims, - x_dims_trimed, - axis_trim, - &pre, - &n, - &post, - &is_run_common_broadcast); + GetMidDims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); } // special case for common backward implementation. if (is_run_common_broadcast) { @@ -1408,228 +1756,7 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, } } -/* -****************************** - Add Grad -****************************** -*/ - -template -static __global__ void SimpleElemwiseAddGradCUDAKernel( - const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { - int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X; - int stride = GRID_NUM_X * BLOCK_NUM_X; - int loop = size / vec_size; - int remainder = size % vec_size; - const float4 *dout_vec = reinterpret_cast(dout); - float4 *dx_vec = reinterpret_cast(dx); - float4 *dy_vec = reinterpret_cast(dy); - float4 tmp_loop; - - for (int i = tid; i < loop; i += stride) { - tmp_loop = dout_vec[i]; - dx_vec[i] = tmp_loop; - dy_vec[i] = tmp_loop; - } - - if (tid == loop && remainder != 0) { - T tmp_rem; - while (remainder) { - int idx = size - remainder; - remainder--; - tmp_rem = dout[idx]; - dx[idx] = tmp_rem; - dy[idx] = tmp_rem; - } - } -} - -template -void default_elementwise_add_grad(const GPUContext &ctx, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dx, - DenseTensor *dy, - int axis = -1) { - auto *dout_data = dout.data(); - - // dx - if (dx != nullptr) { - auto *dx_data = dx->mutable_data(ctx.GetPlace()); - if (dx->dims() == dout.dims()) { - if (dx_data != dout_data) { - phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); - } - } else { - // For inplace strategy, dx will be stored in addr of dout, which makes - // the result of dy wrong. - if (dx->IsSharedBufferWith(dout)) { - dx->clear(); - dx->mutable_data(x.dims(), ctx.GetPlace()); - } - std::vector reduce_dims = - funcs::GetReduceDim(x.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - kernels::TensorReduceImpl>( - ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); - } - } - // dy - if (dy != nullptr) { - auto *dy_data = dy->mutable_data(ctx.GetPlace()); - if (dy->dims() == dout.dims()) { - if (dy_data != dout_data) { - phi::Copy(ctx, dout, ctx.GetPlace(), false, dy); - } - } else { - std::vector reduce_dims = - funcs::GetReduceDim(y.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - kernels::TensorReduceImpl>( - ctx, dout, dy, kps::IdentityFunctor(), reduce_dims, stream); - } - } -} - -template -void elementwise_add_grad(const GPUContext &ctx, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dx, - DenseTensor *dy) { - auto *dx_data = dx->mutable_data(ctx.GetPlace()); - auto *dy_data = dy->mutable_data(ctx.GetPlace()); - auto *dout_data = dout.data(); - if (dx_data == dout_data && dy_data != dout_data) { - VLOG(4) << "Special case when dx_data is the same as dout_data, " - "only need copy dout to dy"; - phi::Copy(ctx, dout, ctx.GetPlace(), false, dy); - } else if (dx_data != dout_data && dy_data == dout_data) { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "only need copy dout to dx"; - phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); - } else if (dx_data != dout_data && dy_data != dout_data) { - auto size = x.numel(); - int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / - PREDEFINED_BLOCK_SIZE, - 1); - SimpleElemwiseAddGradCUDAKernel< - T><<>>( - dout.data(), - size, - vec_size, - dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); - } else { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "and dx_data is the same as dout_data, do not need " - "any operator"; - } -} - -/* -****************************** - Sub Grad -****************************** -*/ - -template -static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout, - int64_t size, - T *dx, - T *dy) { - int col = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X; - - while (col < size) { - if (dx != nullptr) { - dx[col] = dout[col]; - } - dy[col] = -dout[col]; - col += BLOCK_NUM_X * GRID_NUM_X; - } -} - -template -void default_elementwise_sub_grad(const GPUContext &ctx, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dx, - DenseTensor *dy, - int axis = -1) { - auto *dout_data = dout.data(); - // dx - if (dx != nullptr) { - auto *dx_data = dx->mutable_data(ctx.GetPlace()); - if (dx->dims() == dout.dims()) { - if (dx_data != dout_data) { - phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); - } - } else { - // For inplace strategy, dx will be stored in addr of dout, which makes - // the result of dy wrong. - if (dx->IsSharedBufferWith(dout)) { - dx->clear(); - dx->mutable_data(x.dims(), ctx.GetPlace()); - } - std::vector reduce_dims = - funcs::GetReduceDim(x.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - kernels::TensorReduceImpl>( - ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); - } - } - // dy - if (dy != nullptr) { - auto *dy_data = dy->mutable_data(ctx.GetPlace()); - if (dy->dims() == dout.dims()) { - if (dy_data != dout_data) { - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - auto size = dy->numel(); - dim3 grid_size = - dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); - SimpleElemwiseSubGradCUDAKernel< - T><<>>( - dout.data(), size, nullptr, dy->mutable_data(ctx.GetPlace())); - } - } else { - std::vector reduce_dims = - funcs::GetReduceDim(y.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - kernels::TensorReduceImpl>( - ctx, dout, dy, kps::InverseFunctor(), reduce_dims, stream); - } - } -} - -template -void elementwise_sub_grad(const GPUContext &ctx, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &out, - const DenseTensor &dout, - DenseTensor *dx, - DenseTensor *dy) { - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - auto size = x.numel(); - dim3 grid_size = - dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); - SimpleElemwiseSubGradCUDAKernel< - T><<>>( - dout.data(), - size, - dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); -} - #endif +} // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_utils.h b/paddle/phi/kernels/funcs/elementwise_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3790044346dc42ac8772505dbafdb076db615491 --- /dev/null +++ b/paddle/phi/kernels/funcs/elementwise_utils.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { + +namespace funcs { + +using DDim = phi::DDim; + +/* + * Out = X ⊙ Y + * If Y's shape does not match X' shape, they will be reshaped. + * For example: + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) + * + * New parameter: *is_run_common_broadcast* is a flag to record whether to run + * common broadcast code. + */ +inline void GetMidDims(const DDim &x_dims, + const DDim &y_dims, + const int axis, + int *pre, + int *n, + int *post, + int *is_run_common_broadcast) { + *pre = 1; + *n = 1; + *post = 1; + *is_run_common_broadcast = 0; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + if (x_dims[i + axis] != y_dims[i]) { + PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, + true, + phi::errors::InvalidArgument( + "Broadcast dimension mismatch. Operands " + "could not be broadcast together with the shape of " + "X = [%s] and the shape of Y = [%s]. Received [%d] " + "in X is not equal to [%d] in Y.", + x_dims, + y_dims, + x_dims[i + axis], + y_dims[i])); + *is_run_common_broadcast = 1; + return; + } + (*n) *= y_dims[i]; + } + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } +} + +inline DDim TrimTrailingSingularDims(const DDim &dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + if (actual_dims_size == dims.size()) return dims; + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (int i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return DDim(phi::make_dim()); + } + DDim actual_dims = phi::make_ddim(trim_dims); + return actual_dims; +} + +inline int GetElementwiseIndex(const int *x_dims_array, + const int max_dim, + const int *index_array) { + int index_ = 0; + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] > 1) { + index_ = index_ * x_dims_array[i] + index_array[i]; + } + } + return index_; +} + +inline void UpdateElementwiseIndexArray(const int *out_dims_array, + const int max_dim, + int *index_array) { + for (int i = max_dim - 1; i >= 0; --i) { + ++index_array[i]; + if (index_array[i] >= out_dims_array[i]) { + index_array[i] -= out_dims_array[i]; + } else { + break; + } + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..b17196b6b11566927a02b81aa7447ba77c1884ff --- /dev/null +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -0,0 +1,246 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" +#include "paddle/phi/kernels/gpu/reduce.h" + +namespace phi { + +/* +****************************** + Add Grad +****************************** +*/ + +template +static __global__ void SimpleElemwiseAddGradCUDAKernel( + const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { + int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X; + int stride = GRID_NUM_X * BLOCK_NUM_X; + int loop = size / vec_size; + int remainder = size % vec_size; + const float4 *dout_vec = reinterpret_cast(dout); + float4 *dx_vec = reinterpret_cast(dx); + float4 *dy_vec = reinterpret_cast(dy); + float4 tmp_loop; + + for (int i = tid; i < loop; i += stride) { + tmp_loop = dout_vec[i]; + dx_vec[i] = tmp_loop; + dy_vec[i] = tmp_loop; + } + + if (tid == loop && remainder != 0) { + T tmp_rem; + while (remainder) { + int idx = size - remainder; + remainder--; + tmp_rem = dout[idx]; + dx[idx] = tmp_rem; + dy[idx] = tmp_rem; + } + } +} + +template +void DefaultElementwiseAddGrad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis = -1) { + auto *dout_data = dout.data(); + + // dx + if (dx != nullptr) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + if (dx->dims() == dout.dims()) { + if (dx_data != dout_data) { + phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dout, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x.dims(), ctx.GetPlace()); + } + std::vector reduce_dims = + funcs::GetReduceDim(x.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceImpl>( + ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + } + } + // dy + if (dy != nullptr) { + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + if (dy->dims() == dout.dims()) { + if (dy_data != dout_data) { + phi::Copy(ctx, dout, ctx.GetPlace(), false, dy); + } + } else { + std::vector reduce_dims = + funcs::GetReduceDim(y.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceImpl>( + ctx, dout, dy, kps::IdentityFunctor(), reduce_dims, stream); + } + } +} + +template +void ElementwiseAddGrad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy) { + ctx.template Alloc(dx); + ctx.template Alloc(dy); + auto *dx_data = dx->data(); + auto *dy_data = dy->data(); + auto *dout_data = dout.data(); + if (dx_data == dout_data && dy_data != dout_data) { + VLOG(4) << "Special case when dx_data is the same as dout_data, " + "only need copy dout to dy"; + phi::Copy(ctx, dout, ctx.GetPlace(), false, dy); + } else if (dx_data != dout_data && dy_data == dout_data) { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "only need copy dout to dx"; + phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); + } else if (dx_data != dout_data && dy_data != dout_data) { + auto size = x.numel(); + int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + dim3 grid_size = + dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / + PREDEFINED_BLOCK_SIZE, + 1); + SimpleElemwiseAddGradCUDAKernel< + T><<>>( + dout.data(), + size, + vec_size, + dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); + } else { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "and dx_data is the same as dout_data, do not need " + "any operator"; + } +} + +/* +****************************** + Sub Grad +****************************** +*/ + +template +static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout, + int64_t size, + T *dx, + T *dy) { + int col = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X; + + while (col < size) { + if (dx != nullptr) { + dx[col] = dout[col]; + } + dy[col] = -dout[col]; + col += BLOCK_NUM_X * GRID_NUM_X; + } +} + +template +void default_elementwise_sub_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis = -1) { + auto *dout_data = dout.data(); + // dx + if (dx != nullptr) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + if (dx->dims() == dout.dims()) { + if (dx_data != dout_data) { + phi::Copy(ctx, dout, ctx.GetPlace(), false, dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dout, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x.dims(), ctx.GetPlace()); + } + std::vector reduce_dims = + funcs::GetReduceDim(x.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceImpl>( + ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + } + } + // dy + if (dy != nullptr) { + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + if (dy->dims() == dout.dims()) { + if (dy_data != dout_data) { + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + auto size = dy->numel(); + dim3 grid_size = + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel< + T><<>>( + dout.data(), size, nullptr, dy->mutable_data(ctx.GetPlace())); + } + } else { + std::vector reduce_dims = + funcs::GetReduceDim(y.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceImpl>( + ctx, dout, dy, kps::InverseFunctor(), reduce_dims, stream); + } + } +} + +template +void elementwise_sub_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy) { + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + auto size = x.numel(); + dim3 grid_size = + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel< + T><<>>( + dout.data(), + size, + dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); +} +} // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 3c4c01b1dc8ff739ac87ca2e9fe7a6659ab4eac3..d00888aee67019befa99d49cd093a4e171f941c2 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -17,8 +17,9 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/gpu/elementwise.h" +#include "paddle/phi/kernels/gpu/elementwise_grad.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" namespace phi { @@ -33,9 +34,9 @@ void AddGradFunc(const GPUContext& dev_ctx, DenseTensor* dy, int axis = -1) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy); + ElementwiseAddGrad(dev_ctx, x, y, out, dout, dx, dy); } else { - default_elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy, axis); + DefaultElementwiseAddGrad(dev_ctx, x, y, out, dout, dx, dy, axis); } } @@ -58,15 +59,7 @@ void AddDoubleGradKernel(const Context& dev_ctx, const DenseTensor& dout, int axis, DenseTensor* ddout) { - phi::AddDoubleGradImpl(dev_ctx, - y, - ddx, - ddy, - dout, - axis, - ddout, - ElementwiseCompute, T>, - ElementwiseCompute, T>); + phi::AddDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } template @@ -106,15 +99,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, const DenseTensor& dout, int axis, DenseTensor* ddout) { - phi::SubtractDoubleGradImpl( - dev_ctx, - y, - ddx, - ddy, - dout, - axis, - ddout, - ElementwiseCompute, T>); + phi::SubtractDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/logical_kernel.cu b/paddle/phi/kernels/gpu/logical_kernel.cu index f32d4c77d4059f4c6c0157fc839d3fa345ed489c..1c0bafc932ee87756529325dedbc1394340e7dde 100644 --- a/paddle/phi/kernels/gpu/logical_kernel.cu +++ b/paddle/phi/kernels/gpu/logical_kernel.cu @@ -16,9 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/logical_functor.h" -#include "paddle/phi/kernels/gpu/elementwise.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/math_kernel.cu b/paddle/phi/kernels/gpu/math_kernel.cu index fc73ccca6de18ea169b60fc6e998d42a8cb03919..af9d5574aa9feaf4d44482bbf0e75f31a5139595 100644 --- a/paddle/phi/kernels/gpu/math_kernel.cu +++ b/paddle/phi/kernels/gpu/math_kernel.cu @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/gpu/elementwise.h" #include "paddle/phi/kernels/gpu/reduce.h" #ifdef __NVCC__ diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 460e74b58166a5132bdbd62703f4dc3d5ef34a91..ac7d6fd1a0e9ca18ef528399581a30bfe0be5597 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace phi { @@ -47,19 +47,14 @@ void AddGradImpl(const Context& dev_ctx, } } -template +template void AddDoubleGradImpl(const Context& dev_ctx, const DenseTensor& y, const paddle::optional& ddx, const paddle::optional& ddy, const DenseTensor& dout, int axis, - DenseTensor* ddout, - GradFunc grad_func, - GradInverseFunc grad_inverse_func) { + DenseTensor* ddout) { // ddOut = ddx + ddy if (ddout) { DenseTensor ddx_safe, ddy_safe; @@ -72,28 +67,28 @@ void AddDoubleGradImpl(const Context& dev_ctx, auto ddx_dims = ddx_safe.dims(); auto ddy_dims = ddy_safe.dims(); if (ddx_dims.size() >= ddy_dims.size()) { - grad_func( + funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor(), ddout); } else { - grad_inverse_func(dev_ctx, - ddx_safe, - ddy_safe, - axis, - funcs::InverseAddFunctor(), - ddout); + funcs::ElementwiseCompute, T>( + dev_ctx, + ddx_safe, + ddy_safe, + axis, + funcs::InverseAddFunctor(), + ddout); } } } -template +template void SubtractDoubleGradImpl(const Context& dev_ctx, const DenseTensor& y, const paddle::optional& ddx, const paddle::optional& ddy, const DenseTensor& dout, int axis, - DenseTensor* ddout, - GradFunc grad_func) { + DenseTensor* ddout) { // DDOut = ddx - ddy if (ddout) { DenseTensor ddx_safe, ddy_safe; @@ -103,7 +98,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx, dev_ctx, y, ddy.get_ptr(), &ddy_safe); ddout->mutable_data(dev_ctx.GetPlace()); - grad_func( + funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor(), ddout); } }