From 556d509791b2b0a6c12781f7ecb6bbf811ee3bec Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 14 Jan 2022 11:47:16 +0800 Subject: [PATCH] refactor impl of elementwise op part2 (#38898) --- .../elementwise/elementwise_op_function.h | 621 +------------- paddle/pten/kernels/cpu/elementwise.h | 144 ++++ paddle/pten/kernels/gpu/elementwise.h | 768 ++++++++++++++++++ 3 files changed, 919 insertions(+), 614 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 626046890fb..7cd04318d3f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -49,12 +49,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" -#define GetDivMod(dividend, divisor, div, mod) \ - do { \ - const auto dividend_copy = dividend; \ - *div = dividend_copy / divisor; \ - *mod = dividend_copy % divisor; \ - } while (0) #define DIVUP(x, y) (((x) + (y)-1) / (y)) @@ -138,613 +132,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, axis); } -template -void CommonForwardBroadcastCPU(const framework::Tensor *x, - const framework::Tensor *y, framework::Tensor *z, - int *x_dims_array, int *y_dims_array, - int *out_dims_array, int max_dim, - const platform::CPUDeviceContext &ctx, - Functor func, - const bool is_xsize_larger = true) { - pten::CommonForwardBroadcastCPU(x, y, z, x_dims_array, y_dims_array, - out_dims_array, max_dim, ctx, func, - is_xsize_larger); -} - -#if defined(__NVCC__) || defined(__HIPCC__) - -template -__global__ void CommonGradBroadcastCUDAKernel( - const int *x_strides_array, const int *y_strides_array, - const int *out_dims_array, const int *y_strides_order, - const int *y_dims_order, const T *x, const T *y, const Tout *out, - const Tout *dout, T *dx, int out_size, int max_dim, int thread_num, - DX_OP dx_op) { - T val(0); - int i = blockIdx.x; - int tid = threadIdx.x; - for (int j = tid; j < thread_num; j += blockDim.x) { - const int X_index = i * thread_num + j; - int out_index = X_index; - int C_index = 0; - int B_index = i * thread_num + j; - int remainder = 0; -#pragma unroll - for (int d = max_dim - 1; d >= 0; --d) { - GetDivMod(B_index, y_dims_order[d], &B_index, &remainder); - C_index += remainder * y_strides_order[d]; - } - int x_index = 0; - int y_index = 0; - int C_index_val = C_index; -#pragma unroll - for (int d = max_dim - 1; d >= 0; --d) { - GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder); - x_index += remainder * x_strides_array[d]; - y_index += remainder * y_strides_array[d]; - } - out_index = C_index; - val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]); - } - val = paddle::platform::reduceSum(val, tid, thread_num); - if (threadIdx.x == 0) { - dx[i] = val; - } -} - -template -void CommonGradBroadcastCUDA( - const framework::Tensor &x, const framework::Tensor &y, - const framework::Tensor &out, const framework::Tensor &dout, - framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array, - int *y_dims_array, int *out_dims_array, int max_dim, - const platform::CUDADeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) { - const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); - auto cplace = platform::CPUPlace(); - const T *x_data = x.data(); - const T *y_data = y.data(); - const Tout *out_data = out.data(); - const Tout *dout_data = dout.data(); - T *dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); - T *dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); - - std::vector x_one_indexs; - std::vector y_one_indexs; - for (int i = 0; i < max_dim; i++) { - if (x_dims_array[i] != y_dims_array[i]) { - if (x_dims_array[i] == 1) { - x_one_indexs.push_back(i); - } - if (y_dims_array[i] == 1) { - y_one_indexs.push_back(i); - } - } - } - - std::vector x_trans_indexs(max_dim); - std::vector y_trans_indexs(max_dim); - pten::ComputeBroadcastTranspositionArray( - x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size()); - pten::ComputeBroadcastTranspositionArray( - y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size()); - - // compute array stride for cuda kernel; - // e.g. x.dims=[2,3,4], x_stride=[12,4,1] - std::vector x_strides_array(max_dim); - std::vector y_strides_array(max_dim); - std::vector out_strides_array(max_dim); - int x_stride = 1; - int y_stride = 1; - int z_stride = 1; - for (int i = max_dim - 1; i >= 0; i--) { - x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride; - y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride; - out_strides_array[i] = z_stride; - x_stride *= x_dims_array[i]; - y_stride *= y_dims_array[i]; - z_stride *= out_dims_array[i]; - } - - std::vector x_strides_order(max_dim); - std::vector y_strides_order(max_dim); - std::vector x_dims_order(max_dim); - std::vector y_dims_order(max_dim); - for (int i = 0; i < max_dim; ++i) { - x_strides_order[i] = out_strides_array[x_trans_indexs[i]]; - y_strides_order[i] = out_strides_array[y_trans_indexs[i]]; - x_dims_order[i] = out_dims_array[x_trans_indexs[i]]; - y_dims_order[i] = out_dims_array[y_trans_indexs[i]]; - } - std::vector x_broadcast_pos; - std::vector y_broadcast_pos; - - int bytes = max_dim * sizeof(int); - - for (int i = 0; i < max_dim; ++i) { - if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) { - x_broadcast_pos.emplace_back(i); - } - if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) { - y_broadcast_pos.emplace_back(i); - } - } - - auto stream = ctx.stream(); - bool can_split_x = false; - bool can_split_y = false; - - auto FastCommonCUDAF = [&](const std::vector &broadcast_pos, bool is_y) { - int h = - std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(), - 1, std::multiplies()); - int w = - std::accumulate(out_dims_array + broadcast_pos.size(), - out_dims_array + max_dim, 1, std::multiplies()); - - VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h - << " is_y:" << is_y; - - int split_h; - int split_w; - int kh = h; - int kw = w; - - if (is_y) { - split_h = - std::accumulate(x_dims_array, x_dims_array + broadcast_pos.size(), 1, - std::multiplies()); - split_w = - std::accumulate(x_dims_array + broadcast_pos.size(), - x_dims_array + max_dim, 1, std::multiplies()); - - } else { - split_h = - std::accumulate(y_dims_array, y_dims_array + broadcast_pos.size(), 1, - std::multiplies()); - split_w = - std::accumulate(y_dims_array + broadcast_pos.size(), - y_dims_array + max_dim, 1, std::multiplies()); - } - - if (h > split_h) kh = split_h; - if (w > split_w) kw = split_w; - - if (is_y) { - if (w < 16 || h < 16) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int grid_size = w; - pten::CommonGradBroadcast1CUDAKernelHeight<<>>( - x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, - is_y); - } else { - dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - pten::FastCommonGradBroadcastCUDAKernelHeight<<>>( - x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, - is_y); - } - } else { - if (w < 16 || h < 16) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int grid_size = w; - pten::CommonGradBroadcast1CUDAKernelHeight<<>>( - x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, - is_y); - } else { - dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - pten::FastCommonGradBroadcastCUDAKernelHeight<<>>( - x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, - is_y); - } - } - }; - - auto FastBroadCastHeightCUDAF = [&](const std::vector &broadcast_pos, - bool x_large) { - int h = - std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(), - 1, std::multiplies()); - int w = - std::accumulate(out_dims_array + broadcast_pos.size(), - out_dims_array + max_dim, 1, std::multiplies()); - - VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h; - - if (w < 16 || h < 16) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int grid_size = w; - pten::ElemwiseGradBroadcast1CUDAKernel<<>>( - x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, - dx_data, dy_data); - } else { - dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - pten::FastElemwiseGradBroadcast1CUDAKernel<<>>( - x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, - dx_data, dy_data); - } - }; - - auto FastBroadCastAllCUDAF = [&](const std::vector &broadcast_pos, - int max_dim, bool is_x_large) { - int axis = broadcast_pos[0]; - int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1, - std::multiplies()); - int mid = 1; - int post = 1; - - if (broadcast_pos.size() == 1) { - mid = out_dims_array[axis]; - post = - std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, - 1, std::multiplies()); - } else { - mid = std::accumulate(out_dims_array + axis, - out_dims_array + broadcast_pos.back() + 1, 1, - std::multiplies()); - post = - std::accumulate(out_dims_array + broadcast_pos.back() + 1, - out_dims_array + max_dim, 1, std::multiplies()); - } - - VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid - << " post:" << post; - - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - - pten::FastCommonGradBroadcastAllCUDAKernel<<>>( - x_data, y_data, out_data, dout_data, pre, mid, post, is_x_large, dx_op, - dy_op, dx_data, dy_data); - }; - - auto FastBroadCastOneCUDAF = [&](const std::vector &broadcast_pos, - int max_dim, bool is_x) { - int axis = broadcast_pos[0]; - int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1, - std::multiplies()); - int mid = out_dims_array[axis]; - int post = - std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1, - std::multiplies()); - - int k_pre; - int k_mid; - int k_post; - - if (is_x) { - k_pre = std::accumulate(y_dims_array, y_dims_array + axis, 1, - std::multiplies()); - k_mid = y_dims_array[axis]; - k_post = std::accumulate(y_dims_array + axis + 1, y_dims_array + max_dim, - 1, std::multiplies()); - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - // we need to calc y offset with blockid, so do x_pre/y_pre to get left - // size. - if (k_pre != pre) k_pre = pre / k_pre; - - pten::FastCommonGradBroadcastOneCUDAKernel<<>>( - x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid, - k_post, true, dx_op, dx_data); - } else { - k_pre = std::accumulate(x_dims_array, x_dims_array + axis, 1, - std::multiplies()); - k_mid = x_dims_array[axis]; - k_post = std::accumulate(x_dims_array + axis + 1, x_dims_array + max_dim, - 1, std::multiplies()); - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - if (k_pre != pre) k_pre = pre / k_pre; - - pten::FastCommonGradBroadcastOneCUDAKernel<<>>( - x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid, - k_post, false, dy_op, dy_data); - } - VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid - << " post:" << post; - }; - - // do fast elementwise if: 1. only one input need to do broadcast, we can - // fallback - // to old fast path. - // 2. if both x and y need broadcast, then do it one by one. - bool fast_broadcast = false; - if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { - can_split_y = pten::SplitDims(y_broadcast_pos, max_dim); - if (can_split_y) { - // only y need to do broadcast on h - if (y_broadcast_pos[0] == 0) { - FastBroadCastHeightCUDAF(y_broadcast_pos, true); - fast_broadcast = true; - } - } else if (y_broadcast_pos.size() == 1 || - pten::CheckContiguousDims( - y_broadcast_pos)) { // for only one dim and - // contiguous broadcast. - // If cannot split, which means input has 3 parts - FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true); - fast_broadcast = true; - } - } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) { - // only x need broadcast - can_split_x = pten::SplitDims(x_broadcast_pos, max_dim); - if (can_split_x) { - if (x_broadcast_pos[0] == 0) { - FastBroadCastHeightCUDAF(x_broadcast_pos, false); - fast_broadcast = true; - } - } else if (x_broadcast_pos.size() == 1 || - pten::CheckContiguousDims(x_broadcast_pos)) { - FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false); - fast_broadcast = true; - } - } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { - // do x and y broadcast each. - can_split_y = pten::SplitDims(y_broadcast_pos, max_dim); - bool fast_broadcast_x = false; - bool fast_broadcast_y = false; - if (can_split_y) { - // begin at start. - if (y_broadcast_pos[0] == 0) { - FastCommonCUDAF(y_broadcast_pos, true); - fast_broadcast_y = true; - } - } else if (y_broadcast_pos.size() == 1) { - FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false); - can_split_y = true; - fast_broadcast_y = true; - } - can_split_x = pten::SplitDims(x_broadcast_pos, max_dim); - if (can_split_x) { - if (x_broadcast_pos[0] == 0) { - FastCommonCUDAF(x_broadcast_pos, false); - fast_broadcast_x = true; - } - } else if (x_broadcast_pos.size() == 1) { - FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true); - can_split_x = true; - fast_broadcast_x = true; - } - VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y - << " can_split_x:" << can_split_x; - // if both x and y into fast path then return - if (fast_broadcast_x && fast_broadcast_y) { - fast_broadcast = true; - } - if (can_split_y && can_split_x && fast_broadcast) return; - } - - // Should remove memory copy, use reg instead. - if (fast_broadcast) { - return; - } - int x_blocks = 0; - int x_threads = 0; - pten::ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks, - &x_threads, max_dim); - int y_blocks = 0; - int y_threads = 0; - pten::ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks, - &y_threads, max_dim); - - auto x_strides_array_tmp = memory::Alloc(ctx, bytes); - int *x_strides_array_gpu = - reinterpret_cast(x_strides_array_tmp->ptr()); - memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(), - bytes, ctx.stream()); - - auto y_strides_array_tmp = memory::Alloc(ctx, bytes); - int *y_strides_array_gpu = - reinterpret_cast(y_strides_array_tmp->ptr()); - memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(), - bytes, ctx.stream()); - - auto out_dims_array_tmp = memory::Alloc(ctx, bytes); - int *out_dims_array_gpu = reinterpret_cast(out_dims_array_tmp->ptr()); - memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes, - ctx.stream()); - - const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim, - 1, std::multiplies()); - int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads); - int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads); - if (dx) { - auto x_strides_order_tmp = memory::Alloc(ctx, bytes); - int *x_strides_order_gpu = - reinterpret_cast(x_strides_order_tmp->ptr()); - memory::Copy(gplace, x_strides_order_gpu, cplace, x_strides_order.data(), - bytes, ctx.stream()); - - auto x_dims_order_tmp = memory::Alloc(ctx, bytes); - int *x_dims_order_gpu = reinterpret_cast(x_dims_order_tmp->ptr()); - memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes, - ctx.stream()); - CommonGradBroadcastCUDAKernel< - T, DX_OP, Tout><<>>( - x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, - x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data, - dout_data, dx_data, out_size, max_dim, x_threads, dx_op); - } - if (dy) { - auto y_strides_order_tmp = memory::Alloc(ctx, bytes); - int *y_strides_order_gpu = - reinterpret_cast(y_strides_order_tmp->ptr()); - memory::Copy(gplace, y_strides_order_gpu, cplace, y_strides_order.data(), - bytes, ctx.stream()); - - auto y_dims_order_tmp = memory::Alloc(ctx, bytes); - int *y_dims_order_gpu = reinterpret_cast(y_dims_order_tmp->ptr()); - memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes, - ctx.stream()); - CommonGradBroadcastCUDAKernel< - T, DY_OP, Tout><<>>( - x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, - y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data, - dout_data, dy_data, out_size, max_dim, y_threads, dy_op); - } -} - -#endif // __NVCC__ or __HIPCC__ - inline framework::DDim trim_trailing_singular_dims( const framework::DDim &dims) { return pten::funcs::trim_trailing_singular_dims(dims); } -template -void CommonElementwiseBroadcastBackward( - const framework::ExecutionContext &ctx, const framework::DDim &x_dims, - const framework::DDim &y_dims, const framework::Tensor &x, - const framework::Tensor &y, const framework::Tensor &out, - const framework::Tensor &dout, int axis, framework::Tensor *dx, - framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { - int max_dim = std::max(x_dims.size(), y_dims.size()); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), - y_dims_array.data(), out_dims_array.data(), max_dim, - axis); - // for inplace strategy. memset will make dx and dout clear and get wrong - // result. - if (dx && dx->IsSharedBufferWith(dout)) { - dx->clear(); - dx->mutable_data(x_dims, ctx.GetPlace()); - } - - VLOG(3) << "CommonElementwiseBroadcastBackward xdims:" - << framework::make_ddim(x_dims_array) - << " ydim:" << framework::make_ddim(y_dims_array); - - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - CommonGradBroadcastCUDA( - x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(), - out_dims_array.data(), max_dim, - ctx.template device_context(), dx_op, - dy_op); -#endif - } else { - pten::CommonGradBroadcastCPU( - x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(), - out_dims_array.data(), max_dim, - ctx.template device_context(), dx_op, - dy_op); - } -} - -template -void ElemwiseGradComputeWithBroadcast( - const framework::ExecutionContext &ctx, const framework::DDim &x_dims, - const framework::DDim &y_dims, const framework::Tensor &x, - const framework::Tensor &y, const framework::Tensor &out, - const framework::Tensor &dout, int axis, framework::Tensor *dx, - framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { - bool is_xsize_larger = true; - - int max_dim = x_dims.size(); - if (x_dims.size() < y_dims.size()) { - is_xsize_larger = false; - max_dim = y_dims.size(); - } - - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - PADDLE_ENFORCE_GE( - axis, 0, - platform::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT(axis, max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, axis)); - - int pre, n, post, is_run_common_broadcast, axis_trim = 0; - if (is_xsize_larger) { - auto y_dims_trimed = trim_trailing_singular_dims(y_dims); - axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; - pten::funcs::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post, - &is_run_common_broadcast); - } else { - auto x_dims_trimed = trim_trailing_singular_dims(x_dims); - axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; - pten::funcs::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, - &is_run_common_broadcast); - } - // special case for common backward implementation. - if (is_run_common_broadcast) { - CommonElementwiseBroadcastBackward( - ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); - return; - } - if (post == 1) { - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - pten::ElemwiseGradBroadcast1CUDA( - ctx.template device_context().stream(), x.data(), - y.data(), out.data(), dout.data(), pre, n, - is_xsize_larger, dx_op, dy_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); -#endif - } else { - pten::ElemwiseGradBroadcast1CPU( - x.data(), y.data(), out.data(), dout.data(), pre, n, - is_xsize_larger, dx_op, dy_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); - } - } else { - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - pten::ElemwiseGradBroadcast2CUDA( - ctx.template device_context().stream(), x.data(), - y.data(), out.data(), dout.data(), pre, n, post, - is_xsize_larger, dx_op, dy_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); -#endif - } else { - pten::ElemwiseGradBroadcast2CPU( - x.data(), y.data(), out.data(), dout.data(), pre, n, - post, is_xsize_larger, dx_op, dy_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); - } - } -} - -template -void CommonElementwiseBroadcastForward( - const framework::ExecutionContext &ctx, const framework::Tensor *x, - const framework::Tensor *y, framework::Tensor *z, - const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func, - int axis, const bool is_xsize_larger = true) { - z->mutable_data(ctx.GetPlace()); - auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); - auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); - auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - const auto &dev_ctx = ctx.template device_context(); - pten::CommonElementwiseBroadcastForward(dev_ctx, *pt_x.get(), *pt_y.get(), - pt_z.get(), x_dims, y_dims, func, - axis, is_xsize_larger); -} - template void ElemwiseGradCompute(const framework::ExecutionContext &ctx, @@ -755,14 +147,14 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, DX_OP dx_op, DY_OP dy_op) { const framework::DDim &x_dim = x.dims(); const framework::DDim &y_dim = y.dims(); + const auto &dev_ctx = ctx.template device_context(); if (x.dims() == y.dims()) { - const auto &dev_ctx = ctx.template device_context(); pten::funcs::ElemwiseGradComputeNoBroadcast( dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { - ElemwiseGradComputeWithBroadcast( - ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + pten::ElemwiseGradComputeWithBroadcast( + dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } } @@ -780,14 +172,15 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx, DX_OP dx_op, DY_OP dy_op) { const framework::DDim &x_dim = x.dims(); const framework::DDim &y_dim = y.dims(); + const auto &dev_ctx = ctx.template device_context(); if (x.dims() == y.dims()) { - const auto &dev_ctx = ctx.template device_context(); pten::funcs::ElemwiseGradComputeNoBroadcast( dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op); } else { - ElemwiseGradComputeWithBroadcast( - ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op); + pten::ElemwiseGradComputeWithBroadcast( + dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, + dy_op); } } diff --git a/paddle/pten/kernels/cpu/elementwise.h b/paddle/pten/kernels/cpu/elementwise.h index 97db997a164..b448586754d 100644 --- a/paddle/pten/kernels/cpu/elementwise.h +++ b/paddle/pten/kernels/cpu/elementwise.h @@ -549,4 +549,148 @@ static void ElemwiseGradBroadcast2CPU(const T* x, } } +template +void CommonElementwiseBroadcastBackward(const CPUContext& ctx, + const DDim& x_dims, + const DDim& y_dims, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DX_OP dx_op, + DY_OP dy_op) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + // for inplace strategy. memset will make dx and dout clear and get wrong + // result. + if (dx && dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x_dims, ctx.GetPlace()); + } + + VLOG(3) << "CommonElementwiseBroadcastBackward xdims:" + << paddle::framework::make_ddim(x_dims_array) + << " ydim:" << paddle::framework::make_ddim(y_dims_array); + + CommonGradBroadcastCPU(x, + y, + out, + dout, + dx, + dy, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + ctx, + dx_op, + dy_op); +} + +template +void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx, + const DDim& x_dims, + const DDim& y_dims, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DX_OP dx_op, + DY_OP dy_op) { + bool is_xsize_larger = true; + + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, + 0, + paddle::platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + max_dim, + paddle::platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, + axis)); + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + funcs::get_mid_dims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + funcs::get_mid_dims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common backward implementation. + if (is_run_common_broadcast) { + CommonElementwiseBroadcastBackward( + ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + return; + } + if (post == 1) { + ElemwiseGradBroadcast1CPU( + x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } else { + ElemwiseGradBroadcast2CPU( + x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + post, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } +} + } // namespace pten diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index 4dfcd7a2152..5abc40c75d1 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -18,7 +18,10 @@ limitations under the License. */ #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/function_traits.h" +#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/funcs/cuda_kernel_config.h" +#include "paddle/pten/kernels/funcs/elementwise_base.h" #ifdef __HIPCC__ constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; @@ -28,6 +31,13 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #define BLOCK_X 32 #define BLOCK_Y 32 +#define GetDivMod(dividend, divisor, div, mod) \ + do { \ + const auto dividend_copy = dividend; \ + *div = dividend_copy / divisor; \ + *mod = dividend_copy % divisor; \ + } while (0) + namespace pten { namespace kps = paddle::operators::kernel_primitives; @@ -1469,4 +1479,762 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); } +template +__global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array, + const int *y_strides_array, + const int *out_dims_array, + const int *y_strides_order, + const int *y_dims_order, + const T *x, + const T *y, + const Tout *out, + const Tout *dout, + T *dx, + int out_size, + int max_dim, + int thread_num, + DX_OP dx_op) { + T val(0); + int i = blockIdx.x; + int tid = threadIdx.x; + for (int j = tid; j < thread_num; j += blockDim.x) { + const int X_index = i * thread_num + j; + int out_index = X_index; + int C_index = 0; + int B_index = i * thread_num + j; + int remainder = 0; +#pragma unroll + for (int d = max_dim - 1; d >= 0; --d) { + GetDivMod(B_index, y_dims_order[d], &B_index, &remainder); + C_index += remainder * y_strides_order[d]; + } + int x_index = 0; + int y_index = 0; + int C_index_val = C_index; +#pragma unroll + for (int d = max_dim - 1; d >= 0; --d) { + GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder); + x_index += remainder * x_strides_array[d]; + y_index += remainder * y_strides_array[d]; + } + out_index = C_index; + val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]); + } + val = paddle::platform::reduceSum(val, tid, thread_num); + if (threadIdx.x == 0) { + dx[i] = val; + } +} + +template +void CommonGradBroadcastCUDA(const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + int max_dim, + const GPUContext &ctx, + DX_OP dx_op, + DY_OP dy_op) { + const auto gplace = + BOOST_GET_CONST(paddle::platform::CUDAPlace, ctx.GetPlace()); + auto cplace = paddle::platform::CPUPlace(); + const T *x_data = x.data(); + const T *y_data = y.data(); + const Tout *out_data = out.data(); + const Tout *dout_data = dout.data(); + T *dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); + T *dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); + + std::vector x_one_indexs; + std::vector y_one_indexs; + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] != y_dims_array[i]) { + if (x_dims_array[i] == 1) { + x_one_indexs.push_back(i); + } + if (y_dims_array[i] == 1) { + y_one_indexs.push_back(i); + } + } + } + + std::vector x_trans_indexs(max_dim); + std::vector y_trans_indexs(max_dim); + ComputeBroadcastTranspositionArray( + x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size()); + ComputeBroadcastTranspositionArray( + y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size()); + + // compute array stride for cuda kernel; + // e.g. x.dims=[2,3,4], x_stride=[12,4,1] + std::vector x_strides_array(max_dim); + std::vector y_strides_array(max_dim); + std::vector out_strides_array(max_dim); + int x_stride = 1; + int y_stride = 1; + int z_stride = 1; + for (int i = max_dim - 1; i >= 0; i--) { + x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride; + y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride; + out_strides_array[i] = z_stride; + x_stride *= x_dims_array[i]; + y_stride *= y_dims_array[i]; + z_stride *= out_dims_array[i]; + } + + std::vector x_strides_order(max_dim); + std::vector y_strides_order(max_dim); + std::vector x_dims_order(max_dim); + std::vector y_dims_order(max_dim); + for (int i = 0; i < max_dim; ++i) { + x_strides_order[i] = out_strides_array[x_trans_indexs[i]]; + y_strides_order[i] = out_strides_array[y_trans_indexs[i]]; + x_dims_order[i] = out_dims_array[x_trans_indexs[i]]; + y_dims_order[i] = out_dims_array[y_trans_indexs[i]]; + } + std::vector x_broadcast_pos; + std::vector y_broadcast_pos; + + int bytes = max_dim * sizeof(int); + + for (int i = 0; i < max_dim; ++i) { + if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) { + x_broadcast_pos.emplace_back(i); + } + if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) { + y_broadcast_pos.emplace_back(i); + } + } + + auto stream = ctx.stream(); + bool can_split_x = false; + bool can_split_y = false; + + auto FastCommonCUDAF = [&](const std::vector &broadcast_pos, bool is_y) { + int h = std::accumulate(out_dims_array, + out_dims_array + broadcast_pos.size(), + 1, + std::multiplies()); + int w = std::accumulate(out_dims_array + broadcast_pos.size(), + out_dims_array + max_dim, + 1, + std::multiplies()); + + VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h + << " is_y:" << is_y; + + int split_h; + int split_w; + int kh = h; + int kw = w; + + if (is_y) { + split_h = std::accumulate(x_dims_array, + x_dims_array + broadcast_pos.size(), + 1, + std::multiplies()); + split_w = std::accumulate(x_dims_array + broadcast_pos.size(), + x_dims_array + max_dim, + 1, + std::multiplies()); + + } else { + split_h = std::accumulate(y_dims_array, + y_dims_array + broadcast_pos.size(), + 1, + std::multiplies()); + split_w = std::accumulate(y_dims_array + broadcast_pos.size(), + y_dims_array + max_dim, + 1, + std::multiplies()); + } + + if (h > split_h) kh = split_h; + if (w > split_w) kw = split_w; + + if (is_y) { + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + CommonGradBroadcast1CUDAKernelHeight<<>>(x_data, + y_data, + out_data, + dout_data, + h, + w, + dy_op, + dy_data, + kh, + kw, + is_y); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastCommonGradBroadcastCUDAKernelHeight<<>>(x_data, + y_data, + out_data, + dout_data, + h, + w, + dy_op, + dy_data, + kh, + kw, + is_y); + } + } else { + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + CommonGradBroadcast1CUDAKernelHeight<<>>(x_data, + y_data, + out_data, + dout_data, + h, + w, + dx_op, + dx_data, + kh, + kw, + is_y); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastCommonGradBroadcastCUDAKernelHeight<<>>(x_data, + y_data, + out_data, + dout_data, + h, + w, + dx_op, + dx_data, + kh, + kw, + is_y); + } + } + }; + + auto FastBroadCastHeightCUDAF = [&](const std::vector &broadcast_pos, + bool x_large) { + int h = std::accumulate(out_dims_array, + out_dims_array + broadcast_pos.size(), + 1, + std::multiplies()); + int w = std::accumulate(out_dims_array + broadcast_pos.size(), + out_dims_array + max_dim, + 1, + std::multiplies()); + + VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h; + + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + ElemwiseGradBroadcast1CUDAKernel<<>>( + x_data, + y_data, + out_data, + dout_data, + h, + w, + x_large, + dx_op, + dy_op, + dx_data, + dy_data); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastElemwiseGradBroadcast1CUDAKernel<<>>(x_data, + y_data, + out_data, + dout_data, + h, + w, + x_large, + dx_op, + dy_op, + dx_data, + dy_data); + } + }; + + auto FastBroadCastAllCUDAF = [&]( + const std::vector &broadcast_pos, int max_dim, bool is_x_large) { + int axis = broadcast_pos[0]; + int pre = std::accumulate( + out_dims_array, out_dims_array + axis, 1, std::multiplies()); + int mid = 1; + int post = 1; + + if (broadcast_pos.size() == 1) { + mid = out_dims_array[axis]; + post = std::accumulate(out_dims_array + axis + 1, + out_dims_array + max_dim, + 1, + std::multiplies()); + } else { + mid = std::accumulate(out_dims_array + axis, + out_dims_array + broadcast_pos.back() + 1, + 1, + std::multiplies()); + post = std::accumulate(out_dims_array + broadcast_pos.back() + 1, + out_dims_array + max_dim, + 1, + std::multiplies()); + } + + VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid + << " post:" << post; + + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); + int grid_size = pre * post; + + FastCommonGradBroadcastAllCUDAKernel<<>>( + x_data, + y_data, + out_data, + dout_data, + pre, + mid, + post, + is_x_large, + dx_op, + dy_op, + dx_data, + dy_data); + }; + + auto FastBroadCastOneCUDAF = [&]( + const std::vector &broadcast_pos, int max_dim, bool is_x) { + int axis = broadcast_pos[0]; + int pre = std::accumulate( + out_dims_array, out_dims_array + axis, 1, std::multiplies()); + int mid = out_dims_array[axis]; + int post = std::accumulate(out_dims_array + axis + 1, + out_dims_array + max_dim, + 1, + std::multiplies()); + + int k_pre; + int k_mid; + int k_post; + + if (is_x) { + k_pre = std::accumulate( + y_dims_array, y_dims_array + axis, 1, std::multiplies()); + k_mid = y_dims_array[axis]; + k_post = std::accumulate(y_dims_array + axis + 1, + y_dims_array + max_dim, + 1, + std::multiplies()); + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); + int grid_size = pre * post; + // we need to calc y offset with blockid, so do x_pre/y_pre to get left + // size. + if (k_pre != pre) k_pre = pre / k_pre; + + FastCommonGradBroadcastOneCUDAKernel<<>>(x_data, + y_data, + out_data, + dout_data, + pre, + mid, + post, + k_pre, + k_mid, + k_post, + true, + dx_op, + dx_data); + } else { + k_pre = std::accumulate( + x_dims_array, x_dims_array + axis, 1, std::multiplies()); + k_mid = x_dims_array[axis]; + k_post = std::accumulate(x_dims_array + axis + 1, + x_dims_array + max_dim, + 1, + std::multiplies()); + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); + int grid_size = pre * post; + if (k_pre != pre) k_pre = pre / k_pre; + + FastCommonGradBroadcastOneCUDAKernel<<>>(x_data, + y_data, + out_data, + dout_data, + pre, + mid, + post, + k_pre, + k_mid, + k_post, + false, + dy_op, + dy_data); + } + VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid + << " post:" << post; + }; + + // do fast elementwise if: 1. only one input need to do broadcast, we can + // fallback + // to old fast path. + // 2. if both x and y need broadcast, then do it one by one. + bool fast_broadcast = false; + if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { + can_split_y = SplitDims(y_broadcast_pos, max_dim); + if (can_split_y) { + // only y need to do broadcast on h + if (y_broadcast_pos[0] == 0) { + FastBroadCastHeightCUDAF(y_broadcast_pos, true); + fast_broadcast = true; + } + } else if (y_broadcast_pos.size() == 1 || + CheckContiguousDims(y_broadcast_pos)) { // for only one dim and + // contiguous broadcast. + // If cannot split, which means input has 3 parts + FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true); + fast_broadcast = true; + } + } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) { + // only x need broadcast + can_split_x = SplitDims(x_broadcast_pos, max_dim); + if (can_split_x) { + if (x_broadcast_pos[0] == 0) { + FastBroadCastHeightCUDAF(x_broadcast_pos, false); + fast_broadcast = true; + } + } else if (x_broadcast_pos.size() == 1 || + CheckContiguousDims(x_broadcast_pos)) { + FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false); + fast_broadcast = true; + } + } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { + // do x and y broadcast each. + can_split_y = SplitDims(y_broadcast_pos, max_dim); + bool fast_broadcast_x = false; + bool fast_broadcast_y = false; + if (can_split_y) { + // begin at start. + if (y_broadcast_pos[0] == 0) { + FastCommonCUDAF(y_broadcast_pos, true); + fast_broadcast_y = true; + } + } else if (y_broadcast_pos.size() == 1) { + FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false); + can_split_y = true; + fast_broadcast_y = true; + } + can_split_x = SplitDims(x_broadcast_pos, max_dim); + if (can_split_x) { + if (x_broadcast_pos[0] == 0) { + FastCommonCUDAF(x_broadcast_pos, false); + fast_broadcast_x = true; + } + } else if (x_broadcast_pos.size() == 1) { + FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true); + can_split_x = true; + fast_broadcast_x = true; + } + VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y + << " can_split_x:" << can_split_x; + // if both x and y into fast path then return + if (fast_broadcast_x && fast_broadcast_y) { + fast_broadcast = true; + } + if (can_split_y && can_split_x && fast_broadcast) return; + } + + // Should remove memory copy, use reg instead. + if (fast_broadcast) { + return; + } + int x_blocks = 0; + int x_threads = 0; + ComputeBroadcastKernelSize( + x_dims_array, out_dims_array, &x_blocks, &x_threads, max_dim); + int y_blocks = 0; + int y_threads = 0; + ComputeBroadcastKernelSize( + y_dims_array, out_dims_array, &y_blocks, &y_threads, max_dim); + + auto x_strides_array_tmp = paddle::memory::Alloc(ctx, bytes); + int *x_strides_array_gpu = + reinterpret_cast(x_strides_array_tmp->ptr()); + paddle::memory::Copy(gplace, + x_strides_array_gpu, + cplace, + x_strides_array.data(), + bytes, + ctx.stream()); + + auto y_strides_array_tmp = paddle::memory::Alloc(ctx, bytes); + int *y_strides_array_gpu = + reinterpret_cast(y_strides_array_tmp->ptr()); + paddle::memory::Copy(gplace, + y_strides_array_gpu, + cplace, + y_strides_array.data(), + bytes, + ctx.stream()); + + auto out_dims_array_tmp = paddle::memory::Alloc(ctx, bytes); + int *out_dims_array_gpu = reinterpret_cast(out_dims_array_tmp->ptr()); + paddle::memory::Copy( + gplace, out_dims_array_gpu, cplace, out_dims_array, bytes, ctx.stream()); + + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads); + int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads); + if (dx) { + auto x_strides_order_tmp = paddle::memory::Alloc(ctx, bytes); + int *x_strides_order_gpu = + reinterpret_cast(x_strides_order_tmp->ptr()); + paddle::memory::Copy(gplace, + x_strides_order_gpu, + cplace, + x_strides_order.data(), + bytes, + ctx.stream()); + + auto x_dims_order_tmp = paddle::memory::Alloc(ctx, bytes); + int *x_dims_order_gpu = reinterpret_cast(x_dims_order_tmp->ptr()); + paddle::memory::Copy(gplace, + x_dims_order_gpu, + cplace, + x_dims_order.data(), + bytes, + ctx.stream()); + CommonGradBroadcastCUDAKernel< + T, + DX_OP, + Tout><<>>(x_strides_array_gpu, + y_strides_array_gpu, + out_dims_array_gpu, + x_strides_order_gpu, + x_dims_order_gpu, + x_data, + y_data, + out_data, + dout_data, + dx_data, + out_size, + max_dim, + x_threads, + dx_op); + } + if (dy) { + auto y_strides_order_tmp = paddle::memory::Alloc(ctx, bytes); + int *y_strides_order_gpu = + reinterpret_cast(y_strides_order_tmp->ptr()); + paddle::memory::Copy(gplace, + y_strides_order_gpu, + cplace, + y_strides_order.data(), + bytes, + ctx.stream()); + + auto y_dims_order_tmp = paddle::memory::Alloc(ctx, bytes); + int *y_dims_order_gpu = reinterpret_cast(y_dims_order_tmp->ptr()); + paddle::memory::Copy(gplace, + y_dims_order_gpu, + cplace, + y_dims_order.data(), + bytes, + ctx.stream()); + CommonGradBroadcastCUDAKernel< + T, + DY_OP, + Tout><<>>(x_strides_array_gpu, + y_strides_array_gpu, + out_dims_array_gpu, + y_strides_order_gpu, + y_dims_order_gpu, + x_data, + y_data, + out_data, + dout_data, + dy_data, + out_size, + max_dim, + y_threads, + dy_op); + } +} + +template +void CommonElementwiseBroadcastBackward(const GPUContext &ctx, + const DDim &x_dims, + const DDim &y_dims, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + // for inplace strategy. memset will make dx and dout clear and get wrong + // result. + if (dx && dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x_dims, ctx.GetPlace()); + } + + VLOG(3) << "CommonElementwiseBroadcastBackward xdims:" + << paddle::framework::make_ddim(x_dims_array) + << " ydim:" << paddle::framework::make_ddim(y_dims_array); + + CommonGradBroadcastCUDA(x, + y, + out, + dout, + dx, + dy, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + ctx, + dx_op, + dy_op); +} + +template +void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, + const DDim &x_dims, + const DDim &y_dims, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + bool is_xsize_larger = true; + + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, + 0, + paddle::platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + max_dim, + paddle::platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, + axis)); + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + funcs::get_mid_dims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + funcs::get_mid_dims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common backward implementation. + if (is_run_common_broadcast) { + CommonElementwiseBroadcastBackward( + ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + return; + } + if (post == 1) { + ElemwiseGradBroadcast1CUDA( + ctx.stream(), + x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } else { + ElemwiseGradBroadcast2CUDA( + ctx.stream(), + x.data(), + y.data(), + out.data(), + dout.data(), + pre, + n, + post, + is_xsize_larger, + dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } +} + } // namespace pten -- GitLab