From 3825b40ff3d4ac636c09cdff9747b1c63666e1d9 Mon Sep 17 00:00:00 2001 From: Noel Date: Tue, 25 Jan 2022 11:54:22 +0800 Subject: [PATCH] [pnorm] fix bug in fp16 & optimize memory (#39011) --- paddle/fluid/operators/p_norm_op.cu | 92 ++++++------------- .../fluid/operators/reduce_ops/logsumexp_op.h | 9 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 29 +++--- .../operators/reduce_ops/reduce_op_function.h | 3 +- .../fluid/tests/unittests/test_norm_all.py | 87 +++++++++++++++--- 5 files changed, 123 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 88e94ba039a..c0bd906685d 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -76,22 +76,13 @@ struct AbsFunctor { } }; -template +template struct UnsignedPowFunctor { HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { this->porder = porder; } - HOSTDEVICE inline Ty operator()(const Tx x) const { - return static_cast(inline_pow(inline_abs(x), static_cast(porder))); - } - float porder; -}; - -template -struct PowFunctor { - HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; } - HOSTDEVICE inline Ty operator()(const Tx x) const { - return static_cast(inline_pow(x, static_cast(porder))); + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); } float porder; }; @@ -105,13 +96,11 @@ class PnormCUDAKernel : public framework::OpKernel { const T* x = in_x->data(); T* norm = out_norm->mutable_data(ctx.GetPlace()); auto xdim = in_x->dims(); - auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); bool asvector = ctx.Attr("asvector"); int axis = ctx.Attr("axis"); std::vector reduce_axis = {axis}; reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector); - auto stream = ctx.cuda_device_context().stream(); using MT = typename details::MPTypeTrait::Type; @@ -125,29 +114,17 @@ class PnormCUDAKernel : public framework::OpKernel { TensorReduceFunctorImpl>( *in_x, out_norm, AbsFunctor(), reduce_axis, stream); } else { - framework::Tensor tmp_x; - tmp_x.mutable_data(xdim, ctx.GetPlace()); - std::vector ins = {in_x}; - std::vector outs = {&tmp_x}; - auto func = UnsignedPowFunctor(porder); + TensorReduceFunctorImpl>( + *in_x, out_norm, UnsignedPowFunctor(porder), reduce_axis, stream); + + const framework::Tensor* tmp_norm = out_norm; + std::vector ins = {tmp_norm}; + std::vector outs = {out_norm}; const auto& cuda_ctx = ctx.template device_context(); - - paddle::operators::LaunchSameDimsElementwiseCudaKernel< - ElementwiseType::kUnary, MT, T, UnsignedPowFunctor>( - cuda_ctx, ins, &outs, func); - framework::Tensor tmp_y; - tmp_y.mutable_data(ndim, ctx.GetPlace()); - TensorReduceFunctorImpl>( - tmp_x, &tmp_y, kps::IdentityFunctor(), reduce_axis, stream); - const framework::Tensor* tmp_norm = &tmp_y; - ins = {tmp_norm}; - outs = {out_norm}; - auto func_inverse = UnsignedPowFunctor(1. / porder); - paddle::operators::LaunchSameDimsElementwiseCudaKernel< - ElementwiseType::kUnary, MT, T, UnsignedPowFunctor>( - cuda_ctx, ins, &outs, func_inverse); + ElementwiseType::kUnary, T, T, UnsignedPowFunctor>( + cuda_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); } } }; @@ -158,29 +135,25 @@ struct AbsMaxAndMinGradFunctor { typename DY, typename Dim> void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - auto equals = ((*x).abs() == y->broadcast(dim)); - auto ones = dx->constant(static_cast(1.)); - auto negs = dx->constant(static_cast(-1.)); - auto zeros = dx->constant(static_cast(0.)); - auto positives = (*x) > zeros; - dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) * - positives.select(ones, negs); + dx->device(place) = dy->broadcast(dim) * (*x).sign() * + ((*x).abs() == y->broadcast(dim)).template cast(); } }; template -struct PNormPostGradFunctor { +struct PNormGradFunctor { + HOSTDEVICE explicit inline PNormGradFunctor(float porder) { + this->porder = static_cast(porder - 1.); + } template void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - auto ones = dx->constant(static_cast(1.)); - auto negs = dx->constant(static_cast(-1.)); - auto zeros = dx->constant(static_cast(0.)); - auto positives = (*x) > zeros; - dx->device(place) = (*dx) * dy->broadcast(dim) * y->broadcast(dim) * - positives.select(ones, negs); + dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() * + dy->broadcast(dim) * + (*y).pow(-this->porder).broadcast(dim); } + T porder; }; template @@ -207,26 +180,13 @@ class PnormGradCUDAKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(cuda_ctx, out_dx, static_cast(0)); } else if (porder == INFINITY || porder == -INFINITY) { + AbsMaxAndMinGradFunctor functor; LaunchReduceGradKernel>( - ctx, in_x, in_norm, in_norm_dy, out_dx, dims, reduce_all); + ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); } else { - framework::Tensor tmp_norm; - tmp_norm.mutable_data(in_norm->dims(), ctx.GetPlace()); - std::vector ins = {in_norm}; - std::vector outs = {&tmp_norm}; - auto pow_functor = PowFunctor(1. - porder); - paddle::operators::LaunchSameDimsElementwiseCudaKernel< - ElementwiseType::kUnary, T, T, PowFunctor>(cuda_ctx, ins, &outs, - pow_functor); - ins = {in_x}; - outs = {out_dx}; - auto unsigned_pow = UnsignedPowFunctor(porder - 1.); - paddle::operators::LaunchSameDimsElementwiseCudaKernel< - ElementwiseType::kUnary, T, T, UnsignedPowFunctor>( - cuda_ctx, ins, &outs, unsigned_pow); - const framework::Tensor* tmp_norm_const = &tmp_norm; - LaunchReduceGradKernel>( - ctx, in_x, tmp_norm_const, in_norm_dy, out_dx, dims, reduce_all); + auto functor = PNormGradFunctor(porder); + LaunchReduceGradKernel>( + ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); } } }; diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.h b/paddle/fluid/operators/reduce_ops/logsumexp_op.h index 06c9f23dd2c..4490f08b212 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.h +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.h @@ -139,26 +139,27 @@ class LogsumexpGradKernel : public framework::OpKernel { broadcast_dim[0]); } else { int rank = input->dims().size(); + LogsumexpGradFunctor functor; switch (rank) { case 1: ReduceGradFunctor( context.template device_context(), *input, *output, - *output_grad, input_grad, axis); + *output_grad, input_grad, functor, axis); break; case 2: ReduceGradFunctor( context.template device_context(), *input, *output, - *output_grad, input_grad, axis); + *output_grad, input_grad, functor, axis); break; case 3: ReduceGradFunctor( context.template device_context(), *input, *output, - *output_grad, input_grad, axis); + *output_grad, input_grad, functor, axis); break; case 4: ReduceGradFunctor( context.template device_context(), *input, *output, - *output_grad, input_grad, axis); + *output_grad, input_grad, functor, axis); break; } } diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 2e5bd7a42b1..661fb772f1c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -143,7 +143,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context, const framework::Tensor* x, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, - const std::vector& dims) { + Functor functor, const std::vector& dims) { const int64_t unreduced = out->numel(); const int64_t reduced = x->numel() / unreduced; DDim out_dim(out->dims()); @@ -157,7 +157,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context, dx->Resize({unreduced, reduced}); ReduceGradFunctor( context.template device_context(), shuffled_x, *out, *dout, - dx, {1}); + dx, functor, {1}); // transpose dX std::vector origin_axis(x_dim.size()); GetOriginDimFromShuffled(x_dim, dims, &origin_axis); @@ -333,7 +333,7 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context, const framework::Tensor* input0, const framework::Tensor* input1, const framework::Tensor* input2, - paddle::framework::Tensor* output, + paddle::framework::Tensor* output, Functor functor, const std::vector& dims, bool reduce_all = false) { if (reduce_all) { @@ -345,7 +345,6 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context, *context.template device_context().eigen_device(); auto broadcast_dim = Eigen::array({{static_cast(input0->numel())}}); - Functor functor; functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, broadcast_dim[0]); } else { @@ -354,36 +353,36 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context, case 1: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; case 2: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; case 3: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; case 4: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; case 5: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; case 6: ReduceGradFunctor( context.template device_context(), *input0, *input1, - *input2, output, dims); + *input2, output, functor, dims); break; default: - HandleLargeDimGrad(context, input0, input1, - input2, output, dims); + HandleLargeDimGrad( + context, input0, input1, input2, output, functor, dims); break; } } @@ -430,8 +429,10 @@ class ReduceGradKernel : public framework::OpKernel { // NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and // not be set as Input in grad Maker, use Out_grad to replace here if (!input1) input1 = input2; - LaunchReduceGradKernel( - context, input0, input1, input2, output, const_dims, reduce_all); + Functor functor; + LaunchReduceGradKernel(context, input0, input1, + input2, output, functor, + const_dims, reduce_all); } void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op_function.h b/paddle/fluid/operators/reduce_ops/reduce_op_function.h index 3da27bc8ac8..1f3839c8dc7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op_function.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op_function.h @@ -74,7 +74,7 @@ void ReduceGradFunctor(const DeviceContext& context, const framework::Tensor& input0, const framework::Tensor& input1, const framework::Tensor& input2, - framework::Tensor* output, + framework::Tensor* output, Functor functor, const std::vector& dims) { auto x = EigenTensor::From(input0); auto x_grad = EigenTensor::From(*output); @@ -100,7 +100,6 @@ void ReduceGradFunctor(const DeviceContext& context, auto& place = *context.eigen_device(); - Functor functor; functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, broad_cats_times); } diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index 352089e1fb7..b20305b78ef 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -19,11 +19,12 @@ import numpy as np from op_test import OpTest import paddle import paddle.fluid as fluid +import paddle.fluid.core as core -def p_norm(x, axis, porder, keepdims=False): +def p_norm(x, axis, porder, keepdims=False, reduce_all=False): r = [] - if axis is None: + if axis is None or reduce_all: x = x.flatten() if porder == np.inf: r = np.amax(np.abs(x), keepdims=keepdims) @@ -53,8 +54,8 @@ def p_norm(x, axis, porder, keepdims=False): else: if isinstance(axis, list): axis = tuple(axis) - r = np.linalg.norm( - x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype) + r = np.linalg.norm(x, ord=porder, axis=axis, keepdims=keepdims) + r = r.astype(x.dtype) return r @@ -111,13 +112,14 @@ class TestPnormOp(OpTest): self.op_type = "p_norm" self.init_test_case() x = (np.random.random(self.shape) + 0.5).astype(self.dtype) - norm = p_norm(x, self.axis, self.porder, self.keepdim) + norm = p_norm(x, self.axis, self.porder, self.keepdim, self.asvector) self.inputs = {'X': x} self.attrs = { 'epsilon': self.epsilon, 'axis': self.axis, 'keepdim': self.keepdim, - 'porder': float(self.porder) + 'porder': float(self.porder), + 'asvector': self.asvector } self.outputs = {'Out': norm} self.gradient = self.calc_gradient() @@ -135,34 +137,42 @@ class TestPnormOp(OpTest): self.porder = 2.0 self.keepdim = False self.dtype = "float64" + self.asvector = False def calc_gradient(self): self.attrs = { 'epsilon': self.epsilon, 'axis': self.axis, 'keepdim': self.keepdim, - 'porder': float(self.porder) + 'porder': float(self.porder), + 'asvector': self.asvector } x = self.inputs["X"] porder = self.attrs["porder"] axis = self.attrs["axis"] + asvector = self.attrs["asvector"] + x_dtype = x.dtype + x = x.astype(np.float32) if x.dtype == np.float16 else x if porder == 0: grad = np.zeros(x.shape).astype(x.dtype) elif porder in [float("inf"), float("-inf")]: - norm = p_norm(x, axis=axis, porder=porder, keepdims=True) + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) x_abs = np.abs(x) grad = np.sign(x) grad[x_abs != norm] = 0.0 else: - norm = p_norm(x, axis=axis, porder=porder, keepdims=True) + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) grad = np.power(norm, 1 - porder) * np.power( np.abs(x), porder - 1) * np.sign(x) numel = 1 for s in x.shape: numel *= s - numel /= x.shape[axis] - return [grad.astype(x.dtype) * 1 / numel] + divisor = numel if asvector else x.shape[axis] + numel /= divisor + return [grad.astype(x_dtype) * 1 / numel] class TestPnormOp2(TestPnormOp): @@ -173,6 +183,7 @@ class TestPnormOp2(TestPnormOp): self.porder = 2.0 self.keepdim = True self.dtype = "float32" + self.asvector = False def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -186,6 +197,7 @@ class TestPnormOp3(TestPnormOp): self.porder = np.inf self.keepdim = True self.dtype = "float32" + self.asvector = False def test_check_grad(self): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) @@ -199,6 +211,7 @@ class TestPnormOp4(TestPnormOp): self.porder = -np.inf self.keepdim = True self.dtype = "float32" + self.asvector = False def test_check_grad(self): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) @@ -212,11 +225,63 @@ class TestPnormOp5(TestPnormOp): self.porder = 0 self.keepdim = True self.dtype = "float32" + self.asvector = False def test_check_grad(self): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) +class TestPnormOp6(TestPnormOp): + def init_test_case(self): + self.shape = [3, 20, 3] + self.axis = -1 + self.epsilon = 1e-12 + self.porder = 2 + self.keepdim = False + self.dtype = "float32" + self.asvector = True + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPnormOpFP16(TestPnormOp): + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = False + self.dtype = "float16" + self.asvector = False + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['X'], 'Out', user_defined_grads=self.gradient) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPnormOpFP161(TestPnormOpFP16): + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = False + self.dtype = "float16" + self.asvector = True + + def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False): with fluid.program_guard(fluid.Program()): data = fluid.data(name="X", shape=shape_x, dtype=dtype) -- GitLab