From e59524f86d472f0f36e09cc41c3ca882d3fc2841 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 11 Jan 2021 16:11:33 +0800 Subject: [PATCH] [cherry-pick]Elementwise add grad GPU kernel optimization (#30276) * elementwise_add_grad Op optimization (#29575) * optimize for long width for elementwise (#29602) * refine (#29622) * delete the code for fp16 optimization because it is not faster than common template code (#29715) * fix the shape choose of vectorize for cuda * optimization for fp16 elementwise add (#29744) * Fix the compiler error for half type (#29799) * refine the compiler error for half2 operation (#29816) * fix the compiler error when gcc4 cuda9.0 (#29997) --- .../elementwise/elementwise_add_op.h | 265 ++++++++++++++++++ .../unittests/test_elementwise_add_op.py | 11 + 2 files changed, 276 insertions(+) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index acda31e0f23..41e97a39466 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -13,10 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ +#include +#include +#include "cub/cub.cuh" +#endif +#endif namespace paddle { namespace operators { @@ -116,6 +126,167 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} +template +__global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, + size_t width, size_t height) { + __shared__ T sdata[BLOCK_H][BLOCK_W + 1]; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t width_stride = gridDim.x * blockDim.x; + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) + + ((height & (BLOCK_H - 1)) ? BLOCK_H : 0); + +#pragma unroll + for (size_t w = idx; w < full_width; w += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + __syncthreads(); + size_t offset = w + threadIdx.y * width; +#pragma unroll + for (size_t h = threadIdx.y; h < full_height; + h += BLOCK_H) { // block-stride loop across matrix height + sdata[threadIdx.y][threadIdx.x] += + (w < width && h < height) ? in[offset] : (static_cast(0)); + offset += width * BLOCK_H; + } + __syncthreads(); + + T val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + + __syncthreads(); + if (threadIdx.x == 0) sdata[0][threadIdx.y] = val; + __syncthreads(); + if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; + } +} + +#if CUDA_VERSION >= 10000 +template +__global__ void VecFP16MatrixColReduce(const __half2 *__restrict__ in, + __half2 *__restrict__ out, size_t width, + size_t height) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int by = blockIdx.y; + __half2 zero = __half2half2(static_cast<__half>(0)); + const int cols = width / 2; + for (; idx < cols; idx += blockDim.x * gridDim.x) { + __half2 sum = zero; + for (int row = 0; row < SIZE; row++) { + int index = idx + (row + by * SIZE) * cols; + sum = __hadd2(sum, in[index]); + } + + atomicAdd(&(out[idx]), sum); + } +#endif +} +#endif + +template +__global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out, + size_t width, size_t height) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + + for (; idx < width; idx += blockDim.x * gridDim.x) { + T sum = static_cast(0); + for (int row = 0; row < height; row++) { + sum += in[idx + row * width]; + } + + out[idx] = sum; + } +} + +template +__global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out, + size_t width, size_t height) { + using LoadT = AlignedVector; + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int w = idx * VEC_SIZE; + int width_stride = blockDim.x * gridDim.x * VEC_SIZE; + for (; w < width; w += width_stride) { + T zero = static_cast(0); + T sum[VEC_SIZE] = {zero}; + T tmp_vec[VEC_SIZE] = {zero}; + LoadT *tmp_ptr = reinterpret_cast(&tmp_vec); + for (int row = 0; row < height; row++) { + int offset = width * row + w; + *tmp_ptr = *reinterpret_cast(&in[offset]); + for (int v = 0; v < VEC_SIZE; v++) { + sum[v] += tmp_vec[v]; + } + } + + for (int v = 0; v < VEC_SIZE; v++) out[w + v] = sum[v]; + } +} +#endif +#endif +bool static RunSpecialDims(const framework::DDim &dx_dims, + const framework::DDim &dy_dims, + const framework::DDim &dout_dims, int axis) { + auto smaller_dims = dx_dims; + auto bigger_dims = dy_dims; + auto smaller_dims_size = smaller_dims.size(); + auto bigger_dims_size = bigger_dims.size(); + int smaller_ignore_size = 0; + int bigger_ignore_size = 0; + for (int i = 0; i < smaller_dims_size; i++) { + if (smaller_dims[i] == 1) + smaller_ignore_size++; + else + break; + } + for (int i = 0; i < bigger_dims_size; i++) { + if (bigger_dims[i] == 1) + bigger_ignore_size++; + else + break; + } + + int smaller_real_size = smaller_dims.size() - smaller_ignore_size; + int bigger_real_size = bigger_dims.size() - bigger_ignore_size; + + if (smaller_real_size == bigger_real_size) return false; + + if (bigger_real_size < smaller_real_size) { + smaller_dims = dy_dims; + bigger_dims = dx_dims; + std::swap(smaller_real_size, bigger_real_size); + } + int big_size = bigger_dims.size(); + int small_size = smaller_dims.size(); + for (int i = 1; i <= smaller_real_size; i++) { + if (bigger_dims[big_size - i] != smaller_dims[small_size - i]) return false; + } + + if (axis != -1 && (axis != (bigger_real_size - smaller_real_size))) { + return false; + } + + return true; +} + #ifdef PADDLE_WITH_CUDA // cuda definition template @@ -144,6 +315,100 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { // skip out auto *out = dout; +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + + int axis = ctx.Attr("axis"); + if (ctx.GetPlace() == platform::CUDAPlace() && dx != nullptr && + dy != nullptr && dout != nullptr && dx->numel() != dy->numel() && + RunSpecialDims(dx->dims(), dy->dims(), dout->dims(), axis)) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + auto *dout_data = dout->data(); + auto stream = ctx.cuda_device_context().stream(); + auto *out_data = dx_data; + int width = dx->numel(); + int height = dout->numel() / width; + if (dx->dims() == dout->dims()) { + width = dy->numel(); + height = dout->numel() / width; + out_data = dy_data; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else { + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } + // special optimization using cub + if (width == 1) { + int nums = height; + size_t temp_storage_bytes = 0; + auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, + dout_data, out_data, nums, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err); + framework::Tensor tmp; + auto *temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), + ctx.GetPlace()); + err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, + dout_data, out_data, nums, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err); + return; + } + + constexpr int block_x = 32; + constexpr int block_y = 32; + dim3 blocks(block_x, block_y); + + int max_physical_threads = + ctx.cuda_device_context().GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); + int theory_block = (width + blocks.x - 1) / blocks.x; + dim3 grids(std::min(theory_block, max_blocks)); +#if CUDA_VERSION >= 10000 + if (std::is_same::value && width < 2048 && + width % 2 == 0 && height % 64 == 0) { + auto &dev_ctx = + ctx.template device_context(); + math::SetConstant functor; + if (dout->dims() == dx->dims()) + functor(dev_ctx, dy, static_cast(0)); + else + functor(dev_ctx, dx, static_cast(0)); + const __half2 *ptr1 = reinterpret_cast(dout_data); + __half2 *ptr2 = reinterpret_cast<__half2 *>(out_data); + const int threads = 128; + dim3 grid(1, (height + 64 - 1) / 64); + VecFP16MatrixColReduce<64><<>>(ptr1, ptr2, + width, height); + return; + } +#endif + + if (width / height < 32) { + MatrixColReduce<<>>( + dout_data, out_data, width, height); + } else { + size_t thread_nums = 1024; + size_t block_nums = (width + thread_nums - 1) / thread_nums; + int vec_size = VectorizedSize(dout_data); + if (vec_size == 4 && width % 4 == 0) { + block_nums = (width / vec_size + thread_nums - 1) / thread_nums; + VecMatrixReduceLongWidth<<>>( + dout_data, out_data, width, height); + } else { + MatrixReduceLongWidth<<>>( + dout_data, out_data, width, height); + } + } + return; + } + +#endif +#endif // Special case when dy is not needed and dx doesn't reduce if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { VLOG(4) << "Special case when dy is not needed and dx doesn't " diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 318ef9fd39a..6abc97fd583 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -351,6 +351,16 @@ class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp): self.axis = -1 +class TestElementwiseFP16AddOp_commonuse_add1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(20, 30, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp): def init_input_output(self): self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype) @@ -501,4 +511,5 @@ class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp): if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab