未验证 提交 7b2dc4e6 编写于 作者: W wangchaochaohu 提交者: GitHub

optimization for fp16 elementwise add (#29744)

上级 27bdbec7
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -176,6 +177,25 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
}
}
template <int SIZE>
__global__ void VecFP16MatrixColReduce(const __half2 *__restrict__ in,
__half2 *__restrict__ out, size_t width,
size_t height) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int by = blockIdx.y;
__half2 zero = __half2half2(static_cast<__half>(0));
const int cols = width / 2;
for (; idx < cols; idx += blockDim.x * gridDim.x) {
__half2 sum = zero;
for (int row = 0; row < SIZE; row++) {
int index = idx + (row + by * SIZE) * cols;
sum = __hadd2(sum, in[index]);
}
atomicAdd(&(out[idx]), sum);
}
}
template <typename T>
__global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out,
size_t width, size_t height) {
......@@ -198,7 +218,7 @@ __global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out,
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int w = idx * VEC_SIZE;
int width_stride = blockDim.x * gridDim.x * VEC_SIZE;
for (; w < width; w += width) {
for (; w < width; w += width_stride) {
T zero = static_cast<T>(0);
T sum[VEC_SIZE] = {zero};
T tmp_vec[VEC_SIZE] = {zero};
......@@ -341,6 +361,23 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1);
int theory_block = (width + blocks.x - 1) / blocks.x;
dim3 grids(std::min(theory_block, max_blocks));
if (std::is_same<T, paddle::platform::float16>::value && width < 2048 &&
width % 2 == 0 && height % 64 == 0) {
auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor;
if (dout->dims() == dx->dims())
functor(dev_ctx, dy, static_cast<T>(0));
else
functor(dev_ctx, dx, static_cast<T>(0));
const __half2 *ptr1 = reinterpret_cast<const __half2 *>(dout_data);
__half2 *ptr2 = reinterpret_cast<__half2 *>(out_data);
const int threads = 128;
dim3 grid(1, (height + 64 - 1) / 64);
VecFP16MatrixColReduce<64><<<grid, threads, 0, stream>>>(ptr1, ptr2,
width, height);
return;
}
if (width / height < 32) {
MatrixColReduce<T, block_x, block_y><<<grids, blocks, 0, stream>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册