未验证 提交 1b69e528 编写于 作者: W wangchaochaohu 提交者: GitHub

optimize for long width for elementwise (#29602)

上级 78dad786
......@@ -19,6 +19,11 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#endif
namespace paddle {
namespace operators {
......@@ -121,6 +126,20 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
template <typename T>
inline int VectorizedSize(const T *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
template <typename T, int BLOCK_W, int BLOCK_H>
__global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
size_t width, size_t height) {
......@@ -200,6 +219,45 @@ __global__ void FP16MatrixColReduce(
if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x];
}
}
template <typename T>
__global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out,
size_t width, size_t height) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (; idx < width; idx += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int row = 0; row < height; row++) {
sum += in[idx + row * width];
}
out[idx] = sum;
}
}
template <typename T, int VEC_SIZE>
__global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out,
size_t width, size_t height) {
using LoadT = AlignedVector<T, VEC_SIZE>;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int w = idx * VEC_SIZE;
int width_stride = blockDim.x * gridDim.x * VEC_SIZE;
for (; w < width; w += width) {
T zero = static_cast<T>(0);
T sum[VEC_SIZE] = {zero};
T tmp_vec[VEC_SIZE] = {zero};
LoadT *tmp_ptr = reinterpret_cast<LoadT *>(&tmp_vec);
for (int row = 0; row < height; row++) {
int offset = width * row + w;
*tmp_ptr = *reinterpret_cast<const LoadT *>(&in[offset]);
for (int v = 0; v < VEC_SIZE; v++) {
sum[v] += tmp_vec[v];
}
}
for (int v = 0; v < VEC_SIZE; v++) out[w + v] = sum[v];
}
}
#endif
#endif
bool static RunSpecialDims(const framework::DDim &dx_dims,
......@@ -301,6 +359,21 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
// special optimization using cub
if (width == 1) {
int nums = height;
size_t temp_storage_bytes = 0;
auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes,
dout_data, out_data, nums, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
framework::Tensor tmp;
auto *temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
ctx.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes,
dout_data, out_data, nums, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
}
constexpr int block_x = 32;
constexpr int block_y = 32;
......@@ -311,7 +384,8 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1);
int theory_block = (width + blocks.x - 1) / blocks.x;
dim3 grids(std::min(theory_block, max_blocks));
if (std::is_same<T, paddle::platform::float16>::value) {
if (std::is_same<T, paddle::platform::float16>::value &&
(width / height) < 32) {
const paddle::platform::float16 *ptr1 =
reinterpret_cast<const paddle::platform::float16 *>(dout_data);
paddle::platform::float16 *ptr2 =
......@@ -325,8 +399,24 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
}
return;
}
MatrixColReduce<T, block_x, block_y><<<grids, blocks, 0, stream>>>(
dout_data, out_data, width, height);
if (width / height < 32) {
MatrixColReduce<T, block_x, block_y><<<grids, blocks, 0, stream>>>(
dout_data, out_data, width, height);
} else {
size_t thread_nums = 1024;
size_t block_nums = (width + thread_nums - 1) / thread_nums;
int vec_size = VectorizedSize<T>(dx_data);
if (vec_size == 4 && width % 4 == 0) {
block_nums = (width / vec_size + thread_nums - 1) / thread_nums;
VecMatrixReduceLongWidth<T,
4><<<block_nums, thread_nums, 0, stream>>>(
dout_data, out_data, width, height);
} else {
MatrixReduceLongWidth<T><<<block_nums, thread_nums, 0, stream>>>(
dout_data, out_data, width, height);
}
}
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册