From 3fc4fa299eb6c237e2e8e9bb0a1b565f46d9b149 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 28 Sep 2022 19:46:54 +0800 Subject: [PATCH] [New AD] Fix p_norm n=1 issue (#46514) * fix p_norm n=1 issue * fix p norm test error --- .../paddle/fluid/tests/unittests/autograd/test_orig2prim.py | 4 +--- python/paddle/incubate/autograd/primrules.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 275e5f1bee8..fe7b37941b4 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -462,9 +462,7 @@ class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim): } self.orig2prim_args = (X, ) - self.all_ops = [ - 'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p' - ] + self.all_ops = ['p_norm', 'reshape_p', 'abs_p', 'reduce_sum_p'] self.out_map = {0: self.output['Out']} diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 73058912761..3be0816864f 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -344,7 +344,7 @@ def p_norm_orig2prim(op, x): if abs(op.attr('porder') - 2.0) < 1e-5: return sqrt(reduce_sum(mul(x, x), axis=[0])) elif abs(op.attr('porder') - 1.0) < 1e-5: - return reduce_sum(sqrt(mul(x, x)), axis=[0]) + return reduce_sum(primops.abs(x), axis=[0]) else: raise RuntimeError('Only support lower l2/l1 norm currently') -- GitLab