From e4cc6a28b022e1f1a02aae2b90e2ab49fc13e0fb Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Thu, 27 Aug 2020 19:11:58 +0800 Subject: [PATCH] Norm op support 2-axis (#26492) --- paddle/fluid/operators/p_norm_op.cc | 16 +- paddle/fluid/operators/p_norm_op.cu | 6 +- paddle/fluid/operators/p_norm_op.h | 22 +- paddle/fluid/operators/top_k_v2_op.cu | 1 - paddle/fluid/operators/top_k_v2_op.h | 13 + .../fluid/tests/unittests/test_norm_all.py | 125 +++++++--- python/paddle/tensor/linalg.py | 227 ++++++++++++++---- 7 files changed, 311 insertions(+), 99 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cc b/paddle/fluid/operators/p_norm_op.cc index aa39821051..59035d5a8c 100644 --- a/paddle/fluid/operators/p_norm_op.cc +++ b/paddle/fluid/operators/p_norm_op.cc @@ -42,6 +42,11 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker { "keepdim", "(bool, default false) Whether to keep the dimensions as the input.") .SetDefault(false); + + AddAttr("asvector", + "(bool, default false) as vector norm when axis is None and " + "input is matrix, ") + .SetDefault(false); AddOutput("Out", "(Tensor) Output result tensor of p-norm"); AddComment(R"DOC( Pnorm Operator. @@ -96,10 +101,15 @@ class PnormOp : public framework::OperatorWithKernel { "Current Input(X)'s shape is=[%s].", axis, x_rank, x_dim)); - if (axis < 0) axis = x_dim.size() + axis; std::vector reduce_dims; - for (int i = 0; i < x_dim.size(); ++i) { - if (i != axis) reduce_dims.emplace_back(x_dim[i]); + bool asvector = ctx->Attrs().Get("asvector"); + if (asvector) { + reduce_dims.emplace_back(1); + } else { + if (axis < 0) axis = x_dim.size() + axis; + for (int i = 0; i < x_dim.size(); ++i) { + if (i != axis) reduce_dims.emplace_back(x_dim[i]); + } } x_dim[axis] = 1; diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 63f2a1c56c..ba0d46f4c7 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -129,9 +129,10 @@ class PnormCUDAKernel : public framework::OpKernel { auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); + bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post); + GetDims(xdim, axis, &pre, &n, &post, asvector); auto& dev_ctx = ctx.cuda_device_context(); @@ -230,9 +231,10 @@ class PnormGradCUDAKernel : public framework::OpKernel { float porder = ctx.Attr("porder"); T eps = static_cast(ctx.Attr("epsilon")); int axis = ctx.Attr("axis"); + bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post); + GetDims(xdim, axis, &pre, &n, &post, asvector); auto& dev_ctx = ctx.cuda_device_context(); diff --git a/paddle/fluid/operators/p_norm_op.h b/paddle/fluid/operators/p_norm_op.h index 7620d1421e..8fca6924a2 100644 --- a/paddle/fluid/operators/p_norm_op.h +++ b/paddle/fluid/operators/p_norm_op.h @@ -20,15 +20,19 @@ namespace paddle { namespace operators { inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, - int* post) { + int* post, bool asvector) { *pre = 1; *post = 1; *n = dim[axis]; - for (int i = 0; i < axis; ++i) { - (*pre) *= dim[i]; - } - for (int i = axis + 1; i < dim.size(); ++i) { - (*post) *= dim[i]; + if (asvector) { + *n = product(dim); + } else { + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } } } @@ -43,9 +47,10 @@ class PnormKernel : public framework::OpKernel { auto xdim = in_x->dims(); float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); + bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post); + GetDims(xdim, axis, &pre, &n, &post, asvector); auto* place = ctx.template device_context().eigen_device(); @@ -91,9 +96,10 @@ class PnormGradKernel : public framework::OpKernel { float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); + bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post); + GetDims(xdim, axis, &pre, &n, &post, asvector); Eigen::DSizes shape(pre, n, post); Eigen::DSizes rshape(pre, 1, post); diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 5154503292..2c94dca1e3 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -14,7 +14,6 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_v2_op.h" diff --git a/paddle/fluid/operators/top_k_v2_op.h b/paddle/fluid/operators/top_k_v2_op.h index a77285d123..89b5d36b1b 100644 --- a/paddle/fluid/operators/top_k_v2_op.h +++ b/paddle/fluid/operators/top_k_v2_op.h @@ -33,6 +33,19 @@ limitations under the License. */ namespace paddle { namespace operators { +inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, + int* post) { + *pre = 1; + *post = 1; + *n = dim[axis]; + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } +} + template static void FullTopK(Type input_height, Type input_width, int input_dim, const framework::Tensor* input, T* t_out, Type* t_indices, diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index 0d083038c6..c047cf6ddf 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -22,9 +22,40 @@ import paddle.fluid as fluid def p_norm(x, axis, porder, keepdims=False): - if axis is None: axis = -1 - r = np.linalg.norm( - x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype) + r = [] + if axis is None: + x = x.flatten() + if porder == np.inf: + r = np.amax(np.abs(x)) + elif porder == -np.inf: + r = np.amin(np.abs(x)) + else: + r = np.linalg.norm(x, ord=porder) + elif isinstance(axis, list or tuple) and len(axis) == 2: + if porder == np.inf: + axis = tuple(axis) + r = np.amax(np.abs(x), axis=axis, keepdims=keepdims) + elif porder == -np.inf: + axis = tuple(axis) + r = np.amin(np.abs(x), axis=axis, keepdims=keepdims) + elif porder == 0: + axis = tuple(axis) + r = x.astype(bool) + r = np.sum(r, axis) + elif porder == 1: + axis = tuple(axis) + r = np.sum(np.abs(x), axis) + else: + axis = tuple(axis) + xp = np.power(np.abs(x), porder) + s = np.sum(xp, axis=axis, keepdims=keepdims) + r = np.power(s, 1.0 / porder) + else: + if isinstance(axis, list): + axis = tuple(axis) + r = np.linalg.norm( + x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype) + return r @@ -186,22 +217,10 @@ class TestPnormOp5(TestPnormOp): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) -def run_out(self, p, axis, shape_x, shape_y, dtype): - with fluid.program_guard(fluid.Program()): - data1 = fluid.data(name="X", shape=shape_x, dtype=dtype) - data2 = fluid.data(name="Y", shape=shape_y, dtype=dtype) - out = paddle.norm(input=data1, p=p, axis=axis, out=data2) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - result = exe.run(feed={"X": np.random.rand(*shape_x).astype(dtype)}, - fetch_list=[data2, out]) - self.assertEqual((result[0] == result[1]).all(), True) - - def run_fro(self, p, axis, shape_x, dtype): with fluid.program_guard(fluid.Program()): data = fluid.data(name="X", shape=shape_x, dtype=dtype) - out = paddle.norm(input=data, p=p, axis=axis) + out = paddle.norm(x=data, p=p, axis=axis) place = fluid.CPUPlace() exe = fluid.Executor(place) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) @@ -213,35 +232,73 @@ def run_fro(self, p, axis, shape_x, dtype): def run_pnorm(self, p, axis, shape_x, dtype): with fluid.program_guard(fluid.Program()): data = fluid.data(name="X", shape=shape_x, dtype=dtype) - out = paddle.norm(input=data, p=p, axis=axis) + out = paddle.norm(x=data, p=p, axis=axis) place = fluid.CPUPlace() exe = fluid.Executor(place) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype) result, = exe.run(feed={"X": np_input}, fetch_list=[out]) - self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) + self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) + + +def run_graph(self, p, axis, shape_x, dtype): + paddle.disable_static() + shape = [2, 3, 4] + np_input = np.arange(24).astype('float32') - 12 + np_input = np_input.reshape(shape) + x = paddle.to_tensor(np_input) + #[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]] + # [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]] + out_pnorm = paddle.norm(x, p=2, axis=-1) + + # compute frobenius norm along last two dimensions. + out_fro = paddle.norm(x, p='fro') + out_fro = paddle.norm(x, p='fro', axis=[0, 1]) + # compute 2-order norm along [0,1] dimension. + out_pnorm = paddle.norm(x, p=2, axis=[0, 1]) + out_pnorm = paddle.norm(x, p=2) + #out_pnorm = [17.43559577 16.91153453 16.73320053 16.91153453] + # compute inf-order norm + out_pnorm = paddle.norm(x, p=np.inf) + #out_pnorm = [12.] + out_pnorm = paddle.norm(x, p=np.inf, axis=0) + #out_pnorm = [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]] + + # compute -inf-order norm + out_pnorm = paddle.norm(x, p=-np.inf) + #out_pnorm = [0.] + out_pnorm = paddle.norm(x, p=-np.inf, axis=0) + # out_fro = [17.43559577 16.91153453 16.73320053 16.91153453] + paddle.enable_static() class API_NormTest(unittest.TestCase): - def test_output_result(self): - run_out(self, p=2, axis=1, shape_x=[3, 4], shape_y=[3], dtype="float32") - run_out( - self, - p='fro', - axis=None, - shape_x=[3, 4], - shape_y=[1], - dtype="float32") - def test_basic(self): - run_fro(self, p='fro', axis=None, shape_x=[3, 3, 4], dtype="float32") - run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64") + run_fro(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32") + run_fro(self, p='fro', axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32") run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64") - run_pnorm(self, p=np.inf, axis=1, shape_x=[3, 4], dtype="float32") - run_pnorm(self, p=-np.inf, axis=1, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=np.inf, axis=0, shape_x=[2, 3, 4], dtype="float32") + run_pnorm(self, p=np.inf, axis=None, shape_x=[2, 3, 4], dtype="float32") + run_pnorm(self, p=-np.inf, axis=0, shape_x=[2, 3, 4], dtype="float64") + run_pnorm( + self, p=-np.inf, axis=None, shape_x=[2, 3, 4], dtype="float64") run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=1, axis=1, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=0, axis=None, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=2, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") + run_pnorm(self, p=2, axis=-1, shape_x=[2, 3, 4], dtype="float64") + run_pnorm(self, p=1, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") + run_pnorm(self, p=0, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") + run_pnorm( + self, p=np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") + run_pnorm( + self, p=-np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") + + def test_dygraph(self): + run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32") + def test_name(self): with fluid.program_guard(fluid.Program()): x = fluid.data(name="x", shape=[10, 10], dtype="float32") @@ -268,11 +325,7 @@ class API_NormTest(unittest.TestCase): self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm") self.assertRaises(ValueError, paddle.norm, data, p=[1]) self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1) - self.assertRaises( - ValueError, paddle.norm, data, p='unspport', axis=[-2, -1]) data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64") - self.assertRaises( - ValueError, paddle.norm, data, p='unspport', axis=[-2, -1]) self.assertRaises( ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1]) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a7bf2272a5..b5b528325c 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np from paddle.common_ops_import import * from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type @@ -170,7 +171,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): return out -def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): +def norm(x, p='fro', axis=None, keepdim=False, name=None): """ :alias_main: paddle.norm :alias: paddle.norm,paddle.tensor.norm,paddle.tensor.linalg.norm @@ -179,20 +180,19 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): or 2-norm, and in general the p-norm for p > 0) of a given tensor. Args: - input (Variable): The input tensor could be N-D tensor, and the input data + x (Tensor): The input tensor could be N-D tensor, and the input data type could be float32 or float64. - p (float|string, optional): Order of the norm. Supported values are `fro`, `1`, `2`, - and any positive real number yielding the corresponding p-norm. - axis (int|list, optional): The axis on which to apply norm operation. If axis is int - or list with only one element, the vector norm is computed over the axis. - If axis is a list with two elements, the matrix norm is computed over the axis. + p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`, + `inf`,`-inf` and any positive real number yielding the corresponding p-norm. + Not supported: ord < 0, nuclear norm. + axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int + or list(int)/tuple(int) with only one element, the vector norm is computed over the axis. If `axis < 0`, the dimension to norm operation is rank(input) + axis. + If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis. keepdim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have fewer dimension than the :attr:`input` unless :attr:`keepdim` is true, default value is False. - out (Variable, optional): The output tensor, default value is None. It's data type - must be the same as the input Tensor. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -208,29 +208,57 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): .. code-block:: python import paddle - import paddle.fluid as fluid - x = fluid.data(name='x', shape=[2, 3, 5], dtype='float64') - + import numpy as np + paddle.disable_static() + shape=[2, 3, 4] + np_input = np.arange(24).astype('float32') - 12 + np_input = np_input.reshape(shape) + x = paddle.to_tensor(np_input) + #[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]] + # [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]] + # compute frobenius norm along last two dimensions. - out_fro = paddle.norm(x, p='fro', axis=[1,2]) - + out_fro = paddle.norm(x, p='fro', axis=[0,1]) + # out_fro.numpy() [17.435596 16.911535 16.7332 16.911535] + # compute 2-order vector norm along last dimension. out_pnorm = paddle.norm(x, p=2, axis=-1) + #out_pnorm.numpy(): [[21.118711 13.190906 5.477226] + # [ 3.7416575 11.224972 19.131126]] + + # compute 2-order norm along [0,1] dimension. + out_pnorm = paddle.norm(x, p=2, axis=[0,1]) + #out_pnorm.numpy(): [17.435596 16.911535 16.7332 16.911535] + + # compute inf-order norm + out_pnorm = paddle.norm(x, p=np.inf) + #out_pnorm.numpy() = [12.] + out_pnorm = paddle.norm(x, p=np.inf, axis=0) + #out_pnorm.numpy(): [[12. 11. 10. 9.] [8. 7. 6. 7.] [8. 9. 10. 11.]] + + # compute -inf-order norm + out_pnorm = paddle.norm(x, p=-np.inf) + #out_pnorm.numpy(): [0.] + out_pnorm = paddle.norm(x, p=-np.inf, axis=0) + #out_pnorm.numpy(): [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]] """ - def frobenius_norm(input, dim=None, keepdim=False, out=None, name=None): + def frobenius_norm(input, dim=None, keepdim=False, name=None): """ The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`. Args: input (Variable): Tensor, data type float32, float64. dim (list, optional): None for last two dimensions. keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. - out (Variable, optional): The tensor variable storing the output. """ if dim is not None and not (isinstance(dim, list) and len(dim) == 2): raise ValueError( "The dim of frobenius norm op should be None or two elements list!" ) + if in_dygraph_mode(): + if dim is None: dim = [-1] + return core.ops.frobenius_norm(input, 'dim', dim, 'keepdim', + keepdim) attrs = { 'dim': dim if dim != None else [-2, -1], 'keep_dim': keepdim, @@ -242,16 +270,8 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): 'frobenius_norm') helper = LayerHelper('frobenius_norm', **locals()) - if out is None: - out = helper.create_variable_for_type_inference( - dtype=helper.input_dtype()) - else: - check_type(out, 'out', (Variable), 'frobenius_norm') - check_dtype( - out.dtype, out.name, - convert_dtype(input.dtype), 'frobenius_norm', - '(The out data type in frobenius_norm must be the same with input data type.)' - ) + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) helper.append_op( type='frobenius_norm', @@ -264,7 +284,7 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): porder=None, axis=None, keepdim=False, - out=None, + asvector=False, name=None): """ Calculate the p-order vector norm for certain dimension of Tensor `input`. @@ -273,32 +293,28 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): porder (float, optional): None for porder=2.0. axis (int, optional): None for last dimension. keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. - out (Variable, optional): The tensor variable storing the output. """ + if in_dygraph_mode(): + if axis is None: axis = -1 + return core.ops.p_norm(input, 'porder', porder, 'axis', axis, + 'keepdim', keepdim, 'asvector', asvector) if porder is not None: check_type(porder, 'porder', (float, int), 'p_norm') if axis is not None: check_type(axis, 'axis', (int), 'p_norm') + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'p_norm') + attrs = { 'axis': axis if axis is not None else -1, 'porder': float(porder) if porder is not None else 2.0, 'keepdim': keepdim, + 'asvector': asvector, 'epsilon': 1e-12, } - check_variable_and_dtype(input, 'input', ['float32', 'float64'], - 'p_norm') - helper = LayerHelper('p_norm', **locals()) - if out is None: - out = helper.create_variable_for_type_inference( - dtype=helper.input_dtype()) - else: - check_type(out, 'out', (Variable), 'p_norm') - check_dtype( - out.dtype, out.name, - convert_dtype(input.dtype), 'p_norm', - '(The out data type in p_norm must be the same with input data type.)' - ) + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) helper.append_op( type='p_norm', @@ -307,21 +323,126 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): attrs=attrs) return out + def inf_norm(input, + porder=None, + axis=axis, + keepdim=False, + asvector=False, + name=None): + helper = LayerHelper('frobenius_norm', **locals()) + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + helper.append_op(type='abs', inputs={'X': input}, outputs={'Out': out}) + reduce_out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + + reduce_all = True if axis == None or axis == [] or asvector == True else False + axis = axis if axis != None and axis != [] else [0] + + reduce_type = 'reduce_max' if porder == np.float( + 'inf') else 'reduce_min' + helper.append_op( + type=reduce_type, + inputs={'X': out}, + outputs={'Out': reduce_out}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + + return reduce_out + + def p0_matrix_norm(input, porder=0., axis=axis, keepdim=False, name=None): + block = LayerHelper('norm', **locals()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + + cast_out = block.create_variable_for_type_inference(dtype=bool) + block.append_op( + type='cast', + inputs={'X': input}, + outputs={'Out': cast_out}, + attrs={ + 'in_dtype': input.dtype, + 'out_dtype': int(core.VarDesc.VarType.BOOL) + }) + cast_out2 = block.create_variable_for_type_inference(dtype=bool) + block.append_op( + type='cast', + inputs={'X': cast_out}, + outputs={'Out': cast_out2}, + attrs={ + 'in_dtype': cast_out.dtype, + 'out_dtype': int(core.VarDesc.VarType.FP32) + }) + sum_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='reduce_sum', + inputs={'X': cast_out2}, + outputs={'Out': sum_out}, + attrs={ + 'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': True if axis is None else False + }) + return sum_out + + def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None): + block = LayerHelper('norm', **locals()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + abs_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='abs', inputs={'X': input}, outputs={'Out': abs_out}) + pow_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + + block.append_op( + type='pow', + inputs={'X': abs_out}, + outputs={'Out': pow_out}, + attrs={'factor': porder}) + sum_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='reduce_sum', + inputs={'X': pow_out}, + outputs={'Out': sum_out}, + attrs={ + 'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': True if axis is None else False + }) + porder + block.append_op( + type='pow', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={'factor': float(1. / porder)}) + return out + if axis is None and p is not None: if isinstance(p, str): if p == "fro": - return frobenius_norm( - input, dim=axis, keepdim=keepdim, out=out, name=name) + return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name) else: raise ValueError( "only valid string values are 'fro', found {}".format(p)) elif isinstance(p, (int, float)): return vector_norm( - input, porder=p, axis=axis, keepdim=keepdim, out=out, name=name) + x, + porder=p, + axis=axis, + keepdim=keepdim, + asvector=True, + name=name) else: raise ValueError("only valid p type is string or float, found {}". format(type(p))) + if isinstance(axis, tuple): + axis = list(axis) if isinstance(axis, list) and len(axis) == 1: axis = axis[0] @@ -329,7 +450,12 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): if isinstance(axis, int): if isinstance(p, (int, float)): return vector_norm( - input, axis=axis, porder=p, keepdim=keepdim, out=out, name=name) + x, + axis=axis, + porder=p, + keepdim=keepdim, + asvector=False, + name=name) else: raise ValueError( "unspport p for p-order vector norm. except float, found {}". @@ -337,11 +463,14 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): #calculate matrix norm, where axis is list with two integers elif isinstance(axis, list) and len(axis) == 2: if p == "fro": - return frobenius_norm( - input, dim=axis, keepdim=keepdim, out=out, name=name) + return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name) + elif p == 0: + return p0_matrix_norm(x, axis=axis, keepdim=keepdim, name=name) + elif p == np.inf or p == -np.inf: + return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name) else: - raise ValueError( - "unspport p for matrix norm, expcept 'fro', found {}".format(p)) + return p_matrix_norm( + x, porder=p, axis=axis, keepdim=keepdim, name=name) else: raise ValueError( "except axis type int or list (length of list <=2), found {}". -- GitLab