From a39eba77eb44f7b56159b5ae71bb37760f955a83 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 10 Sep 2018 03:28:43 -0500 Subject: [PATCH] Implement norm_op by CUDA instead of Eigen. (#13273) * Implement norm_op by CUDA instead of Eigen. * Remove the commented code. --- paddle/fluid/operators/norm_op.cu | 149 +++++++++++++++++- paddle/fluid/operators/norm_op.h | 5 +- .../fluid/tests/unittests/test_norm_op.py | 22 +++ 3 files changed, 169 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/norm_op.cu b/paddle/fluid/operators/norm_op.cu index 1d0021d33f..67449aa4c6 100644 --- a/paddle/fluid/operators/norm_op.cu +++ b/paddle/fluid/operators/norm_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -11,14 +11,151 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU +#include +#include "cub/cub.cuh" #include "paddle/fluid/operators/norm_op.h" +namespace paddle { +namespace operators { + +__device__ __forceinline__ float square_root(float x) { return sqrtf(x); } + +__device__ __forceinline__ double square_root(double x) { return sqrt(x); } + +template +__global__ void Normalize(const T* x, const int pre, + const int axis_n, // dim in axis + const int post, const T eps, T* y, T* out_norm) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + int base = (i / post) * post * axis_n + (i % post); + + T sum = 0.0; + __shared__ T norm; + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + const T x_ij = x[base + j * post]; + sum += x_ij * x_ij; + } + T reduce_result = BlockReduce(temp_storage).Sum(sum); + + if (threadIdx.x == 0) { + norm = square_root(reduce_result + eps); + out_norm[i] = norm; + } + __syncthreads(); + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + const int index = base + j * post; + y[index] = x[index] / norm; + } + } +} + +template +class NormCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* out_y = ctx.Output("Out"); + auto* out_norm = ctx.Output("Norm"); + const T* x = in_x->data(); + T* y = out_y->mutable_data(ctx.GetPlace()); + T* norm = out_norm->mutable_data(ctx.GetPlace()); + + auto xdim = in_x->dims(); + auto ndim = out_norm->dims(); + int axis = ctx.Attr("axis"); + T eps = static_cast(ctx.Attr("epsilon")); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + + auto& dev_ctx = ctx.cuda_device_context(); + + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid = std::min(max_blocks, pre * post); + Normalize<<>>(x, pre, n, post, + eps, y, norm); + } +}; + +template +__global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad, + const int pre, const int axis_n, + const int post, T* x_grad) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage_sum; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + T sum = 0.0; + __shared__ T row_sum; + __shared__ T row_sqrt_norm; + __shared__ T row_norm; + + auto base = (i / post) * post * axis_n + (i % post); + + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + int index = base + j * post; + sum += x[index] * y_grad[index]; + } + T reduce_result = BlockReduce(temp_storage_sum).Sum(sum); + + if (threadIdx.x == 0) { + row_sum = reduce_result; + row_sqrt_norm = x_norm[i]; + row_norm = row_sqrt_norm * row_sqrt_norm; + } + __syncthreads(); + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + int index = base + j * post; + const T x_ij = x[index]; + const T dy_ij = y_grad[index]; + x_grad[index] = (dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm; + } + } +} + +template +class NormGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* in_norm = ctx.Input("Norm"); + auto* in_dy = ctx.Input(framework::GradVarName("Out")); + auto* out_dx = ctx.Output(framework::GradVarName("X")); + T* dx = out_dx->mutable_data(ctx.GetPlace()); + const T* x = in_x->data(); + const T* x_norm = in_norm->data(); + const T* dy = in_dy->data(); + + auto xdim = in_x->dims(); + int axis = ctx.Attr("axis"); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + + auto& dev_ctx = ctx.cuda_device_context(); + + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid = std::min(max_blocks, pre * post); + NormalizeGradient<<>>( + x, x_norm, dy, pre, n, post, dx); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(norm, ops::NormKernel, - ops::NormKernel); -REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradKernel, - ops::NormGradKernel); +REGISTER_OP_CUDA_KERNEL(norm, ops::NormCUDAKernel, + ops::NormCUDAKernel); +REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradCUDAKernel, + ops::NormGradCUDAKernel); diff --git a/paddle/fluid/operators/norm_op.h b/paddle/fluid/operators/norm_op.h index 3167bdc8ac..d0224177ec 100644 --- a/paddle/fluid/operators/norm_op.h +++ b/paddle/fluid/operators/norm_op.h @@ -65,14 +65,17 @@ class NormKernel : public framework::OpKernel { Eigen::DSizes rdim(1); // y = x / sqrt((sum(x * x) + epsilon)) // norm = sqrt(sum(x * x) + epsilon) - auto sum = x.pow(2).sum(rdim) + eps; + auto x2 = x * x; + auto sum = x2.sum(rdim) + eps; norm.device(*place) = sum.sqrt(); + // y = x / norm Eigen::DSizes rshape(pre, 1, post); Eigen::DSizes bcast(1, n, 1); y.device(*place) = x / norm.reshape(rshape).broadcast(bcast); } }; + template class NormGradKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_norm_op.py b/python/paddle/fluid/tests/unittests/test_norm_op.py index 22bc45ff1e..a424260312 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_norm_op.py @@ -63,5 +63,27 @@ class TestNormOp3(TestNormOp): self.epsilon = 1e-8 +class TestNormOp4(TestNormOp): + def init_test_case(self): + self.shape = [128, 1024, 14, 14] + self.axis = 2 + self.epsilon = 1e-8 + + def test_check_grad(self): + # since the gradient check is very slow in large shape, so skip check_grad + pass + + +class TestNormOp5(TestNormOp): + def init_test_case(self): + self.shape = [2048, 2048] + self.axis = 1 + self.epsilon = 1e-8 + + def test_check_grad(self): + # since the gradient check is very slow in large shape, so skip check_grad + pass + + if __name__ == '__main__': unittest.main() -- GitLab