diff --git a/paddle/fluid/operators/p_norm_op.cc b/paddle/fluid/operators/p_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1da6b11935b3aa09884b40ebfb3448581d9b34f9 --- /dev/null +++ b/paddle/fluid/operators/p_norm_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/p_norm_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +class PnormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) A tensor of rank >= axis."); + AddAttr("porder", + "The porder is the p order vector norm to calculate.") + .SetDefault(2.0f); + AddAttr("axis", + "The axis on which to apply normalization. If axis < 0, " + "the dimension to pnorm is rank(X) + axis. -1 is " + "the last dimension.") + .SetDefault(-1); + AddAttr("epsilon", + "(float, default 1e-10) The epsilon value is used " + "to avoid division by zero.") + .SetDefault(1.0e-12f); + AddAttr( + "keepdim", + "(bool, default false) Whether to keep the dimensions as the input") + .SetDefault(false); + AddOutput( + "Out", + "(Tensor) Output tensor for the `(sum(x.pow(p)) + epsion).pow(1/p)`"); + AddComment(R"DOC( + +Given a tensor, apply 2-normalization along the provided axis. + +$$ +pnorm = \(\sum_i {abs\(x_i\)^p} \)^{1/p} +$$ + +where, $\sum_i{x_i^p}$ is calculated along the `axis` dimension. + +)DOC"); + } +}; + +class PnormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm"); + auto porder = ctx->Attrs().Get("porder"); + PADDLE_ENFORCE_NE(porder, 0, + platform::errors::InvalidArgument( + "The input porder of p_norm is not support for " + "porder == 0, INFINITY, -INFINITY now.")); + PADDLE_ENFORCE_NE(porder, INFINITY, + platform::errors::InvalidArgument( + "The input porder of p_norm is not support for " + "porder == 0, INFINITY, -INFINITY now.")); + PADDLE_ENFORCE_NE(porder, -INFINITY, + platform::errors::InvalidArgument( + "The input porder of p_norm is not support for " + "porder == 0, INFINITY, -INFINITY now.")); + auto xdim = ctx->GetInputDim("X"); + int axis = ctx->Attrs().Get("axis"); + bool keepdim = ctx->Attrs().Get("keepdim"); + if (axis < 0) axis = xdim.size() + axis; + std::vector reduce_dims; + for (int i = 0; i < xdim.size(); ++i) { + if (i != axis) reduce_dims.emplace_back(xdim[i]); + } + xdim[axis] = 1; + if (keepdim) { + ctx->SetOutputDim("Out", xdim); + } else { + ctx->SetOutputDim("Out", framework::make_ddim(reduce_dims)); + } + } +}; + +class PnormOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "p_norm"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "p_norm"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@GRAD", "p_norm"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +template +class PnormOpGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("p_norm_grad"); + op->SetAttrMap(this->Attrs()); + op->SetInput("X", this->Input("X")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(p_norm, ops::PnormOp, ops::PnormOpMaker, + ops::PnormOpGradOpMaker, + ops::PnormOpGradOpMaker); +REGISTER_OPERATOR(p_norm_grad, ops::PnormOpGrad); +REGISTER_OP_CPU_KERNEL(p_norm, ops::PnormKernel, + ops::PnormKernel); +REGISTER_OP_CPU_KERNEL(p_norm_grad, ops::PnormGradKernel, + ops::PnormGradKernel); diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5e94d87b85197b69fb47643168facb8de6375884 --- /dev/null +++ b/paddle/fluid/operators/p_norm_op.cu @@ -0,0 +1,180 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "cub/cub.cuh" +#include "paddle/fluid/operators/p_norm_op.h" + +namespace paddle { +namespace operators { + +template +__device__ __forceinline__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +__device__ __forceinline__ int inline_sign(float x) { return sgn(x); } +__device__ __forceinline__ int inline_sign(double x) { return sgn(x); } + +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} + +template +__global__ void Pnorm(const T* x, const int pre, + const int axis_n, // dim in axis + const int post, float porder, T* out_norm) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + int base = (i / post) * post * axis_n + (i % post); + + T sum = 0.0; + __shared__ T norm; + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + const T x_ij = x[base + j * post]; + sum += inline_pow(inline_abs(x_ij), porder); + } + T reduce_result = BlockReduce(temp_storage).Sum(sum); + + if (threadIdx.x == 0) { + norm = inline_pow(reduce_result, 1.0f / porder); + out_norm[i] = norm; + } + __syncthreads(); + } +} + +template +class PnormCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* out_norm = ctx.Output("Out"); + const T* x = in_x->data(); + T* norm = out_norm->mutable_data(ctx.GetPlace()); + + auto xdim = in_x->dims(); + auto ndim = out_norm->dims(); + float porder = ctx.Attr("porder"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + + auto& dev_ctx = ctx.cuda_device_context(); + + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid = std::min(max_blocks, pre * post); + Pnorm<<>>(x, pre, n, post, + porder, norm); + } +}; + +template +__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, + const float porder, const int pre, + const int axis_n, const int post, const T eps, + T* x_grad) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage_sum; + // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + T sum = 0.0; + __shared__ T row_sum; + __shared__ T row_sqrt_norm; + __shared__ T row_norm; + + auto base = (i / post) * post * axis_n + (i % post); + + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + int index = base + j * post; + sum += x[index] * y_grad[index]; + } + T reduce_result = BlockReduce(temp_storage_sum).Sum(sum); + + if (threadIdx.x == 0) { + row_sum = reduce_result; + row_sqrt_norm = x_norm[i]; + row_norm = row_sqrt_norm * row_sqrt_norm; + } + __syncthreads(); + + const T pnorm_i = x_norm[i]; + const T yout_i = y_grad[i]; + + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + int index = base + j * post; + const T x_ij = inline_abs(x[index]); + const T dy_ij = y_grad[index]; + x_grad[index] = inline_pow(x_ij, porder - 1.0f) / + (inline_pow(pnorm_i, porder - 1.0f) + eps) * yout_i * + inline_sign(x[index]); + } + } +} + +template +class PnormGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* in_norm = ctx.Input("Out"); + auto* in_norm_dy = + ctx.Input(framework::GradVarName("Out")); + auto* out_dx = ctx.Output(framework::GradVarName("X")); + T* dx = out_dx->mutable_data(ctx.GetPlace()); + const T* x = in_x->data(); + const T* x_norm = in_norm->data(); + const T* norm_dy = in_norm_dy->data(); + + auto xdim = in_x->dims(); + float porder = ctx.Attr("porder"); + T eps = static_cast(ctx.Attr("epsilon")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + + auto& dev_ctx = ctx.cuda_device_context(); + + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid = std::min(max_blocks, pre * post); + PnormGradient<<>>( + x, x_norm, norm_dy, porder, pre, n, post, eps, dx); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; + +REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel, + ops::PnormCUDAKernel); +REGISTER_OP_CUDA_KERNEL(p_norm_grad, ops::PnormGradCUDAKernel, + ops::PnormGradCUDAKernel); diff --git a/paddle/fluid/operators/p_norm_op.h b/paddle/fluid/operators/p_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c5bdfe352723b55f80376d6644922af5de099e90 --- /dev/null +++ b/paddle/fluid/operators/p_norm_op.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, + int* post) { + *pre = 1; + *post = 1; + *n = dim[axis]; + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } +} + +template +class PnormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* out_norm = ctx.Output("Out"); + out_norm->mutable_data(ctx.GetPlace()); + + auto xdim = in_x->dims(); + float porder = ctx.Attr("porder"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + + auto* place = ctx.template device_context().eigen_device(); + + Eigen::DSizes shape(pre, n, post); + Eigen::DSizes norm_shape(pre, post); + + auto x_e = framework::EigenVector::Flatten(*in_x); + auto norm_e = framework::EigenVector::Flatten(*out_norm); + + auto x = x_e.reshape(shape); + auto norm = norm_e.reshape(norm_shape); + + Eigen::DSizes rdim(1); + auto xp = (x.abs()).pow(porder); + auto sum = xp.sum(rdim); + norm.device(*place) = sum.pow(1.0f / porder); + } +}; + +template +class PnormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_x = ctx.Input("X"); + auto* in_norm = ctx.Input("Out"); + auto* in_norm_dy = + ctx.Input(framework::GradVarName("Out")); + auto* out_dx = ctx.Output(framework::GradVarName("X")); + out_dx->mutable_data(ctx.GetPlace()); + + T eps = static_cast(ctx.Attr("epsilon")); + auto xdim = in_x->dims(); + float porder = ctx.Attr("porder"); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post); + Eigen::DSizes shape(pre, n, post); + Eigen::DSizes rshape(pre, 1, post); + + auto* place = ctx.template device_context().eigen_device(); + + auto x_e = framework::EigenVector::Flatten(*in_x); + auto dx_e = framework::EigenVector::Flatten(*out_dx); + auto norm_e = framework::EigenVector::Flatten(*in_norm); + auto norm_dy_e = framework::EigenVector::Flatten(*in_norm_dy); + + auto x = x_e.reshape(shape); + auto dx = dx_e.reshape(shape); + auto norm = norm_e.reshape(rshape); + auto norm_dy = norm_dy_e.reshape(rshape); + + Eigen::DSizes rdim(1); + Eigen::DSizes bcast(1, n, 1); + + dx.device(*place) = (x.abs()).pow(porder - 1.0f); + dx.device(*place) = + dx / ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps)); + dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign(); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cc b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..57df2664824d478503fce04f09c5a7f1e02eb080 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h" +#include +#include + +namespace paddle { +namespace operators { + +template +class FrobeniusNormOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("frobenius_norm_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +class FrobeniusNormOpMaker : public ops::ReduceOpMaker { + protected: + virtual std::string GetName() const { return "frobenius_norm"; } + virtual std::string GetOpType() const { return "Reduce frobenius_norm"; } +}; + +REGISTER_OPERATOR(frobenius_norm, ops::ReduceOp, FrobeniusNormOpMaker, + ops::FrobeniusNormOpGradMaker, + ops::FrobeniusNormOpGradMaker); + +REGISTER_OPERATOR(frobenius_norm_grad, ops::ReduceGradOp); + +REGISTER_OP_CPU_KERNEL(frobenius_norm, + ops::ReduceKernel, + ops::ReduceKernel); + +template +using CPUFrobeniusNormGradKernel = + ops::FrobeniusNormGradKernel; + +REGISTER_OP_CPU_KERNEL(frobenius_norm_grad, CPUFrobeniusNormGradKernel, + CPUFrobeniusNormGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ff645dfeb653c5fafa2ae2ca058e780a93a0764 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h" + +template +using CUDAFrobeniusNormKernel = + ops::ReduceKernel; + +REGISTER_OP_CUDA_KERNEL(frobenius_norm, CUDAFrobeniusNormKernel, + CUDAFrobeniusNormKernel); + +template +using CUDAFrobeniusNormGradKernel = + ops::ReduceGradKernel; + +REGISTER_OP_CUDA_KERNEL(frobenius_norm_grad, CUDAFrobeniusNormGradKernel, + CUDAFrobeniusNormGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/frobenius_norm_op.h b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0b6b87d99ecd98e65c492fb96f3a1e886b7bfa4b --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.h @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +namespace paddle { +namespace operators { + +// \partial \| X \|_F = \frac{X}{ \| X \|_F } +template +class FrobeniusNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // default use Eigen broadcast + ReduceGradKernel kernel; + kernel.Compute(context); + } +}; + +struct FrobeniusNormFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = ((x->square()).sum(dim)).sqrt(); + } +}; + +struct FrobeniusNormGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = y->broadcast(dim); + dx->device(place) = *dx + dx->constant(1e-12f); + dx->device(place) = (*x / *dx) * (dy->broadcast(dim)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5587fd795b106838baeb3d85804713c51c777a5f..9e2bf8c5eea626f78c18b473db6f88a1ba8b7fd2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -150,7 +150,7 @@ from .tensor.math import addmm #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS -# from .tensor.linalg import morm #DEFINE_ALIAS +from .tensor.linalg import norm #DEFINE_ALIAS # from .tensor.linalg import transpose #DEFINE_ALIAS from .tensor.linalg import dist #DEFINE_ALIAS # from .tensor.linalg import t #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b7a3e7603f53d78052d5de309d6ed7d84c4660 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -0,0 +1,210 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +def p_norm(x, axis, porder, keepdims=False): + if axis is None: axis = -1 + xp = np.power(np.abs(x), porder) + s = np.sum(xp, axis=axis, keepdims=keepdims) + r = np.power(s, 1.0 / porder) + return r + + +def frobenius_norm(x, axis=None, keepdims=False): + if isinstance(axis, list): axis = tuple(axis) + if axis is None: axis = (-2, -1) + r = np.linalg.norm(x, ord='fro', axis=axis, keepdims=keepdims) + return r + + +class TestFrobeniusNormOp(OpTest): + def setUp(self): + self.op_type = "frobenius_norm" + self.init_test_case() + x = (np.random.random(self.shape) + 1.0).astype(self.dtype) + norm = frobenius_norm(x, self.axis, self.keepdim) + self.reduce_all = (len(self.axis) == len(self.shape)) + self.inputs = {'X': x} + self.attrs = { + 'dim': list(self.axis), + 'keep_dim': self.keepdim, + 'reduce_all': self.reduce_all + } + self.outputs = {'Out': norm} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = (1, 2) + self.keepdim = False + self.dtype = "float64" + + +class TestFrobeniusNormOp2(TestFrobeniusNormOp): + def init_test_case(self): + self.shape = [5, 5, 5] + self.axis = (0, 1) + self.keepdim = True + self.dtype = "float32" + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestPnormOp(OpTest): + def setUp(self): + self.op_type = "p_norm" + self.init_test_case() + x = (np.random.random(self.shape) + 0.5).astype(self.dtype) + norm = p_norm(x, self.axis, self.porder, self.keepdim) + self.inputs = {'X': x} + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder) + } + self.outputs = {'Out': norm} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = False + self.dtype = "float64" + + +class TestPnormOp2(TestPnormOp): + def init_test_case(self): + self.shape = [3, 20, 3] + self.axis = 2 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = True + self.dtype = "float32" + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +def run_out(self, p, axis, shape_x, shape_y, dtype): + with fluid.program_guard(fluid.Program()): + data1 = fluid.data(name="X", shape=shape_x, dtype=dtype) + data2 = fluid.data(name="Y", shape=shape_y, dtype=dtype) + out = paddle.norm(input=data1, p=p, axis=axis, out=data2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result = exe.run(feed={"X": np.random.rand(*shape_x).astype(dtype)}, + fetch_list=[data2, out]) + self.assertEqual((result[0] == result[1]).all(), True) + + +def run_fro(self, p, axis, shape_x, dtype): + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=shape_x, dtype=dtype) + out = paddle.norm(input=data, p=p, axis=axis) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) + expected_result = frobenius_norm(np_input, axis=axis) + result, = exe.run(feed={"X": np_input}, fetch_list=[out]) + self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) + + +def run_pnorm(self, p, axis, shape_x, dtype): + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=shape_x, dtype=dtype) + out = paddle.norm(input=data, p=p, axis=axis) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) + expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype) + result, = exe.run(feed={"X": np_input}, fetch_list=[out]) + self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) + + +class API_NormTest(unittest.TestCase): + def test_output_result(self): + run_out(self, p=2, axis=1, shape_x=[3, 4], shape_y=[3], dtype="float32") + run_out( + self, + p='fro', + axis=None, + shape_x=[3, 4], + shape_y=[1], + dtype="float32") + + def test_basic(self): + run_fro(self, p='fro', axis=None, shape_x=[3, 3, 4], dtype="float32") + run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64") + run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32") + run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64") + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[10, 10], dtype="float32") + y_1 = paddle.norm(x, p='fro', name='frobenius_name') + y_2 = paddle.norm(x, p=2, name='pnorm_name') + self.assertEqual(('frobenius_name' in y_1.name), True) + self.assertEqual(('pnorm_name' in y_2.name), True) + + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + def err_dtype(p, shape_x, xdtype, out=None): + data = fluid.data(shape=shape_x, dtype=xdtype) + paddle.norm(data, p=p, out=out) + + self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "int64") + out = fluid.data(name="out", shape=[1], dtype="int64") + self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "float64", + out) + self.assertRaises(TypeError, err_dtype, 2, [10], "int64") + self.assertRaises(TypeError, err_dtype, 2, [10], "float64", out) + + data = fluid.data(name="data_2d", shape=[2, 2], dtype="float64") + self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm") + self.assertRaises(ValueError, paddle.norm, data, p=[1]) + self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1) + self.assertRaises( + ValueError, paddle.norm, data, p='unspport', axis=[-2, -1]) + data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64") + self.assertRaises( + ValueError, paddle.norm, data, p='unspport', axis=[-2, -1]) + self.assertRaises( + ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 15a1607ac875477f407626121a9e300ff8d31ffd..0f1accd51ae1cf0a71dd13ec3c0f888f62fa3801 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -125,7 +125,7 @@ from .math import addmm #DEFINE_ALIAS from .linalg import matmul #DEFINE_ALIAS from .linalg import dot #DEFINE_ALIAS # from .linalg import einsum #DEFINE_ALIAS -# from .linalg import morm #DEFINE_ALIAS +from .linalg import norm #DEFINE_ALIAS # from .linalg import transpose #DEFINE_ALIAS from .linalg import dist #DEFINE_ALIAS # from .linalg import t #DEFINE_ALIAS diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4d0d99edf4272534dace4ce40128aa40a390222c..d23c474d3077322695a99ed3359bb01ccc6d8d48 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -20,7 +20,7 @@ __all__ = [ 'matmul', 'dot', # 'einsum', - # 'morm', + 'norm', # 'transpose', 'dist', # 't', @@ -160,6 +160,181 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): return out +def norm(input, p='fro', axis=None, keepdim=False, out=None, name=None): + """ + Returns the matrix norm (Frobenius) or vector norm (the 1-norm, the Euclidean + or 2-norm, and in general the p-norm for p > 0) of a given tensor. + + Args: + input (Variable): The input tensor could be N-D tensor, and the input data + type could be float32 or float64. + p (float|string, optional): Order of the norm. Supported values are `fro`, `1`, `2`, + and any positive real number yielding the corresponding p-norm. + axis (int|list, optional): The axis on which to apply norm operation. If axis is int + or list with only one element, the vector norm is computed over the axis. + If axis is a list with two elements, the matrix norm is computed over the axis. + If `axis < 0`, the dimension to norm operation is rank(input) + axis. + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have fewer dimension + than the :attr:`input` unless :attr:`keepdim` is true, default + value is False. + out (Variable, optional): The output tensor, default value is None. It's data type + must be the same as the input Tensor. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: Tensor, results of norm operation on the specified axis of input tensor, + it's data type is the same as input's Tensor. + + Raises: + TypeError, if out data type is different with the input data type. + ValueError, If `p` or `axis` is invalid. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + x = fluid.data(name='x', shape=[2, 3, 5], dtype='float64') + + # compute frobenius norm along last two dimensions. + out_fro = paddle.norm(x, p='fro', axis=[1,2]) + + # compute 2-order vector norm along last dimension. + out_pnorm = paddle.norm(x, p=2, axis=-1) + """ + + def frobenius_norm(input, dim=None, keepdim=False, out=None, name=None): + """ + The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`. + Args: + input (Variable): Tensor, data type float32, float64. + dim (list, optional): None for last two dimensions. + keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. + out (Variable, optional): The tensor variable storing the output. + """ + if dim is not None and not (isinstance(dim, list) and len(dim) == 2): + raise ValueError( + "The dim of frobenius norm op should be None or two elements list!" + ) + attrs = { + 'dim': dim if dim != None else [-2, -1], + 'keep_dim': keepdim, + 'reduce_all': False + } + if len(attrs['dim']) == len(input.shape): + attrs['reduce_all'] = True + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'frobenius_norm') + + helper = LayerHelper('frobenius_norm', **locals()) + if out is None: + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + else: + check_type(out, 'out', (Variable), 'frobenius_norm') + check_dtype( + out.dtype, out.name, + convert_dtype(input.dtype), 'frobenius_norm', + '(The out data type in frobenius_norm must be the same with input data type.)' + ) + + helper.append_op( + type='frobenius_norm', + inputs={'X': input}, + outputs={'Out': out}, + attrs=attrs) + return out + + def vector_norm(input, + porder=None, + axis=None, + keepdim=False, + out=None, + name=None): + """ + Calculate the p-order vector norm for certain dimension of Tensor `input`. + Args: + input (Variable): Tensor, data type float32, float64. + porder (float, optional): None for porder=2.0. + axis (int, optional): None for last dimension. + keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. + out (Variable, optional): The tensor variable storing the output. + """ + if porder is not None: + check_type(porder, 'porder', (float, int), 'p_norm') + if axis is not None: + check_type(axis, 'axis', (int), 'p_norm') + attrs = { + 'axis': axis if axis is not None else -1, + 'porder': float(porder) if porder is not None else 2.0, + 'keepdim': keepdim, + 'epsilon': 1e-12, + } + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'p_norm') + + helper = LayerHelper('p_norm', **locals()) + if out is None: + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + else: + check_type(out, 'out', (Variable), 'p_norm') + check_dtype( + out.dtype, out.name, + convert_dtype(input.dtype), 'p_norm', + '(The out data type in p_norm must be the same with input data type.)' + ) + + helper.append_op( + type='p_norm', + inputs={'X': input}, + outputs={'Out': out}, + attrs=attrs) + return out + + if axis is None and p is not None: + if isinstance(p, str): + if p == "fro": + return frobenius_norm( + input, dim=axis, keepdim=keepdim, out=out, name=name) + else: + raise ValueError( + "only valid string values are 'fro', found {}".format(p)) + elif isinstance(p, (int, float)): + return vector_norm( + input, porder=p, axis=axis, keepdim=keepdim, out=out, name=name) + else: + raise ValueError("only valid p type is string or float, found {}". + format(type(p))) + + if isinstance(axis, list) and len(axis) == 1: + axis = axis[0] + + #calculate vector norm, where axis is int or list with only one integer + if isinstance(axis, int): + if isinstance(p, (int, float)): + return vector_norm( + input, axis=axis, porder=p, keepdim=keepdim, out=out, name=name) + else: + raise ValueError( + "unspport p for p-order vector norm. except float, found {}". + format(p)) + #calculate matrix norm, where axis is list with two integers + elif isinstance(axis, list) and len(axis) == 2: + if p == "fro": + return frobenius_norm( + input, dim=axis, keepdim=keepdim, out=out, name=name) + else: + raise ValueError( + "unspport p for matrix norm, expcept 'fro', found {}".format(p)) + else: + raise ValueError( + "except axis type int or list (length of list <=2), found {}". + format(axis)) + + def dist(x, y, p=2): """ This OP returns the p-norm of (x - y). It is not a norm in a strict sense, only as a measure