From ac4bae8ee936bdf3dbe6ba95178757ca4807540a Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 14 Dec 2020 21:30:59 +0800 Subject: [PATCH] elementwise_add_grad Op optimization (#29575) --- .../elementwise/elementwise_add_op.h | 188 ++++++++++++++++++ .../unittests/test_elementwise_add_op.py | 11 + 2 files changed, 199 insertions(+) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index acda31e0f23..0e8d202a9aa 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -116,6 +118,135 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + +template +__global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, + size_t width, size_t height) { + __shared__ T sdata[BLOCK_H][BLOCK_W + 1]; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t width_stride = gridDim.x * blockDim.x; + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + +#pragma unroll + for (size_t w = idx; w < full_width; w += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + __syncthreads(); + size_t offset = w + threadIdx.y * width; +#pragma unroll + for (size_t h = threadIdx.y; h < height; + h += BLOCK_H) { // block-stride loop across matrix height + sdata[threadIdx.y][threadIdx.x] += + (w < width) ? in[offset] : (static_cast(0)); + offset += width * BLOCK_H; + } + __syncthreads(); + + T val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + + __syncthreads(); + if (threadIdx.x == 0) sdata[0][threadIdx.y] = val; + __syncthreads(); + if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; + } +} + +template +__global__ void FP16MatrixColReduce( + const paddle::platform::float16 *__restrict__ in, + paddle::platform::float16 *__restrict__ out, size_t width, size_t height) { + constexpr int repeats = BLOCK_H / BLOCK_W; + __shared__ paddle::platform::float16 sdata[BLOCK_H][BLOCK_W + 1]; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t width_stride = gridDim.x * blockDim.x; + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + +#pragma unroll + for (size_t w = idx; w < full_width; w += width_stride) { + for (int r = 0; r < repeats; r++) { + sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0; + } + __syncthreads(); + for (int r = 0; r < repeats; r++) { + size_t offset = w + (r * BLOCK_W + threadIdx.y) * width; +#pragma unroll + for (size_t h = r * BLOCK_H + threadIdx.y; h < height; + h += BLOCK_H) { // block-stride loop across matrix height + sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] += + (w < width) ? in[offset + r * BLOCK_W * width] + : (static_cast(0)); + offset += width * BLOCK_H; + } + } + __syncthreads(); + + paddle::platform::float16 result = + static_cast(0); + for (int r = 0; r < repeats; r++) { + paddle::platform::float16 val = + sdata[threadIdx.x + r * BLOCK_W][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i); + __syncthreads(); + result += val; + } + if (threadIdx.x == 0) sdata[0][threadIdx.y] = result; + __syncthreads(); + if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x]; + } +} +#endif +#endif +bool static RunSpecialDims(const framework::DDim &dx_dims, + const framework::DDim &dy_dims, + const framework::DDim &dout_dims, int axis) { + auto smaller_dims = dx_dims; + auto bigger_dims = dy_dims; + auto smaller_dims_size = smaller_dims.size(); + auto bigger_dims_size = bigger_dims.size(); + int smaller_ignore_size = 0; + int bigger_ignore_size = 0; + for (int i = 0; i < smaller_dims_size; i++) { + if (smaller_dims[i] == 1) + smaller_ignore_size++; + else + break; + } + for (int i = 0; i < bigger_dims_size; i++) { + if (bigger_dims[i] == 1) + bigger_ignore_size++; + else + break; + } + + int smaller_real_size = smaller_dims.size() - smaller_ignore_size; + int bigger_real_size = bigger_dims.size() - bigger_ignore_size; + + if (smaller_real_size == bigger_real_size) return false; + + if (bigger_real_size < smaller_real_size) { + smaller_dims = dy_dims; + bigger_dims = dx_dims; + std::swap(smaller_real_size, bigger_real_size); + } + int big_size = bigger_dims.size(); + int small_size = smaller_dims.size(); + for (int i = 1; i <= smaller_real_size; i++) { + if (bigger_dims[big_size - i] != smaller_dims[small_size - i]) return false; + } + + if (axis != -1 && (axis != (bigger_real_size - smaller_real_size))) { + return false; + } + + return true; +} + #ifdef PADDLE_WITH_CUDA // cuda definition template @@ -144,6 +275,63 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { // skip out auto *out = dout; +#ifdef PADDLE_WITH_CUDA +#ifdef __NVCC__ + + int axis = ctx.Attr("axis"); + if (ctx.GetPlace() == platform::CUDAPlace() && dx != nullptr && + dy != nullptr && dout != nullptr && dx->numel() != dy->numel() && + RunSpecialDims(dx->dims(), dy->dims(), dout->dims(), axis)) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + auto *dout_data = dout->data(); + auto stream = ctx.cuda_device_context().stream(); + auto *out_data = dx_data; + int width = dx->numel(); + int height = dout->numel() / width; + if (dx->dims() == dout->dims()) { + width = dy->numel(); + height = dout->numel() / width; + out_data = dy_data; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else { + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } + + constexpr int block_x = 32; + constexpr int block_y = 32; + dim3 blocks(block_x, block_y); + + int max_physical_threads = + ctx.cuda_device_context().GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); + int theory_block = (width + blocks.x - 1) / blocks.x; + dim3 grids(std::min(theory_block, max_blocks)); + if (std::is_same::value) { + const paddle::platform::float16 *ptr1 = + reinterpret_cast(dout_data); + paddle::platform::float16 *ptr2 = + reinterpret_cast(out_data); + if (height <= 32) { + FP16MatrixColReduce<32, 32><<>>( + ptr1, ptr2, width, height); + } else { + FP16MatrixColReduce<32, 64><<>>( + ptr1, ptr2, width, height); + } + return; + } + MatrixColReduce<<>>( + dout_data, out_data, width, height); + return; + } + +#endif +#endif // Special case when dy is not needed and dx doesn't reduce if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { VLOG(4) << "Special case when dy is not needed and dx doesn't " diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index c941d7c5f34..49c2467c9ff 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -351,6 +351,16 @@ class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp): self.axis = -1 +class TestElementwiseFP16AddOp_commonuse_add1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(20, 30, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp): def init_input_output(self): self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype) @@ -429,4 +439,5 @@ class TestAddOp(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab