未验证 提交 4558d395 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

fix Norm op error (#26771)


* fix frobenius_norm error, rm p=0 2-axis support. test=develop
上级 4d7d6612
......@@ -105,6 +105,12 @@ class PnormOp : public framework::OperatorWithKernel {
bool asvector = ctx->Attrs().Get<bool>("asvector");
if (asvector) {
reduce_dims.emplace_back(1);
if (keepdim) {
for (int i = 1; i < x_dim.size(); ++i) {
reduce_dims.emplace_back(1);
}
x_dim = framework::make_ddim(reduce_dims);
}
} else {
if (axis < 0) axis = x_dim.size() + axis;
for (int i = 0; i < x_dim.size(); ++i) {
......
......@@ -26,11 +26,11 @@ def p_norm(x, axis, porder, keepdims=False):
if axis is None:
x = x.flatten()
if porder == np.inf:
r = np.amax(np.abs(x))
r = np.amax(np.abs(x), keepdims=keepdims)
elif porder == -np.inf:
r = np.amin(np.abs(x))
r = np.amin(np.abs(x), keepdims=keepdims)
else:
r = np.linalg.norm(x, ord=porder)
r = np.linalg.norm(x, ord=porder, keepdims=keepdims)
elif isinstance(axis, list or tuple) and len(axis) == 2:
if porder == np.inf:
axis = tuple(axis)
......@@ -41,10 +41,10 @@ def p_norm(x, axis, porder, keepdims=False):
elif porder == 0:
axis = tuple(axis)
r = x.astype(bool)
r = np.sum(r, axis)
r = np.sum(r, axis, keepdims=keepdims)
elif porder == 1:
axis = tuple(axis)
r = np.sum(np.abs(x), axis)
r = np.sum(np.abs(x), axis, keepdims=keepdims)
else:
axis = tuple(axis)
xp = np.power(np.abs(x), porder)
......@@ -61,7 +61,7 @@ def p_norm(x, axis, porder, keepdims=False):
def frobenius_norm(x, axis=None, keepdims=False):
if isinstance(axis, list): axis = tuple(axis)
if axis is None: axis = (-2, -1)
if axis is None: x = x.reshape(1, x.size)
r = np.linalg.norm(
x, ord='fro', axis=axis, keepdims=keepdims).astype(x.dtype)
return r
......@@ -217,28 +217,37 @@ class TestPnormOp5(TestPnormOp):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
def run_fro(self, p, axis, shape_x, dtype):
def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(x=data, p=p, axis=axis)
out = paddle.norm(x=data, p=p, axis=axis, keepdim=keep_dim)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = frobenius_norm(np_input, axis=axis)
expected_result = frobenius_norm(np_input, axis=axis, keepdims=keep_dim)
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
if keep_dim and check_dim:
self.assertEqual(
(np.abs(np.array(result.shape) - np.array(expected_result.shape)) <
1e-6).all(), True)
def run_pnorm(self, p, axis, shape_x, dtype):
def run_pnorm(self, p, axis, shape_x, dtype, keep_dim, check_dim=False):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(x=data, p=p, axis=axis)
out = paddle.norm(x=data, p=p, axis=axis, keepdim=keep_dim)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype)
expected_result = p_norm(
np_input, porder=p, axis=axis, keepdims=keep_dim).astype(dtype)
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
if keep_dim and check_dim:
self.assertEqual(
(np.abs(np.array(result.shape) - np.array(expected_result.shape)) <
1e-6).all(), True)
def run_graph(self, p, axis, shape_x, dtype):
......@@ -253,6 +262,7 @@ def run_graph(self, p, axis, shape_x, dtype):
# compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro')
out_fro = paddle.norm(x, p='fro', axis=0)
out_fro = paddle.norm(x, p='fro', axis=[0, 1])
# compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0, 1])
......@@ -274,27 +284,133 @@ def run_graph(self, p, axis, shape_x, dtype):
class API_NormTest(unittest.TestCase):
def test_basic(self):
run_fro(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
run_fro(self, p='fro', axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=np.inf, axis=0, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=np.inf, axis=None, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=0, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=None, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=1, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=0, axis=None, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=2, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=-1, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=1, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
keep_dims = {False, True}
for keep in keep_dims:
run_fro(
self,
p='fro',
axis=None,
shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep)
run_fro(
self,
p='fro',
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=2,
axis=None,
shape_x=[3, 4],
dtype="float32",
keep_dim=keep)
run_pnorm(
self,
p=2,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=np.inf,
axis=0,
shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=np.inf,
axis=None,
shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep)
run_pnorm(
self,
p=-np.inf,
axis=0,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=-np.inf,
axis=None,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep)
run_pnorm(
self,
p=0,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=1,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=0,
axis=None,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=2,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=2,
axis=-1,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=1,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=np.inf,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=-np.inf,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
def test_dygraph(self):
run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
......@@ -315,6 +431,7 @@ class API_NormTest(unittest.TestCase):
paddle.norm(data, p=p, out=out)
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "int64")
self.assertRaises(ValueError, paddle.norm, "inf", [2], "int64")
out = fluid.data(name="out", shape=[1], dtype="int64")
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "float64",
out)
......@@ -325,6 +442,7 @@ class API_NormTest(unittest.TestCase):
self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm")
self.assertRaises(ValueError, paddle.norm, data, p=[1])
self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1)
self.assertRaises(ValueError, paddle.norm, 0, [1, 0], "float64")
data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64")
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1])
......
......@@ -183,12 +183,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
x (Tensor): The input tensor could be N-D tensor, and the input data
type could be float32 or float64.
p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`,
`inf`,`-inf` and any positive real number yielding the corresponding p-norm.
Not supported: ord < 0, nuclear norm.
`inf`, `-inf` and any positive real number yielding the corresponding p-norm. Not supported: ord < 0 and nuclear norm.
Default value is `fro`.
axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int
or list(int)/tuple(int) with only one element, the vector norm is computed over the axis.
If `axis < 0`, the dimension to norm operation is rank(input) + axis.
If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis.
Defalut value is `None`.
keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default
......@@ -197,13 +198,9 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor, results of norm operation on the specified axis of input tensor,
Tensor: results of norm operation on the specified axis of input tensor,
it's data type is the same as input's Tensor.
Raises:
TypeError, if out data type is different with the input data type.
ValueError, If `p` or `axis` is invalid.
Examples:
.. code-block:: python
......@@ -256,15 +253,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
"The dim of frobenius norm op should be None or two elements list!"
)
if in_dygraph_mode():
if dim is None: dim = [-1]
return core.ops.frobenius_norm(input, 'dim', dim, 'keepdim',
keepdim)
attrs = {
'dim': dim if dim != None else [-2, -1],
'keep_dim': keepdim,
'reduce_all': False
}
if len(attrs['dim']) == len(input.shape):
if dim is None:
return core.ops.frobenius_norm(input, 'keep_dim', keepdim,
'reduce_all', True)
return core.ops.frobenius_norm(input, 'dim', dim, 'keep_dim',
keepdim, 'reduce_all', False)
attrs = {'dim': dim, 'keep_dim': keepdim, 'reduce_all': False}
if dim is None:
attrs['reduce_all'] = True
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'frobenius_norm')
......@@ -351,42 +346,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
return reduce_out
def p0_matrix_norm(input, porder=0., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
cast_out = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': input},
outputs={'Out': cast_out},
attrs={
'in_dtype': input.dtype,
'out_dtype': int(core.VarDesc.VarType.BOOL)
})
cast_out2 = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': cast_out},
outputs={'Out': cast_out2},
attrs={
'in_dtype': cast_out.dtype,
'out_dtype': int(core.VarDesc.VarType.FP32)
})
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_sum',
inputs={'X': cast_out2},
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False
})
return sum_out
def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
......@@ -448,7 +407,20 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
#calculate vector norm, where axis is int or list with only one integer
if isinstance(axis, int):
if isinstance(p, (int, float)):
if isinstance(p, str):
if p == "fro":
return vector_norm(
x,
porder=2,
axis=axis,
keepdim=keepdim,
asvector=False,
name=name)
else:
raise ValueError(
"only valid string values are 'fro', found {}".format(p))
elif isinstance(p, (int, float)):
return vector_norm(
x,
axis=axis,
......@@ -464,10 +436,12 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
elif isinstance(axis, list) and len(axis) == 2:
if p == "fro":
return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
elif p == 0:
return p0_matrix_norm(x, axis=axis, keepdim=keepdim, name=name)
elif p == np.inf or p == -np.inf:
return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name)
elif p == 0:
raise ValueError(
"just suport axis type int or list (length of list <=1) if p = 0, found {}".
format(axis))
else:
return p_matrix_norm(
x, porder=p, axis=axis, keepdim=keepdim, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册