未验证 提交 e4cc6a28 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

Norm op support 2-axis (#26492)

上级 dc56c898
...@@ -42,6 +42,11 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,6 +42,11 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker {
"keepdim", "keepdim",
"(bool, default false) Whether to keep the dimensions as the input.") "(bool, default false) Whether to keep the dimensions as the input.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("asvector",
"(bool, default false) as vector norm when axis is None and "
"input is matrix, ")
.SetDefault(false);
AddOutput("Out", "(Tensor) Output result tensor of p-norm"); AddOutput("Out", "(Tensor) Output result tensor of p-norm");
AddComment(R"DOC( AddComment(R"DOC(
Pnorm Operator. Pnorm Operator.
...@@ -96,10 +101,15 @@ class PnormOp : public framework::OperatorWithKernel { ...@@ -96,10 +101,15 @@ class PnormOp : public framework::OperatorWithKernel {
"Current Input(X)'s shape is=[%s].", "Current Input(X)'s shape is=[%s].",
axis, x_rank, x_dim)); axis, x_rank, x_dim));
if (axis < 0) axis = x_dim.size() + axis;
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
for (int i = 0; i < x_dim.size(); ++i) { bool asvector = ctx->Attrs().Get<bool>("asvector");
if (i != axis) reduce_dims.emplace_back(x_dim[i]); if (asvector) {
reduce_dims.emplace_back(1);
} else {
if (axis < 0) axis = x_dim.size() + axis;
for (int i = 0; i < x_dim.size(); ++i) {
if (i != axis) reduce_dims.emplace_back(x_dim[i]);
}
} }
x_dim[axis] = 1; x_dim[axis] = 1;
......
...@@ -129,9 +129,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> { ...@@ -129,9 +129,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto ndim = out_norm->dims(); auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
...@@ -230,9 +231,10 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> { ...@@ -230,9 +231,10 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
T eps = static_cast<T>(ctx.Attr<float>("epsilon")); T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
......
...@@ -20,15 +20,19 @@ namespace paddle { ...@@ -20,15 +20,19 @@ namespace paddle {
namespace operators { namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) { int* post, bool asvector) {
*pre = 1; *pre = 1;
*post = 1; *post = 1;
*n = dim[axis]; *n = dim[axis];
for (int i = 0; i < axis; ++i) { if (asvector) {
(*pre) *= dim[i]; *n = product(dim);
} } else {
for (int i = axis + 1; i < dim.size(); ++i) { for (int i = 0; i < axis; ++i) {
(*post) *= dim[i]; (*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
} }
} }
...@@ -43,9 +47,10 @@ class PnormKernel : public framework::OpKernel<T> { ...@@ -43,9 +47,10 @@ class PnormKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims(); auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post, asvector);
auto* place = ctx.template device_context<DeviceContext>().eigen_device(); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
...@@ -91,9 +96,10 @@ class PnormGradKernel : public framework::OpKernel<T> { ...@@ -91,9 +96,10 @@ class PnormGradKernel : public framework::OpKernel<T> {
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post, asvector);
Eigen::DSizes<int, 3> shape(pre, n, post); Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 3> rshape(pre, 1, post); Eigen::DSizes<int, 3> rshape(pre, 1, post);
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h" #include "paddle/fluid/operators/top_k_v2_op.h"
......
...@@ -33,6 +33,19 @@ limitations under the License. */ ...@@ -33,6 +33,19 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename T, typename Type> template <typename T, typename Type>
static void FullTopK(Type input_height, Type input_width, int input_dim, static void FullTopK(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices, const framework::Tensor* input, T* t_out, Type* t_indices,
......
...@@ -22,9 +22,40 @@ import paddle.fluid as fluid ...@@ -22,9 +22,40 @@ import paddle.fluid as fluid
def p_norm(x, axis, porder, keepdims=False): def p_norm(x, axis, porder, keepdims=False):
if axis is None: axis = -1 r = []
r = np.linalg.norm( if axis is None:
x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype) x = x.flatten()
if porder == np.inf:
r = np.amax(np.abs(x))
elif porder == -np.inf:
r = np.amin(np.abs(x))
else:
r = np.linalg.norm(x, ord=porder)
elif isinstance(axis, list or tuple) and len(axis) == 2:
if porder == np.inf:
axis = tuple(axis)
r = np.amax(np.abs(x), axis=axis, keepdims=keepdims)
elif porder == -np.inf:
axis = tuple(axis)
r = np.amin(np.abs(x), axis=axis, keepdims=keepdims)
elif porder == 0:
axis = tuple(axis)
r = x.astype(bool)
r = np.sum(r, axis)
elif porder == 1:
axis = tuple(axis)
r = np.sum(np.abs(x), axis)
else:
axis = tuple(axis)
xp = np.power(np.abs(x), porder)
s = np.sum(xp, axis=axis, keepdims=keepdims)
r = np.power(s, 1.0 / porder)
else:
if isinstance(axis, list):
axis = tuple(axis)
r = np.linalg.norm(
x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype)
return r return r
...@@ -186,22 +217,10 @@ class TestPnormOp5(TestPnormOp): ...@@ -186,22 +217,10 @@ class TestPnormOp5(TestPnormOp):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
def run_out(self, p, axis, shape_x, shape_y, dtype):
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=shape_x, dtype=dtype)
data2 = fluid.data(name="Y", shape=shape_y, dtype=dtype)
out = paddle.norm(input=data1, p=p, axis=axis, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.random.rand(*shape_x).astype(dtype)},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def run_fro(self, p, axis, shape_x, dtype): def run_fro(self, p, axis, shape_x, dtype):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype) data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(input=data, p=p, axis=axis) out = paddle.norm(x=data, p=p, axis=axis)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
...@@ -213,35 +232,73 @@ def run_fro(self, p, axis, shape_x, dtype): ...@@ -213,35 +232,73 @@ def run_fro(self, p, axis, shape_x, dtype):
def run_pnorm(self, p, axis, shape_x, dtype): def run_pnorm(self, p, axis, shape_x, dtype):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype) data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(input=data, p=p, axis=axis) out = paddle.norm(x=data, p=p, axis=axis)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype) expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype)
result, = exe.run(feed={"X": np_input}, fetch_list=[out]) result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
def run_graph(self, p, axis, shape_x, dtype):
paddle.disable_static()
shape = [2, 3, 4]
np_input = np.arange(24).astype('float32') - 12
np_input = np_input.reshape(shape)
x = paddle.to_tensor(np_input)
#[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]]
# [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]]
out_pnorm = paddle.norm(x, p=2, axis=-1)
# compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro')
out_fro = paddle.norm(x, p='fro', axis=[0, 1])
# compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0, 1])
out_pnorm = paddle.norm(x, p=2)
#out_pnorm = [17.43559577 16.91153453 16.73320053 16.91153453]
# compute inf-order norm
out_pnorm = paddle.norm(x, p=np.inf)
#out_pnorm = [12.]
out_pnorm = paddle.norm(x, p=np.inf, axis=0)
#out_pnorm = [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]]
# compute -inf-order norm
out_pnorm = paddle.norm(x, p=-np.inf)
#out_pnorm = [0.]
out_pnorm = paddle.norm(x, p=-np.inf, axis=0)
# out_fro = [17.43559577 16.91153453 16.73320053 16.91153453]
paddle.enable_static()
class API_NormTest(unittest.TestCase): class API_NormTest(unittest.TestCase):
def test_output_result(self):
run_out(self, p=2, axis=1, shape_x=[3, 4], shape_y=[3], dtype="float32")
run_out(
self,
p='fro',
axis=None,
shape_x=[3, 4],
shape_y=[1],
dtype="float32")
def test_basic(self): def test_basic(self):
run_fro(self, p='fro', axis=None, shape_x=[3, 3, 4], dtype="float32") run_fro(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64") run_fro(self, p='fro', axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32") run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64") run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=np.inf, axis=1, shape_x=[3, 4], dtype="float32") run_pnorm(self, p=np.inf, axis=0, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=1, shape_x=[3, 4], dtype="float64") run_pnorm(self, p=np.inf, axis=None, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=0, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=None, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64") run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=1, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=0, axis=None, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=2, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=-1, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=1, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
def test_dygraph(self):
run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
def test_name(self): def test_name(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[10, 10], dtype="float32") x = fluid.data(name="x", shape=[10, 10], dtype="float32")
...@@ -268,11 +325,7 @@ class API_NormTest(unittest.TestCase): ...@@ -268,11 +325,7 @@ class API_NormTest(unittest.TestCase):
self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm") self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm")
self.assertRaises(ValueError, paddle.norm, data, p=[1]) self.assertRaises(ValueError, paddle.norm, data, p=[1])
self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1) self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1)
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64") data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64")
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
self.assertRaises( self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1]) ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1])
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from paddle.common_ops_import import * from paddle.common_ops_import import *
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type from ..fluid.data_feeder import check_variable_and_dtype, check_type
...@@ -170,7 +171,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -170,7 +171,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
return out return out
def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): def norm(x, p='fro', axis=None, keepdim=False, name=None):
""" """
:alias_main: paddle.norm :alias_main: paddle.norm
:alias: paddle.norm,paddle.tensor.norm,paddle.tensor.linalg.norm :alias: paddle.norm,paddle.tensor.norm,paddle.tensor.linalg.norm
...@@ -179,20 +180,19 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -179,20 +180,19 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
or 2-norm, and in general the p-norm for p > 0) of a given tensor. or 2-norm, and in general the p-norm for p > 0) of a given tensor.
Args: Args:
input (Variable): The input tensor could be N-D tensor, and the input data x (Tensor): The input tensor could be N-D tensor, and the input data
type could be float32 or float64. type could be float32 or float64.
p (float|string, optional): Order of the norm. Supported values are `fro`, `1`, `2`, p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`,
and any positive real number yielding the corresponding p-norm. `inf`,`-inf` and any positive real number yielding the corresponding p-norm.
axis (int|list, optional): The axis on which to apply norm operation. If axis is int Not supported: ord < 0, nuclear norm.
or list with only one element, the vector norm is computed over the axis. axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int
If axis is a list with two elements, the matrix norm is computed over the axis. or list(int)/tuple(int) with only one element, the vector norm is computed over the axis.
If `axis < 0`, the dimension to norm operation is rank(input) + axis. If `axis < 0`, the dimension to norm operation is rank(input) + axis.
If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis.
keepdim (bool, optional): Whether to reserve the reduced dimension in the keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have fewer dimension output Tensor. The result tensor will have fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default than the :attr:`input` unless :attr:`keepdim` is true, default
value is False. value is False.
out (Variable, optional): The output tensor, default value is None. It's data type
must be the same as the input Tensor.
name (str, optional): The default value is None. Normally there is no need for name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -208,29 +208,57 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -208,29 +208,57 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid import numpy as np
x = fluid.data(name='x', shape=[2, 3, 5], dtype='float64') paddle.disable_static()
shape=[2, 3, 4]
np_input = np.arange(24).astype('float32') - 12
np_input = np_input.reshape(shape)
x = paddle.to_tensor(np_input)
#[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]]
# [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]]
# compute frobenius norm along last two dimensions. # compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro', axis=[1,2]) out_fro = paddle.norm(x, p='fro', axis=[0,1])
# out_fro.numpy() [17.435596 16.911535 16.7332 16.911535]
# compute 2-order vector norm along last dimension. # compute 2-order vector norm along last dimension.
out_pnorm = paddle.norm(x, p=2, axis=-1) out_pnorm = paddle.norm(x, p=2, axis=-1)
#out_pnorm.numpy(): [[21.118711 13.190906 5.477226]
# [ 3.7416575 11.224972 19.131126]]
# compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0,1])
#out_pnorm.numpy(): [17.435596 16.911535 16.7332 16.911535]
# compute inf-order norm
out_pnorm = paddle.norm(x, p=np.inf)
#out_pnorm.numpy() = [12.]
out_pnorm = paddle.norm(x, p=np.inf, axis=0)
#out_pnorm.numpy(): [[12. 11. 10. 9.] [8. 7. 6. 7.] [8. 9. 10. 11.]]
# compute -inf-order norm
out_pnorm = paddle.norm(x, p=-np.inf)
#out_pnorm.numpy(): [0.]
out_pnorm = paddle.norm(x, p=-np.inf, axis=0)
#out_pnorm.numpy(): [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]]
""" """
def frobenius_norm(input, dim=None, keepdim=False, out=None, name=None): def frobenius_norm(input, dim=None, keepdim=False, name=None):
""" """
The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`. The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`.
Args: Args:
input (Variable): Tensor, data type float32, float64. input (Variable): Tensor, data type float32, float64.
dim (list, optional): None for last two dimensions. dim (list, optional): None for last two dimensions.
keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False.
out (Variable, optional): The tensor variable storing the output.
""" """
if dim is not None and not (isinstance(dim, list) and len(dim) == 2): if dim is not None and not (isinstance(dim, list) and len(dim) == 2):
raise ValueError( raise ValueError(
"The dim of frobenius norm op should be None or two elements list!" "The dim of frobenius norm op should be None or two elements list!"
) )
if in_dygraph_mode():
if dim is None: dim = [-1]
return core.ops.frobenius_norm(input, 'dim', dim, 'keepdim',
keepdim)
attrs = { attrs = {
'dim': dim if dim != None else [-2, -1], 'dim': dim if dim != None else [-2, -1],
'keep_dim': keepdim, 'keep_dim': keepdim,
...@@ -242,16 +270,8 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -242,16 +270,8 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
'frobenius_norm') 'frobenius_norm')
helper = LayerHelper('frobenius_norm', **locals()) helper = LayerHelper('frobenius_norm', **locals())
if out is None: out = helper.create_variable_for_type_inference(
out = helper.create_variable_for_type_inference( dtype=helper.input_dtype())
dtype=helper.input_dtype())
else:
check_type(out, 'out', (Variable), 'frobenius_norm')
check_dtype(
out.dtype, out.name,
convert_dtype(input.dtype), 'frobenius_norm',
'(The out data type in frobenius_norm must be the same with input data type.)'
)
helper.append_op( helper.append_op(
type='frobenius_norm', type='frobenius_norm',
...@@ -264,7 +284,7 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -264,7 +284,7 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
porder=None, porder=None,
axis=None, axis=None,
keepdim=False, keepdim=False,
out=None, asvector=False,
name=None): name=None):
""" """
Calculate the p-order vector norm for certain dimension of Tensor `input`. Calculate the p-order vector norm for certain dimension of Tensor `input`.
...@@ -273,32 +293,28 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -273,32 +293,28 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
porder (float, optional): None for porder=2.0. porder (float, optional): None for porder=2.0.
axis (int, optional): None for last dimension. axis (int, optional): None for last dimension.
keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False.
out (Variable, optional): The tensor variable storing the output.
""" """
if in_dygraph_mode():
if axis is None: axis = -1
return core.ops.p_norm(input, 'porder', porder, 'axis', axis,
'keepdim', keepdim, 'asvector', asvector)
if porder is not None: if porder is not None:
check_type(porder, 'porder', (float, int), 'p_norm') check_type(porder, 'porder', (float, int), 'p_norm')
if axis is not None: if axis is not None:
check_type(axis, 'axis', (int), 'p_norm') check_type(axis, 'axis', (int), 'p_norm')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'p_norm')
attrs = { attrs = {
'axis': axis if axis is not None else -1, 'axis': axis if axis is not None else -1,
'porder': float(porder) if porder is not None else 2.0, 'porder': float(porder) if porder is not None else 2.0,
'keepdim': keepdim, 'keepdim': keepdim,
'asvector': asvector,
'epsilon': 1e-12, 'epsilon': 1e-12,
} }
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'p_norm')
helper = LayerHelper('p_norm', **locals()) helper = LayerHelper('p_norm', **locals())
if out is None: out = helper.create_variable_for_type_inference(
out = helper.create_variable_for_type_inference( dtype=helper.input_dtype())
dtype=helper.input_dtype())
else:
check_type(out, 'out', (Variable), 'p_norm')
check_dtype(
out.dtype, out.name,
convert_dtype(input.dtype), 'p_norm',
'(The out data type in p_norm must be the same with input data type.)'
)
helper.append_op( helper.append_op(
type='p_norm', type='p_norm',
...@@ -307,21 +323,126 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -307,21 +323,126 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
attrs=attrs) attrs=attrs)
return out return out
def inf_norm(input,
porder=None,
axis=axis,
keepdim=False,
asvector=False,
name=None):
helper = LayerHelper('frobenius_norm', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(type='abs', inputs={'X': input}, outputs={'Out': out})
reduce_out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
reduce_all = True if axis == None or axis == [] or asvector == True else False
axis = axis if axis != None and axis != [] else [0]
reduce_type = 'reduce_max' if porder == np.float(
'inf') else 'reduce_min'
helper.append_op(
type=reduce_type,
inputs={'X': out},
outputs={'Out': reduce_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
return reduce_out
def p0_matrix_norm(input, porder=0., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
cast_out = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': input},
outputs={'Out': cast_out},
attrs={
'in_dtype': input.dtype,
'out_dtype': int(core.VarDesc.VarType.BOOL)
})
cast_out2 = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': cast_out},
outputs={'Out': cast_out2},
attrs={
'in_dtype': cast_out.dtype,
'out_dtype': int(core.VarDesc.VarType.FP32)
})
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_sum',
inputs={'X': cast_out2},
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False
})
return sum_out
def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out})
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='pow',
inputs={'X': abs_out},
outputs={'Out': pow_out},
attrs={'factor': porder})
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False
})
porder
block.append_op(
type='pow',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={'factor': float(1. / porder)})
return out
if axis is None and p is not None: if axis is None and p is not None:
if isinstance(p, str): if isinstance(p, str):
if p == "fro": if p == "fro":
return frobenius_norm( return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
input, dim=axis, keepdim=keepdim, out=out, name=name)
else: else:
raise ValueError( raise ValueError(
"only valid string values are 'fro', found {}".format(p)) "only valid string values are 'fro', found {}".format(p))
elif isinstance(p, (int, float)): elif isinstance(p, (int, float)):
return vector_norm( return vector_norm(
input, porder=p, axis=axis, keepdim=keepdim, out=out, name=name) x,
porder=p,
axis=axis,
keepdim=keepdim,
asvector=True,
name=name)
else: else:
raise ValueError("only valid p type is string or float, found {}". raise ValueError("only valid p type is string or float, found {}".
format(type(p))) format(type(p)))
if isinstance(axis, tuple):
axis = list(axis)
if isinstance(axis, list) and len(axis) == 1: if isinstance(axis, list) and len(axis) == 1:
axis = axis[0] axis = axis[0]
...@@ -329,7 +450,12 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -329,7 +450,12 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
if isinstance(axis, int): if isinstance(axis, int):
if isinstance(p, (int, float)): if isinstance(p, (int, float)):
return vector_norm( return vector_norm(
input, axis=axis, porder=p, keepdim=keepdim, out=out, name=name) x,
axis=axis,
porder=p,
keepdim=keepdim,
asvector=False,
name=name)
else: else:
raise ValueError( raise ValueError(
"unspport p for p-order vector norm. except float, found {}". "unspport p for p-order vector norm. except float, found {}".
...@@ -337,11 +463,14 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): ...@@ -337,11 +463,14 @@ def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None):
#calculate matrix norm, where axis is list with two integers #calculate matrix norm, where axis is list with two integers
elif isinstance(axis, list) and len(axis) == 2: elif isinstance(axis, list) and len(axis) == 2:
if p == "fro": if p == "fro":
return frobenius_norm( return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
input, dim=axis, keepdim=keepdim, out=out, name=name) elif p == 0:
return p0_matrix_norm(x, axis=axis, keepdim=keepdim, name=name)
elif p == np.inf or p == -np.inf:
return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name)
else: else:
raise ValueError( return p_matrix_norm(
"unspport p for matrix norm, expcept 'fro', found {}".format(p)) x, porder=p, axis=axis, keepdim=keepdim, name=name)
else: else:
raise ValueError( raise ValueError(
"except axis type int or list (length of list <=2), found {}". "except axis type int or list (length of list <=2), found {}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册