未验证 提交 556d5097 编写于 作者: Y YuanRisheng 提交者: GitHub

refactor impl of elementwise op part2 (#38898)

上级 7f8d5bc8
......@@ -49,12 +49,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
#define GetDivMod(dividend, divisor, div, mod) \
do { \
const auto dividend_copy = dividend; \
*div = dividend_copy / divisor; \
*mod = dividend_copy % divisor; \
} while (0)
#define DIVUP(x, y) (((x) + (y)-1) / (y))
......@@ -138,613 +132,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
axis);
}
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
int *x_dims_array, int *y_dims_array,
int *out_dims_array, int max_dim,
const platform::CPUDeviceContext &ctx,
Functor func,
const bool is_xsize_larger = true) {
pten::CommonForwardBroadcastCPU(x, y, z, x_dims_array, y_dims_array,
out_dims_array, max_dim, ctx, func,
is_xsize_larger);
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename DX_OP, typename Tout = T>
__global__ void CommonGradBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const int *y_strides_order,
const int *y_dims_order, const T *x, const T *y, const Tout *out,
const Tout *dout, T *dx, int out_size, int max_dim, int thread_num,
DX_OP dx_op) {
T val(0);
int i = blockIdx.x;
int tid = threadIdx.x;
for (int j = tid; j < thread_num; j += blockDim.x) {
const int X_index = i * thread_num + j;
int out_index = X_index;
int C_index = 0;
int B_index = i * thread_num + j;
int remainder = 0;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(B_index, y_dims_order[d], &B_index, &remainder);
C_index += remainder * y_strides_order[d];
}
int x_index = 0;
int y_index = 0;
int C_index_val = C_index;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder);
x_index += remainder * x_strides_array[d];
y_index += remainder * y_strides_array[d];
}
out_index = C_index;
val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
}
val = paddle::platform::reduceSum(val, tid, thread_num);
if (threadIdx.x == 0) {
dx[i] = val;
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCUDA(
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out, const framework::Tensor &dout,
framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array,
int *y_dims_array, int *out_dims_array, int max_dim,
const platform::CUDADeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) {
const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto cplace = platform::CPUPlace();
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
std::vector<int> x_one_indexs;
std::vector<int> y_one_indexs;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] != y_dims_array[i]) {
if (x_dims_array[i] == 1) {
x_one_indexs.push_back(i);
}
if (y_dims_array[i] == 1) {
y_one_indexs.push_back(i);
}
}
}
std::vector<int> x_trans_indexs(max_dim);
std::vector<int> y_trans_indexs(max_dim);
pten::ComputeBroadcastTranspositionArray(
x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size());
pten::ComputeBroadcastTranspositionArray(
y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size());
// compute array stride for cuda kernel;
// e.g. x.dims=[2,3,4], x_stride=[12,4,1]
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
std::vector<int> out_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
int z_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
out_strides_array[i] = z_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
z_stride *= out_dims_array[i];
}
std::vector<int> x_strides_order(max_dim);
std::vector<int> y_strides_order(max_dim);
std::vector<int> x_dims_order(max_dim);
std::vector<int> y_dims_order(max_dim);
for (int i = 0; i < max_dim; ++i) {
x_strides_order[i] = out_strides_array[x_trans_indexs[i]];
y_strides_order[i] = out_strides_array[y_trans_indexs[i]];
x_dims_order[i] = out_dims_array[x_trans_indexs[i]];
y_dims_order[i] = out_dims_array[y_trans_indexs[i]];
}
std::vector<int> x_broadcast_pos;
std::vector<int> y_broadcast_pos;
int bytes = max_dim * sizeof(int);
for (int i = 0; i < max_dim; ++i) {
if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) {
x_broadcast_pos.emplace_back(i);
}
if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) {
y_broadcast_pos.emplace_back(i);
}
}
auto stream = ctx.stream();
bool can_split_x = false;
bool can_split_y = false;
auto FastCommonCUDAF = [&](const std::vector<int> &broadcast_pos, bool is_y) {
int h =
std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(),
1, std::multiplies<int>());
int w =
std::accumulate(out_dims_array + broadcast_pos.size(),
out_dims_array + max_dim, 1, std::multiplies<int>());
VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h
<< " is_y:" << is_y;
int split_h;
int split_w;
int kh = h;
int kw = w;
if (is_y) {
split_h =
std::accumulate(x_dims_array, x_dims_array + broadcast_pos.size(), 1,
std::multiplies<int>());
split_w =
std::accumulate(x_dims_array + broadcast_pos.size(),
x_dims_array + max_dim, 1, std::multiplies<int>());
} else {
split_h =
std::accumulate(y_dims_array, y_dims_array + broadcast_pos.size(), 1,
std::multiplies<int>());
split_w =
std::accumulate(y_dims_array + broadcast_pos.size(),
y_dims_array + max_dim, 1, std::multiplies<int>());
}
if (h > split_h) kh = split_h;
if (w > split_w) kw = split_w;
if (is_y) {
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
pten::CommonGradBroadcast1CUDAKernelHeight<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
pten::FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, block_size,
0, stream>>>(
x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw,
is_y);
}
} else {
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
pten::CommonGradBroadcast1CUDAKernelHeight<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
pten::FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, block_size,
0, stream>>>(
x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw,
is_y);
}
}
};
auto FastBroadCastHeightCUDAF = [&](const std::vector<int> &broadcast_pos,
bool x_large) {
int h =
std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(),
1, std::multiplies<int>());
int w =
std::accumulate(out_dims_array + broadcast_pos.size(),
out_dims_array + max_dim, 1, std::multiplies<int>());
VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h;
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
pten::ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op,
dx_data, dy_data);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
pten::FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op,
dx_data, dy_data);
}
};
auto FastBroadCastAllCUDAF = [&](const std::vector<int> &broadcast_pos,
int max_dim, bool is_x_large) {
int axis = broadcast_pos[0];
int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
std::multiplies<int>());
int mid = 1;
int post = 1;
if (broadcast_pos.size() == 1) {
mid = out_dims_array[axis];
post =
std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim,
1, std::multiplies<int>());
} else {
mid = std::accumulate(out_dims_array + axis,
out_dims_array + broadcast_pos.back() + 1, 1,
std::multiplies<int>());
post =
std::accumulate(out_dims_array + broadcast_pos.back() + 1,
out_dims_array + max_dim, 1, std::multiplies<int>());
}
VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid
<< " post:" << post;
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
pten::FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, pre, mid, post, is_x_large, dx_op,
dy_op, dx_data, dy_data);
};
auto FastBroadCastOneCUDAF = [&](const std::vector<int> &broadcast_pos,
int max_dim, bool is_x) {
int axis = broadcast_pos[0];
int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
std::multiplies<int>());
int mid = out_dims_array[axis];
int post =
std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1,
std::multiplies<int>());
int k_pre;
int k_mid;
int k_post;
if (is_x) {
k_pre = std::accumulate(y_dims_array, y_dims_array + axis, 1,
std::multiplies<int>());
k_mid = y_dims_array[axis];
k_post = std::accumulate(y_dims_array + axis + 1, y_dims_array + max_dim,
1, std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
// we need to calc y offset with blockid, so do x_pre/y_pre to get left
// size.
if (k_pre != pre) k_pre = pre / k_pre;
pten::FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
k_post, true, dx_op, dx_data);
} else {
k_pre = std::accumulate(x_dims_array, x_dims_array + axis, 1,
std::multiplies<int>());
k_mid = x_dims_array[axis];
k_post = std::accumulate(x_dims_array + axis + 1, x_dims_array + max_dim,
1, std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
if (k_pre != pre) k_pre = pre / k_pre;
pten::FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
stream>>>(
x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
k_post, false, dy_op, dy_data);
}
VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid
<< " post:" << post;
};
// do fast elementwise if: 1. only one input need to do broadcast, we can
// fallback
// to old fast path.
// 2. if both x and y need broadcast, then do it one by one.
bool fast_broadcast = false;
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
can_split_y = pten::SplitDims(y_broadcast_pos, max_dim);
if (can_split_y) {
// only y need to do broadcast on h
if (y_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(y_broadcast_pos, true);
fast_broadcast = true;
}
} else if (y_broadcast_pos.size() == 1 ||
pten::CheckContiguousDims(
y_broadcast_pos)) { // for only one dim and
// contiguous broadcast.
// If cannot split, which means input has 3 parts
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
fast_broadcast = true;
}
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
// only x need broadcast
can_split_x = pten::SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(x_broadcast_pos, false);
fast_broadcast = true;
}
} else if (x_broadcast_pos.size() == 1 ||
pten::CheckContiguousDims(x_broadcast_pos)) {
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
fast_broadcast = true;
}
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
// do x and y broadcast each.
can_split_y = pten::SplitDims(y_broadcast_pos, max_dim);
bool fast_broadcast_x = false;
bool fast_broadcast_y = false;
if (can_split_y) {
// begin at start.
if (y_broadcast_pos[0] == 0) {
FastCommonCUDAF(y_broadcast_pos, true);
fast_broadcast_y = true;
}
} else if (y_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
can_split_y = true;
fast_broadcast_y = true;
}
can_split_x = pten::SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastCommonCUDAF(x_broadcast_pos, false);
fast_broadcast_x = true;
}
} else if (x_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
can_split_x = true;
fast_broadcast_x = true;
}
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
<< " can_split_x:" << can_split_x;
// if both x and y into fast path then return
if (fast_broadcast_x && fast_broadcast_y) {
fast_broadcast = true;
}
if (can_split_y && can_split_x && fast_broadcast) return;
}
// Should remove memory copy, use reg instead.
if (fast_broadcast) {
return;
}
int x_blocks = 0;
int x_threads = 0;
pten::ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
&x_threads, max_dim);
int y_blocks = 0;
int y_threads = 0;
pten::ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks,
&y_threads, max_dim);
auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
bytes, ctx.stream());
auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
bytes, ctx.stream());
auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
ctx.stream());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
if (dx) {
auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
int *x_strides_order_gpu =
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
memory::Copy(gplace, x_strides_order_gpu, cplace, x_strides_order.data(),
bytes, ctx.stream());
auto x_dims_order_tmp = memory::Alloc(ctx, bytes);
int *x_dims_order_gpu = reinterpret_cast<int *>(x_dims_order_tmp->ptr());
memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DX_OP, Tout><<<x_blocks, x_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
}
if (dy) {
auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
int *y_strides_order_gpu =
reinterpret_cast<int *>(y_strides_order_tmp->ptr());
memory::Copy(gplace, y_strides_order_gpu, cplace, y_strides_order.data(),
bytes, ctx.stream());
auto y_dims_order_tmp = memory::Alloc(ctx, bytes);
int *y_dims_order_gpu = reinterpret_cast<int *>(y_dims_order_tmp->ptr());
memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DY_OP, Tout><<<y_blocks, y_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data,
dout_data, dy_data, out_size, max_dim, y_threads, dy_op);
}
}
#endif // __NVCC__ or __HIPCC__
inline framework::DDim trim_trailing_singular_dims(
const framework::DDim &dims) {
return pten::funcs::trim_trailing_singular_dims(dims);
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void CommonElementwiseBroadcastBackward(
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
<< framework::make_ddim(x_dims_array)
<< " ydim:" << framework::make_ddim(y_dims_array);
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
CommonGradBroadcastCUDA<T, DX_OP, DY_OP, Tout>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), dx_op,
dy_op);
#endif
} else {
pten::CommonGradBroadcastCPU<T, DX_OP, DY_OP, Tout>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), dx_op,
dy_op);
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim, axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
pten::funcs::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
pten::funcs::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP, Tout>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
if (post == 1) {
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
pten::ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
pten::ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
pten::ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n, post,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
pten::ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
post, is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void CommonElementwiseBroadcastForward(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
int axis, const bool is_xsize_larger = true) {
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
pten::CommonElementwiseBroadcastForward(dev_ctx, *pt_x.get(), *pt_y.get(),
pt_z.get(), x_dims, y_dims, func,
axis, is_xsize_larger);
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
......@@ -755,14 +147,14 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (x.dims() == y.dims()) {
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
pten::funcs::ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP,
Tout>(
dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP, Tout>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
pten::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP, Tout>(
dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
......@@ -780,14 +172,15 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (x.dims() == y.dims()) {
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
pten::funcs::ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op,
dy_op);
} else {
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op);
pten::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP>(
dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op,
dy_op);
}
}
......
......@@ -549,4 +549,148 @@ static void ElemwiseGradBroadcast2CPU(const T* x,
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonElementwiseBroadcastBackward(const CPUContext& ctx,
const DDim& x_dims,
const DDim& y_dims,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
<< paddle::framework::make_ddim(x_dims_array)
<< " ydim:" << paddle::framework::make_ddim(y_dims_array);
CommonGradBroadcastCPU<T, DX_OP, DY_OP, Tout>(x,
y,
out,
dout,
dx,
dy,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
ctx,
dx_op,
dy_op);
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx,
const DDim& x_dims,
const DDim& y_dims,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
funcs::get_mid_dims(x_dims,
y_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
funcs::get_mid_dims(y_dims,
x_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<T, DX_OP, DY_OP, Tout>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
if (post == 1) {
ElemwiseGradBroadcast1CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
} // namespace pten
......@@ -18,7 +18,10 @@ limitations under the License. */
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/cuda_kernel_config.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
......@@ -28,6 +31,13 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#define BLOCK_X 32
#define BLOCK_Y 32
#define GetDivMod(dividend, divisor, div, mod) \
do { \
const auto dividend_copy = dividend; \
*div = dividend_copy / divisor; \
*mod = dividend_copy % divisor; \
} while (0)
namespace pten {
namespace kps = paddle::operators::kernel_primitives;
......@@ -1469,4 +1479,762 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
}
template <typename T, typename DX_OP, typename Tout = T>
__global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array,
const int *y_strides_array,
const int *out_dims_array,
const int *y_strides_order,
const int *y_dims_order,
const T *x,
const T *y,
const Tout *out,
const Tout *dout,
T *dx,
int out_size,
int max_dim,
int thread_num,
DX_OP dx_op) {
T val(0);
int i = blockIdx.x;
int tid = threadIdx.x;
for (int j = tid; j < thread_num; j += blockDim.x) {
const int X_index = i * thread_num + j;
int out_index = X_index;
int C_index = 0;
int B_index = i * thread_num + j;
int remainder = 0;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(B_index, y_dims_order[d], &B_index, &remainder);
C_index += remainder * y_strides_order[d];
}
int x_index = 0;
int y_index = 0;
int C_index_val = C_index;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder);
x_index += remainder * x_strides_array[d];
y_index += remainder * y_strides_array[d];
}
out_index = C_index;
val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
}
val = paddle::platform::reduceSum(val, tid, thread_num);
if (threadIdx.x == 0) {
dx[i] = val;
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCUDA(const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int *x_dims_array,
int *y_dims_array,
int *out_dims_array,
int max_dim,
const GPUContext &ctx,
DX_OP dx_op,
DY_OP dy_op) {
const auto gplace =
BOOST_GET_CONST(paddle::platform::CUDAPlace, ctx.GetPlace());
auto cplace = paddle::platform::CPUPlace();
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
std::vector<int> x_one_indexs;
std::vector<int> y_one_indexs;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] != y_dims_array[i]) {
if (x_dims_array[i] == 1) {
x_one_indexs.push_back(i);
}
if (y_dims_array[i] == 1) {
y_one_indexs.push_back(i);
}
}
}
std::vector<int> x_trans_indexs(max_dim);
std::vector<int> y_trans_indexs(max_dim);
ComputeBroadcastTranspositionArray(
x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size());
ComputeBroadcastTranspositionArray(
y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size());
// compute array stride for cuda kernel;
// e.g. x.dims=[2,3,4], x_stride=[12,4,1]
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
std::vector<int> out_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
int z_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
out_strides_array[i] = z_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
z_stride *= out_dims_array[i];
}
std::vector<int> x_strides_order(max_dim);
std::vector<int> y_strides_order(max_dim);
std::vector<int> x_dims_order(max_dim);
std::vector<int> y_dims_order(max_dim);
for (int i = 0; i < max_dim; ++i) {
x_strides_order[i] = out_strides_array[x_trans_indexs[i]];
y_strides_order[i] = out_strides_array[y_trans_indexs[i]];
x_dims_order[i] = out_dims_array[x_trans_indexs[i]];
y_dims_order[i] = out_dims_array[y_trans_indexs[i]];
}
std::vector<int> x_broadcast_pos;
std::vector<int> y_broadcast_pos;
int bytes = max_dim * sizeof(int);
for (int i = 0; i < max_dim; ++i) {
if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) {
x_broadcast_pos.emplace_back(i);
}
if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) {
y_broadcast_pos.emplace_back(i);
}
}
auto stream = ctx.stream();
bool can_split_x = false;
bool can_split_y = false;
auto FastCommonCUDAF = [&](const std::vector<int> &broadcast_pos, bool is_y) {
int h = std::accumulate(out_dims_array,
out_dims_array + broadcast_pos.size(),
1,
std::multiplies<int>());
int w = std::accumulate(out_dims_array + broadcast_pos.size(),
out_dims_array + max_dim,
1,
std::multiplies<int>());
VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h
<< " is_y:" << is_y;
int split_h;
int split_w;
int kh = h;
int kw = w;
if (is_y) {
split_h = std::accumulate(x_dims_array,
x_dims_array + broadcast_pos.size(),
1,
std::multiplies<int>());
split_w = std::accumulate(x_dims_array + broadcast_pos.size(),
x_dims_array + max_dim,
1,
std::multiplies<int>());
} else {
split_h = std::accumulate(y_dims_array,
y_dims_array + broadcast_pos.size(),
1,
std::multiplies<int>());
split_w = std::accumulate(y_dims_array + broadcast_pos.size(),
y_dims_array + max_dim,
1,
std::multiplies<int>());
}
if (h > split_h) kh = split_h;
if (w > split_w) kw = split_w;
if (is_y) {
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
CommonGradBroadcast1CUDAKernelHeight<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
h,
w,
dy_op,
dy_data,
kh,
kw,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
h,
w,
dy_op,
dy_data,
kh,
kw,
is_y);
}
} else {
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
CommonGradBroadcast1CUDAKernelHeight<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
h,
w,
dx_op,
dx_data,
kh,
kw,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
h,
w,
dx_op,
dx_data,
kh,
kw,
is_y);
}
}
};
auto FastBroadCastHeightCUDAF = [&](const std::vector<int> &broadcast_pos,
bool x_large) {
int h = std::accumulate(out_dims_array,
out_dims_array + broadcast_pos.size(),
1,
std::multiplies<int>());
int w = std::accumulate(out_dims_array + broadcast_pos.size(),
out_dims_array + max_dim,
1,
std::multiplies<int>());
VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h;
if (w < 16 || h < 16) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int grid_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x_data,
y_data,
out_data,
dout_data,
h,
w,
x_large,
dx_op,
dy_op,
dx_data,
dy_data);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
h,
w,
x_large,
dx_op,
dy_op,
dx_data,
dy_data);
}
};
auto FastBroadCastAllCUDAF = [&](
const std::vector<int> &broadcast_pos, int max_dim, bool is_x_large) {
int axis = broadcast_pos[0];
int pre = std::accumulate(
out_dims_array, out_dims_array + axis, 1, std::multiplies<int>());
int mid = 1;
int post = 1;
if (broadcast_pos.size() == 1) {
mid = out_dims_array[axis];
post = std::accumulate(out_dims_array + axis + 1,
out_dims_array + max_dim,
1,
std::multiplies<int>());
} else {
mid = std::accumulate(out_dims_array + axis,
out_dims_array + broadcast_pos.back() + 1,
1,
std::multiplies<int>());
post = std::accumulate(out_dims_array + broadcast_pos.back() + 1,
out_dims_array + max_dim,
1,
std::multiplies<int>());
}
VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid
<< " post:" << post;
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
x_data,
y_data,
out_data,
dout_data,
pre,
mid,
post,
is_x_large,
dx_op,
dy_op,
dx_data,
dy_data);
};
auto FastBroadCastOneCUDAF = [&](
const std::vector<int> &broadcast_pos, int max_dim, bool is_x) {
int axis = broadcast_pos[0];
int pre = std::accumulate(
out_dims_array, out_dims_array + axis, 1, std::multiplies<int>());
int mid = out_dims_array[axis];
int post = std::accumulate(out_dims_array + axis + 1,
out_dims_array + max_dim,
1,
std::multiplies<int>());
int k_pre;
int k_mid;
int k_post;
if (is_x) {
k_pre = std::accumulate(
y_dims_array, y_dims_array + axis, 1, std::multiplies<int>());
k_mid = y_dims_array[axis];
k_post = std::accumulate(y_dims_array + axis + 1,
y_dims_array + max_dim,
1,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
// we need to calc y offset with blockid, so do x_pre/y_pre to get left
// size.
if (k_pre != pre) k_pre = pre / k_pre;
FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
pre,
mid,
post,
k_pre,
k_mid,
k_post,
true,
dx_op,
dx_data);
} else {
k_pre = std::accumulate(
x_dims_array, x_dims_array + axis, 1, std::multiplies<int>());
k_mid = x_dims_array[axis];
k_post = std::accumulate(x_dims_array + axis + 1,
x_dims_array + max_dim,
1,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
if (k_pre != pre) k_pre = pre / k_pre;
FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
block_size,
0,
stream>>>(x_data,
y_data,
out_data,
dout_data,
pre,
mid,
post,
k_pre,
k_mid,
k_post,
false,
dy_op,
dy_data);
}
VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid
<< " post:" << post;
};
// do fast elementwise if: 1. only one input need to do broadcast, we can
// fallback
// to old fast path.
// 2. if both x and y need broadcast, then do it one by one.
bool fast_broadcast = false;
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
can_split_y = SplitDims(y_broadcast_pos, max_dim);
if (can_split_y) {
// only y need to do broadcast on h
if (y_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(y_broadcast_pos, true);
fast_broadcast = true;
}
} else if (y_broadcast_pos.size() == 1 ||
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
// contiguous broadcast.
// If cannot split, which means input has 3 parts
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
fast_broadcast = true;
}
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
// only x need broadcast
can_split_x = SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(x_broadcast_pos, false);
fast_broadcast = true;
}
} else if (x_broadcast_pos.size() == 1 ||
CheckContiguousDims(x_broadcast_pos)) {
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
fast_broadcast = true;
}
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
// do x and y broadcast each.
can_split_y = SplitDims(y_broadcast_pos, max_dim);
bool fast_broadcast_x = false;
bool fast_broadcast_y = false;
if (can_split_y) {
// begin at start.
if (y_broadcast_pos[0] == 0) {
FastCommonCUDAF(y_broadcast_pos, true);
fast_broadcast_y = true;
}
} else if (y_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
can_split_y = true;
fast_broadcast_y = true;
}
can_split_x = SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastCommonCUDAF(x_broadcast_pos, false);
fast_broadcast_x = true;
}
} else if (x_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
can_split_x = true;
fast_broadcast_x = true;
}
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
<< " can_split_x:" << can_split_x;
// if both x and y into fast path then return
if (fast_broadcast_x && fast_broadcast_y) {
fast_broadcast = true;
}
if (can_split_y && can_split_x && fast_broadcast) return;
}
// Should remove memory copy, use reg instead.
if (fast_broadcast) {
return;
}
int x_blocks = 0;
int x_threads = 0;
ComputeBroadcastKernelSize(
x_dims_array, out_dims_array, &x_blocks, &x_threads, max_dim);
int y_blocks = 0;
int y_threads = 0;
ComputeBroadcastKernelSize(
y_dims_array, out_dims_array, &y_blocks, &y_threads, max_dim);
auto x_strides_array_tmp = paddle::memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_strides_array_gpu,
cplace,
x_strides_array.data(),
bytes,
ctx.stream());
auto y_strides_array_tmp = paddle::memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
paddle::memory::Copy(gplace,
y_strides_array_gpu,
cplace,
y_strides_array.data(),
bytes,
ctx.stream());
auto out_dims_array_tmp = paddle::memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
paddle::memory::Copy(
gplace, out_dims_array_gpu, cplace, out_dims_array, bytes, ctx.stream());
const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
if (dx) {
auto x_strides_order_tmp = paddle::memory::Alloc(ctx, bytes);
int *x_strides_order_gpu =
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
paddle::memory::Copy(gplace,
x_strides_order_gpu,
cplace,
x_strides_order.data(),
bytes,
ctx.stream());
auto x_dims_order_tmp = paddle::memory::Alloc(ctx, bytes);
int *x_dims_order_gpu = reinterpret_cast<int *>(x_dims_order_tmp->ptr());
paddle::memory::Copy(gplace,
x_dims_order_gpu,
cplace,
x_dims_order.data(),
bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T,
DX_OP,
Tout><<<x_blocks, x_block_size, 0, ctx.stream()>>>(x_strides_array_gpu,
y_strides_array_gpu,
out_dims_array_gpu,
x_strides_order_gpu,
x_dims_order_gpu,
x_data,
y_data,
out_data,
dout_data,
dx_data,
out_size,
max_dim,
x_threads,
dx_op);
}
if (dy) {
auto y_strides_order_tmp = paddle::memory::Alloc(ctx, bytes);
int *y_strides_order_gpu =
reinterpret_cast<int *>(y_strides_order_tmp->ptr());
paddle::memory::Copy(gplace,
y_strides_order_gpu,
cplace,
y_strides_order.data(),
bytes,
ctx.stream());
auto y_dims_order_tmp = paddle::memory::Alloc(ctx, bytes);
int *y_dims_order_gpu = reinterpret_cast<int *>(y_dims_order_tmp->ptr());
paddle::memory::Copy(gplace,
y_dims_order_gpu,
cplace,
y_dims_order.data(),
bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T,
DY_OP,
Tout><<<y_blocks, y_block_size, 0, ctx.stream()>>>(x_strides_array_gpu,
y_strides_array_gpu,
out_dims_array_gpu,
y_strides_order_gpu,
y_dims_order_gpu,
x_data,
y_data,
out_data,
dout_data,
dy_data,
out_size,
max_dim,
y_threads,
dy_op);
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonElementwiseBroadcastBackward(const GPUContext &ctx,
const DDim &x_dims,
const DDim &y_dims,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
int axis,
DenseTensor *dx,
DenseTensor *dy,
DX_OP dx_op,
DY_OP dy_op) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
<< paddle::framework::make_ddim(x_dims_array)
<< " ydim:" << paddle::framework::make_ddim(y_dims_array);
CommonGradBroadcastCUDA<T, DX_OP, DY_OP, Tout>(x,
y,
out,
dout,
dx,
dy,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
ctx,
dx_op,
dy_op);
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
const DDim &x_dims,
const DDim &y_dims,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
int axis,
DenseTensor *dx,
DenseTensor *dy,
DX_OP dx_op,
DY_OP dy_op) {
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
funcs::get_mid_dims(x_dims,
y_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
funcs::get_mid_dims(y_dims,
x_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<T, DX_OP, DY_OP, Tout>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
if (post == 1) {
ElemwiseGradBroadcast1CUDA(
ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} else {
ElemwiseGradBroadcast2CUDA(
ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册