未验证 提交 3fc4fa29 编写于 作者: J Jiabin Yang 提交者: GitHub

[New AD] Fix p_norm n=1 issue (#46514)

* fix p_norm n=1 issue

* fix p norm test error
上级 9c01eaed
......@@ -462,9 +462,7 @@ class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
}
self.orig2prim_args = (X, )
self.all_ops = [
'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p'
]
self.all_ops = ['p_norm', 'reshape_p', 'abs_p', 'reduce_sum_p']
self.out_map = {0: self.output['Out']}
......
......@@ -344,7 +344,7 @@ def p_norm_orig2prim(op, x):
if abs(op.attr('porder') - 2.0) < 1e-5:
return sqrt(reduce_sum(mul(x, x), axis=[0]))
elif abs(op.attr('porder') - 1.0) < 1e-5:
return reduce_sum(sqrt(mul(x, x)), axis=[0])
return reduce_sum(primops.abs(x), axis=[0])
else:
raise RuntimeError('Only support lower l2/l1 norm currently')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册