diff --git a/paddle/fluid/distributed/test/graph_node_split_test.cc b/paddle/fluid/distributed/test/graph_node_split_test.cc index 3fcddde787f69f22a2a742528431fffdc28840e8..714fbb1e4aa2d8abb10eebe464cd8ac11ad1dc18 100644 --- a/paddle/fluid/distributed/test/graph_node_split_test.cc +++ b/paddle/fluid/distributed/test/graph_node_split_test.cc @@ -272,4 +272,4 @@ void RunGraphSplit() { worker_ptr_->finalize_worker(); } -TEST(RunGraphSplit, Run) { RunGraphSplit(); } \ No newline at end of file +TEST(RunGraphSplit, Run) { RunGraphSplit(); } diff --git a/paddle/fluid/operators/renorm_op.cc b/paddle/fluid/operators/renorm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b15193e0e99d8ffa8450889bbbc499fcdcc7d929 --- /dev/null +++ b/paddle/fluid/operators/renorm_op.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/renorm_op.h" + +#include +#include +#include +#include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class RenormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + using DDim = paddle::framework::DDim; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class RenormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of renorm op."); + AddOutput("Out", "(Tensor), The output tensor of renorm op."); + AddAttr("p", "(float, norm's power"); + AddAttr("axis", + "int,the dimension to slice over to get the sub-tensors"); + AddAttr("max_norm", "(float, the norm upper-bound"); + AddAttr("use_cudnn", + "(bool, default false) Only used in cudnn kernel, need " + "install cudnn") + .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddComment(R"DOC( +Renorm Operator. + +This operator is used to scale tensor sliced by axis if its p-norm execeeds maxnorm + +)DOC"); + } +}; + +class RenormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "AbsGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "AbsGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class RenormGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("renorm_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(renorm, ops::RenormOp, ops::RenormOpMaker, + ops::RenormGradMaker, + ops::RenormGradMaker); + +REGISTER_OPERATOR(renorm_grad, ops::RenormGradOp); + +REGISTER_OP_CPU_KERNEL(renorm, ops::CPURenormKernel, + ops::CPURenormKernel); + +REGISTER_OP_CPU_KERNEL(renorm_grad, ops::CPURenormGradKernel, + ops::CPURenormGradKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/renorm_op.cu b/paddle/fluid/operators/renorm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..c8471fffa859809a64ad89395936e39bb9a7b4eb --- /dev/null +++ b/paddle/fluid/operators/renorm_op.cu @@ -0,0 +1,237 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/renorm_op.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "stdio.h" + +namespace paddle { +namespace operators { + +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; + } + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + +template +__global__ void RenormKernelFunc3(int64_t size, T* dim_value, float p, + float max_norm) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + if (i < size) { + T temp = pow(dim_value[i], (T)(1.0 / p)); + dim_value[i] = 1.0; + if (temp > max_norm) dim_value[i] = max_norm / temp; + } +} + +template +__global__ void RenormKernelFunc4(T* x_data, T* out_data, int64_t size, + T* dim_value, int64_t dimension_each, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + if (dim_value[dim_index] < 1.0) + out_data[i] = dim_value[dim_index] * x_data[i]; + else + out_data[i] = x_data[i]; + } +} + +template +__global__ void RenormGradKernelFunc1(T* x_data, T* dout_data, T* pow_value, + T* mul_value, int64_t size, + int64_t dimension_each, float p, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + pow_value[i] = pow(abs(x_data[i]), (T)p); + mul_value[i] = x_data[i] * dout_data[i]; + } +} + +template +__global__ void RenormGradKernelFunc2(T* x_data, T* dout_data, T* dx_data, + int64_t size, T* dim_value, + T* dim_power_sum, T* weight_derivative, + int64_t dimension_each, float p, + float max_norm, int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < dimension_each) { + dim_power_sum[i] = 0; + auto temp = pow(dim_value[i], (T)(1.0 / p)); + if (temp > max_norm) { + dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; + dim_value[i] = max_norm / temp; + } else + dim_value[i] = 1.0; + } + __syncthreads(); + if (i < size) { + dx_data[i] = dim_value[dim_index] * dout_data[i]; + dx_data[i] = dx_data[i] + + weight_derivative[dim_index] * dim_power_sum[dim_index] * + pow(abs(x_data[i]), T(p - 1.0)) * + (x_data[i] >= 0 ? 1 : -1); + } +} + +template +class CUDARenormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + auto numel = x->numel(); + T* x_data = (T*)x->data(); + auto input_dims = x->dims(); + float max_norm = context.Attr("max_norm"); + float p = context.Attr("p"); + int dim = context.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + framework::Tensor pow_value, dim_value; + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + pow_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(framework::make_ddim({dimension_each})); + pow_value.mutable_data(context.GetPlace()); + out->Resize(framework::make_ddim(framework::vectorize(input_dims))); + T* out_data = out->mutable_data(context.GetPlace()); + auto stream = context.cuda_device_context().stream(); + int block = std::min(numel, static_cast(256)); + using MT = typename details::MPTypeTrait::Type; + int grid = (numel + block - 1) / block; + + int block2 = std::min(dimension_each, static_cast(256)); + int grid2 = (dimension_each + block2 - 1) / block2; + std::vector ins = {x}; + std::vector outs = {&pow_value}; + auto func = UnsignedPowFunctor(p); + const auto& cuda_ctx = + context.template device_context(); + + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, func); + std::vector reduce_axis = {0, 2}; + TensorReduceFunctorImpl>( + pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis, stream); + RenormKernelFunc3<<>>( + numel, dim_value.mutable_data(context.GetPlace()), p, max_norm); + RenormKernelFunc4<<>>( + x_data, out_data, numel, dim_value.mutable_data(context.GetPlace()), + dimension_each, dim_divisor); + // platform::GpuStreamSync(stream); + } +}; + +template +class CUDAGradRenormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + T* dout_data = (T*)d_out->data(); + T* x_data = (T*)x->data(); + auto input_dims = x->dims(); + float max_norm = ctx.Attr("max_norm"); + float p = ctx.Attr("p"); + int dim = ctx.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + d_x->Resize(framework::make_ddim(framework::vectorize(input_dims))); + T* dx_data = d_x->mutable_data(ctx.GetPlace()); + framework::Tensor pow_value, mul_value, dim_value, dim_power_sum, + weight_derivative; + pow_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + mul_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(framework::make_ddim({dimension_each})); + dim_power_sum.Resize(framework::make_ddim({dimension_each})); + weight_derivative.Resize(framework::make_ddim({dimension_each})); + auto stream = ctx.cuda_device_context().stream(); + int block = std::min(numel, static_cast(256)); + int grid = (numel + block - 1) / block; + pow_value.mutable_data(ctx.GetPlace()); + mul_value.mutable_data(ctx.GetPlace()); + dim_value.mutable_data(ctx.GetPlace()); + dim_power_sum.mutable_data(ctx.GetPlace()); + weight_derivative.mutable_data(ctx.GetPlace()); + RenormGradKernelFunc1<<>>( + x_data, dout_data, pow_value.mutable_data(ctx.GetPlace()), + mul_value.mutable_data(ctx.GetPlace()), numel, dimension_each, p, + dim_divisor); + std::vector reduce_axis = {0, 2}; + TensorReduceFunctorImpl>( + pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis, stream); + TensorReduceFunctorImpl>( + mul_value, &weight_derivative, kps::IdentityFunctor(), reduce_axis, + stream); + RenormGradKernelFunc2<<>>( + x_data, dout_data, dx_data, numel, + dim_value.mutable_data(ctx.GetPlace()), + dim_power_sum.mutable_data(ctx.GetPlace()), + weight_derivative.mutable_data(ctx.GetPlace()), dimension_each, p, + max_norm, dim_divisor); + // platform::GpuStreamSync(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(renorm, ops::CUDARenormKernel, + ops::CUDARenormKernel); + +REGISTER_OP_CUDA_KERNEL(renorm_grad, ops::CUDAGradRenormKernel, + ops::CUDAGradRenormKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/renorm_op.h b/paddle/fluid/operators/renorm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..461f383ad25639fe2db9b64eb490ad1e7a769a4a --- /dev/null +++ b/paddle/fluid/operators/renorm_op.h @@ -0,0 +1,191 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "math.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +// template +// struct NormDimValueFunctor { +// NormDimValueFunctor(T* input, T* output, int64_t dim_divisor, int64_t +// dimension_each, float p) +// : input_(input), output_(output),dim_divisor_(dim_divisor), +// dimension_each_(dimension_each),p_(p) {} + +// HOSTDEVICE void operator()(int64_t i) const { +// auto dim_index = i / dim_divsor % dimension_each; +// dim_value[dim_index] += std::pow(std::abs(input[i]), p); +// } + +// T* input_; +// T* output_; +// int64_t dimension_each_, dim_divisor_; +// float p_,max_norm_; + +// }; +// template +template +class CPURenormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + auto numel = x->numel(); + auto* x_data = x->data(); + auto input_dims = x->dims(); + float max_norm = context.Attr("max_norm"); + float p = context.Attr("p"); + int dim = context.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + + // auto& dev_ctx = ctx.template device_context(); + // std::vector dim_index(dim_size, 0); + std::vector dim_value(dimension_each, + 0); // dim_value = (x1^p + x2^p + x3^p....)^(1/p) + + auto* out_data = + out->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); + + int64_t index = 0, dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + // auto dim_index = i / dim_divsor % dimension_each; + dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + for (int64_t i = 0; i < dimension_each; i++) { + dim_value[i] = std::pow(dim_value[i], 1.0 / p); + if (dim_value[i] > max_norm) + dim_value[i] = max_norm / dim_value[i]; + else + dim_value[i] = 1.0; + // dim_index[i] = 0; + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + // auto dim_index = i / dim_divsor % dimension_each; + out_data[i] = dim_value[dim_index] < 1.0 + ? dim_value[dim_index] * x_data[i] + : x_data[i]; + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + } +}; + +// template +template +class CPURenormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data(); + auto* x_data = x->data(); + auto input_dims = x->dims(); + float max_norm = ctx.Attr("max_norm"); + float p = ctx.Attr("p"); + int dim = ctx.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + auto* dx_data = d_x->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + std::vector dim_value(dimension_each, 0), + dim_power_sum(dimension_each, 0), + weight_derivative(dimension_each, 0.0); + int64_t index = 0, dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + // auto dim_index = i / dim_divsor % dimension_each; + dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + for (int64_t i = 0; i < dimension_each; i++) { + auto temp = std::pow(dim_value[i], 1.0 / p); + if (temp > max_norm) { + dim_power_sum[i] = + std::pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; + dim_value[i] = max_norm / temp; + } else + dim_value[i] = 1.0; + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + // auto dim_index = i / dim_divsor % dimension_each; + dx_data[i] = dim_value[dim_index] * dout_data[i]; + weight_derivative[dim_index] += x_data[i] * dout_data[i]; + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + // auto dim_index = i / dim_divsor % dimension_each; + dx_data[i] += weight_derivative[dim_index] * dim_power_sum[dim_index] * + std::pow(std::abs(x_data[i]), p - 1.0) * + (x_data[i] >= 0 ? 1 : -1); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7791755644b6657a7ed2c63cb06219d6e382ee5e..7d0ec4981b8257bdf9dc24854d21d10150ac7de3 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -211,6 +211,7 @@ from .tensor.math import remainder # noqa: F401 from .tensor.math import mod # noqa: F401 from .tensor.math import floor_mod # noqa: F401 from .tensor.math import multiply # noqa: F401 +from .tensor.math import renorm # noqa: F401 from .tensor.math import add # noqa: F401 from .tensor.math import subtract # noqa: F401 from .tensor.math import logsumexp # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_renorm_op.py b/python/paddle/fluid/tests/unittests/test_renorm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea2002a9786fdd3f6c034e84176d0cae46ca591 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_renorm_op.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.set_device('cpu') + + +class TestRenormAPI(unittest.TestCase): + def input_data(self): + self.data_x = np.array( + [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]]) + self.p = 1.0 + self.dim = 2 + self.max_norm = 2.05 + + def test_renorm_api(self): + paddle.enable_static() + self.input_data() + + # case 1: + with program_guard(Program(), Program()): + #x = fluid.layers.data(name = 'x',shape=[-1, 2, 3]) + x = paddle.static.data(name="x", shape=[-1, 2, 3], dtype='float64') + z = paddle.renorm(x, self.p, self.dim, self.max_norm) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(feed={"x": self.data_x}, + fetch_list=[z], + return_numpy=False) + expected = np.array([[[0.40594056, 0.29285714, -0.41000000], + [0.60891086, 0.04392857, 0.61500001]], + [[0.40594056, -1.17142856, 0.41000000], + [0.62920785, 0.54178572, 0.61500001]]]) + self.assertTrue(np.allclose(expected, np.array(res))) + + def test_dygraph_api(self): + self.input_data() + # case axis none + with fluid.dygraph.guard(): + input = [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]] + x = paddle.to_tensor(input, stop_gradient=False) + y = paddle.renorm(x, 1.0, 2, 2.05) + expected = np.array([[[0.40594056, 0.29285714, -0.41000000], + [0.60891086, 0.04392857, 0.61500001]], + [[0.40594056, -1.17142856, 0.41000000], + [0.62920785, 0.54178572, 0.61500001]]]) + self.assertTrue(np.allclose(expected, np.array(y))) + z = paddle.mean(y) + z.backward(retain_graph=True) + expected_grad = np.array( + [[[0, 0.01394558, 0.02733333], [0, 0.01394558, 0.00683333]], + [[0, 0.01045918, 0.00683333], [0, 0.01394558, 0.00683333]]]) + self.assertTrue(np.allclose(expected_grad, np.array(x.grad))) + #test exception: + with fluid.dygraph.guard(): + input = [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]] + x = paddle.to_tensor(input, stop_gradient=False) + exp = False + try: + paddle.renorm(x, 1.0, 8, 2.05) + except: + exp = True + self.assertTrue(exp) + exp = False + try: + paddle.renorm(x, 1.0, -4, 2.05) + except: + exp = True + self.assertTrue(exp) + y = paddle.renorm(x, 1.0, -1, 2.05) + expected = np.array([[[0.40594056, 0.29285714, -0.41000000], + [0.60891086, 0.04392857, 0.61500001]], + [[0.40594056, -1.17142856, 0.41000000], + [0.62920785, 0.54178572, 0.61500001]]]) + self.assertTrue(np.allclose(expected, np.array(y))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 47ca3fed0bdff8e02b070398aa523e4fa7c9f4fc..dcc241b583055f0e60c35055989a2e3304314757 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1194,6 +1194,62 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): type="addmm", inputs=inputs, attrs=attrs, outputs={"Out": out}) return out +def renorm(x, p, axis, max_norm): + """ + **renorm** + + This operator is used to calculate the p-norm along the axis, + suppose the input-shape on axis dimension has the value of T, then + the tensor is split into T parts, the p-norm should be calculated for each + part, if the p-norm for part i is larger than max-norm, then each element + in part i should be re-normalized at the same scale so that part-i' p-norm equals + max-norm exactly, otherwise part-i stays unchanged. + + Args: + x (Tensor): The input Tensor + p (float): The power of the norm operation. + axis (int): the dimension to slice the tensor. + max-norm (float): the maximal norm limit. + + Returns: + Tensor: the renorm Tensor. + + Examples: + .. code-block:: python + + import paddle + input = [[[2.0,2,-2],[3,0.3,3]],[[2,-8,2],[3.1,3.7,3]]] + x = paddle.to_tensor(input,dtype='float32') + y = paddle.renorm(x, 1.0, 2, 2.05) + print(y) + # [[[ 0.40594056, 0.29285714, -0.41000000], + # [ 0.60891086, 0.04392857, 0.61500001]], + # [[ 0.40594056, -1.17142856, 0.41000000], + # [ 0.62920785, 0.54178572, 0.61500001]]]) + + """ + input_shape = x.shape + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'renorm') + if not axis < len(input_shape): + raise ValueError("the axis:{} should be less then the shape's size {}:{}".format(axis,len(input_shape),input_shape)) + if not axis >=0: + if not axis >= -1 * len(input_shape): + raise ValueError("the axis:{} should not be less than -1 * length of input_shape:{}".format(axis,-1 * len(input_shape))) + axis = axis + len(input_shape) + if in_dygraph_mode(): + out = core.ops.renorm(x, 'p',p, 'axis',axis, 'max_norm', max_norm) + return out + + inputs = {'X': x} + attrs = {'p': p, 'axis': axis, 'max_norm':max_norm} + + helper = LayerHelper("renorm", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="renorm", inputs=inputs, attrs=attrs, outputs={"Out": out}) + return out + def inner(x, y, name=None):