diff --git a/paddle/fluid/operators/renorm_op.cc b/paddle/fluid/operators/renorm_op.cc index 98f65b9dce0be40360cc2100db176d94028499e9..9ed911f8f69a075323d5183e99bf8147d4b3ec50 100644 --- a/paddle/fluid/operators/renorm_op.cc +++ b/paddle/fluid/operators/renorm_op.cc @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/renorm_op.h" - #include #include #include #include -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -29,15 +28,6 @@ class RenormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; using DDim = paddle::framework::DDim; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs"); - - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim("Out", in_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class RenormOpMaker : public framework::OpProtoAndCheckerMaker { @@ -70,26 +60,6 @@ This operator is used to scale tensor sliced by axis if its p-norm execeeds maxn class RenormGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@Grad", - "AbsGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@Grad", - "AbsGrad"); - - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(dtype, ctx.GetPlace()); - } }; template @@ -110,18 +80,19 @@ class RenormGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(renorm, + RenormInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(renorm_grad, + RenormGradInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(renorm, ops::RenormOp, ops::RenormOpMaker, ops::RenormGradMaker, - ops::RenormGradMaker); - -REGISTER_OPERATOR(renorm_grad, ops::RenormGradOp); - -REGISTER_OP_CPU_KERNEL(renorm, - ops::CPURenormKernel, - ops::CPURenormKernel); + ops::RenormGradMaker, + RenormInferShapeFunctor) -REGISTER_OP_CPU_KERNEL(renorm_grad, - ops::CPURenormGradKernel, - ops::CPURenormGradKernel); +REGISTER_OPERATOR(renorm_grad, ops::RenormGradOp, RenormGradInferShapeFunctor); diff --git a/paddle/fluid/operators/renorm_op.cu b/paddle/fluid/operators/renorm_op.cu deleted file mode 100644 index ea21b985e7f7b0bf067bc448c382cf3cf9a13fbe..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/renorm_op.cu +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/renorm_op.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -__device__ __forceinline__ float inline_pow(float base, float exponent) { - return pow(base, exponent); -} - -__device__ __forceinline__ double inline_pow(double base, double exponent) { - return pow(base, exponent); -} - -__device__ __forceinline__ float inline_abs(float x) { return abs(x); } -__device__ __forceinline__ double inline_abs(double x) { return abs(x); } - -template -struct UnsignedPowFunctor { - HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { - this->porder = porder; - } - HOSTDEVICE inline Ty operator()(const Tx x) const { - return static_cast(inline_pow(inline_abs(x), static_cast(porder))); - } - float porder; -}; - -template -__global__ void RenormKernelFunc3(int64_t size, - T* dim_value, - float p, - float max_norm) { - int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; - if (i < size) { - T temp = pow(dim_value[i], (T)(1.0 / p)); - dim_value[i] = 1.0; - if (temp > max_norm) dim_value[i] = max_norm / temp; - } -} - -template -__global__ void RenormKernelFunc4(const T* x_data, - T* out_data, - int64_t size, - T* dim_value, - int64_t dimension_each, - int64_t dim_divisor) { - int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; - auto dim_index = i / dim_divisor % dimension_each; - if (i < size) { - if (dim_value[dim_index] < 1.0) - out_data[i] = dim_value[dim_index] * x_data[i]; - else - out_data[i] = x_data[i]; - } -} - -template -__global__ void RenormGradKernelFunc1(const T* x_data, - const T* dout_data, - T* pow_value, - T* mul_value, - int64_t size, - int64_t dimension_each, - float p, - int64_t dim_divisor) { - int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; - auto dim_index = i / dim_divisor % dimension_each; - if (i < size) { - pow_value[i] = pow(abs(x_data[i]), (T)p); - mul_value[i] = x_data[i] * dout_data[i]; - } -} - -template -__global__ void RenormGradKernelFunc2(const T* x_data, - const T* dout_data, - T* dx_data, - int64_t size, - T* dim_value, - T* dim_power_sum, - T* weight_derivative, - int64_t dimension_each, - float p, - float max_norm, - int64_t dim_divisor) { - int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; - auto dim_index = i / dim_divisor % dimension_each; - if (i < dimension_each) { - dim_power_sum[i] = 0; - auto temp = pow(dim_value[i], (T)(1.0 / p)); - if (temp > max_norm) { - dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; - dim_value[i] = max_norm / temp; - } else { - dim_value[i] = 1.0; - } - } - __syncthreads(); - if (i < size) { - dx_data[i] = dim_value[dim_index] * dout_data[i]; - dx_data[i] = dx_data[i] + weight_derivative[dim_index] * - dim_power_sum[dim_index] * - pow(abs(x_data[i]), T(p - 1.0)) * - (x_data[i] >= 0 ? 1 : -1); - } -} - -template -class CUDARenormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - auto numel = x->numel(); - const T* x_data = x->data(); - auto input_dims = x->dims(); - float max_norm = context.Attr("max_norm"); - float p = context.Attr("p"); - int dim = context.Attr("axis"); - auto dimension_each = input_dims[dim]; - auto dim_size = input_dims.size(); - framework::Tensor pow_value, dim_value; - int64_t dim_divisor = 1, pre_mul = 1; - for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; - for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; - pow_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); - dim_value.Resize(phi::make_ddim({dimension_each})); - pow_value.mutable_data(context.GetPlace()); - out->Resize(phi::make_ddim(phi::vectorize(input_dims))); - T* out_data = out->mutable_data(context.GetPlace()); - auto stream = context.cuda_device_context().stream(); - int block = std::min(numel, static_cast(256)); - using MT = typename details::MPTypeTrait::Type; - int grid = (numel + block - 1) / block; - - int block2 = std::min(dimension_each, static_cast(256)); - int grid2 = (dimension_each + block2 - 1) / block2; - std::vector ins = {x}; - std::vector outs = {&pow_value}; - auto func = UnsignedPowFunctor(p); - const auto& cuda_ctx = context.template device_context(); - - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, &outs, func); - std::vector reduce_axis = {0, 2}; - TensorReduceImpl>( - cuda_ctx, - pow_value, - &dim_value, - kps::IdentityFunctor(), - reduce_axis, - stream); - RenormKernelFunc3<<>>( - numel, dim_value.mutable_data(context.GetPlace()), p, max_norm); - RenormKernelFunc4<<>>( - x_data, - out_data, - numel, - dim_value.mutable_data(context.GetPlace()), - dimension_each, - dim_divisor); - // platform::GpuStreamSync(stream); - } -}; - -template -class CUDAGradRenormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - const T* dout_data = d_out->data(); - const T* x_data = x->data(); - auto input_dims = x->dims(); - float max_norm = ctx.Attr("max_norm"); - float p = ctx.Attr("p"); - int dim = ctx.Attr("axis"); - auto dimension_each = input_dims[dim]; - auto dim_size = input_dims.size(); - int64_t dim_divisor = 1, pre_mul = 1; - for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; - for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; - d_x->Resize(phi::make_ddim(phi::vectorize(input_dims))); - T* dx_data = d_x->mutable_data(ctx.GetPlace()); - framework::Tensor pow_value, mul_value, dim_value, dim_power_sum, - weight_derivative; - pow_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); - mul_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); - dim_value.Resize(phi::make_ddim({dimension_each})); - dim_power_sum.Resize(phi::make_ddim({dimension_each})); - weight_derivative.Resize(phi::make_ddim({dimension_each})); - auto stream = ctx.cuda_device_context().stream(); - int block = std::min(numel, static_cast(256)); - int grid = (numel + block - 1) / block; - pow_value.mutable_data(ctx.GetPlace()); - mul_value.mutable_data(ctx.GetPlace()); - dim_value.mutable_data(ctx.GetPlace()); - dim_power_sum.mutable_data(ctx.GetPlace()); - weight_derivative.mutable_data(ctx.GetPlace()); - RenormGradKernelFunc1 - <<>>(x_data, - dout_data, - pow_value.mutable_data(ctx.GetPlace()), - mul_value.mutable_data(ctx.GetPlace()), - numel, - dimension_each, - p, - dim_divisor); - std::vector reduce_axis = {0, 2}; - TensorReduceImpl>( - ctx.cuda_device_context(), - pow_value, - &dim_value, - kps::IdentityFunctor(), - reduce_axis, - stream); - TensorReduceImpl>( - ctx.cuda_device_context(), - mul_value, - &weight_derivative, - kps::IdentityFunctor(), - reduce_axis, - stream); - RenormGradKernelFunc2<<>>( - x_data, - dout_data, - dx_data, - numel, - dim_value.mutable_data(ctx.GetPlace()), - dim_power_sum.mutable_data(ctx.GetPlace()), - weight_derivative.mutable_data(ctx.GetPlace()), - dimension_each, - p, - max_norm, - dim_divisor); - // platform::GpuStreamSync(stream); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(renorm, - ops::CUDARenormKernel, - ops::CUDARenormKernel); - -REGISTER_OP_CUDA_KERNEL(renorm_grad, - ops::CUDAGradRenormKernel, - ops::CUDAGradRenormKernel); diff --git a/paddle/fluid/operators/renorm_op.h b/paddle/fluid/operators/renorm_op.h deleted file mode 100644 index 8a4d3bd3258be70291d42234f0fbd63b081387c1..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/renorm_op.h +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "math.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -// template -// struct NormDimValueFunctor { -// NormDimValueFunctor(T* input, T* output, int64_t dim_divisor, int64_t -// dimension_each, float p) -// : input_(input), output_(output),dim_divisor_(dim_divisor), -// dimension_each_(dimension_each),p_(p) {} - -// HOSTDEVICE void operator()(int64_t i) const { -// auto dim_index = i / dim_divsor % dimension_each; -// dim_value[dim_index] += std::pow(std::abs(input[i]), p); -// } - -// T* input_; -// T* output_; -// int64_t dimension_each_, dim_divisor_; -// float p_,max_norm_; - -// }; -// template -template -class CPURenormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - auto numel = x->numel(); - auto* x_data = x->data(); - auto input_dims = x->dims(); - float max_norm = context.Attr("max_norm"); - float p = context.Attr("p"); - int dim = context.Attr("axis"); - auto dimension_each = input_dims[dim]; - auto dim_size = input_dims.size(); - int64_t dim_divisor = 1; - for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; - - // auto& dev_ctx = ctx.template device_context(); - // std::vector dim_index(dim_size, 0); - std::vector dim_value(dimension_each, - 0); // dim_value = (x1^p + x2^p + x3^p....)^(1/p) - - auto* out_data = - out->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); - - int64_t index = 0, dim_index = 0; - for (int64_t i = 0; i < numel; i++) { - // auto dim_index = i / dim_divsor % dimension_each; - dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); - index++; - if (index == dim_divisor) { - dim_index++; - if (dim_index == dimension_each) { - dim_index = 0; - } - index = 0; - } - } - for (int64_t i = 0; i < dimension_each; i++) { - dim_value[i] = std::pow(dim_value[i], 1.0 / p); - if (dim_value[i] > max_norm) - dim_value[i] = max_norm / dim_value[i]; - else - dim_value[i] = 1.0; - // dim_index[i] = 0; - } - index = dim_index = 0; - for (int64_t i = 0; i < numel; i++) { - // auto dim_index = i / dim_divsor % dimension_each; - out_data[i] = dim_value[dim_index] < 1.0 - ? dim_value[dim_index] * x_data[i] - : x_data[i]; - index++; - if (index == dim_divisor) { - dim_index++; - if (dim_index == dimension_each) { - dim_index = 0; - } - index = 0; - } - } - } -}; - -// template -template -class CPURenormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - auto* dout_data = d_out->data(); - auto* x_data = x->data(); - auto input_dims = x->dims(); - float max_norm = ctx.Attr("max_norm"); - float p = ctx.Attr("p"); - int dim = ctx.Attr("axis"); - auto dimension_each = input_dims[dim]; - auto dim_size = input_dims.size(); - int64_t dim_divisor = 1; - for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; - auto* dx_data = d_x->mutable_data( - ctx.GetPlace(), static_cast(numel * sizeof(T))); - std::vector dim_value(dimension_each, 0), - dim_power_sum(dimension_each, 0), - weight_derivative(dimension_each, 0.0); - int64_t index = 0, dim_index = 0; - for (int64_t i = 0; i < numel; i++) { - // auto dim_index = i / dim_divsor % dimension_each; - dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); - index++; - if (index == dim_divisor) { - dim_index++; - if (dim_index == dimension_each) { - dim_index = 0; - } - index = 0; - } - } - for (int64_t i = 0; i < dimension_each; i++) { - auto temp = std::pow(dim_value[i], 1.0 / p); - if (temp > max_norm) { - dim_power_sum[i] = - std::pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; - dim_value[i] = max_norm / temp; - } else - dim_value[i] = 1.0; - } - index = dim_index = 0; - for (int64_t i = 0; i < numel; i++) { - // auto dim_index = i / dim_divsor % dimension_each; - dx_data[i] = dim_value[dim_index] * dout_data[i]; - weight_derivative[dim_index] += x_data[i] * dout_data[i]; - index++; - if (index == dim_divisor) { - dim_index++; - if (dim_index == dimension_each) { - dim_index = 0; - } - index = 0; - } - } - index = dim_index = 0; - for (int64_t i = 0; i < numel; i++) { - // auto dim_index = i / dim_divsor % dimension_each; - dx_data[i] += weight_derivative[dim_index] * dim_power_sum[dim_index] * - std::pow(std::abs(x_data[i]), p - 1.0) * - (x_data[i] >= 0 ? 1 : -1); - index++; - if (index == dim_divisor) { - dim_index++; - if (dim_index == dimension_each) { - dim_index = 0; - } - index = 0; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 8b74c62cf435123e69f1e46bfb258e5c095c9077..51be045ce4cf8efb7f641f7e2e6b1344a80e7f86 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1931,6 +1931,16 @@ func : relu6 backward : relu6_grad +- api : renorm + args : (Tensor x, float p, int axis, float max_norm) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : renorm + backward : renorm_grad + - api : reshape args : (Tensor x, IntArray shape) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 909f2488386b228a085b7e4a19733b6e73bce61a..4f5fdbc32c5bcce03961c32d4565d221f1721021 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1780,6 +1780,16 @@ backward: relu_double_grad inplace : (out_grad -> x_grad) +- backward_api : renorm_grad + forward : renorm (Tensor x, float p, int axis, float max_norm) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float p, int axis, float max_norm) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : renorm_grad + - backward_api : reshape_double_grad forward : reshape_grad (Tensor xshape, Tensor grad_out) -> Tensor(grad_x) args : (Tensor grad_out, Tensor grad_x_grad) diff --git a/paddle/phi/kernels/cpu/renorm_grad_kernel.cc b/paddle/phi/kernels/cpu/renorm_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f61bceb2cf05f2a3ec456c2f8d2770103693e871 --- /dev/null +++ b/paddle/phi/kernels/cpu/renorm_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/renorm_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/renorm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + renorm_grad, CPU, ALL_LAYOUT, phi::RenormGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/renorm_kernel.cc b/paddle/phi/kernels/cpu/renorm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..04f990a643cc4e21e7bf7ee2f1e563b21e7267d4 --- /dev/null +++ b/paddle/phi/kernels/cpu/renorm_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/renorm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/renorm_kernel_impl.h" + +PD_REGISTER_KERNEL(renorm, CPU, ALL_LAYOUT, phi::RenormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/renorm_grad_kernel.cu b/paddle/phi/kernels/gpu/renorm_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..48d6e46ccf1b655ef2e00e3a3cd281bd5cc07b03 --- /dev/null +++ b/paddle/phi/kernels/gpu/renorm_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/renorm_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/renorm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + renorm_grad, GPU, ALL_LAYOUT, phi::RenormGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/renorm_kernel.cu b/paddle/phi/kernels/gpu/renorm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..752882610c71b0426564354de368669d5dee3910 --- /dev/null +++ b/paddle/phi/kernels/gpu/renorm_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/renorm_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/renorm_kernel_impl.h" + +PD_REGISTER_KERNEL(renorm, GPU, ALL_LAYOUT, phi::RenormKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/renorm_grad_kernel_impl.h b/paddle/phi/kernels/impl/renorm_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9ae5b30296bef175def14d2c33526f24ff3a1b89 --- /dev/null +++ b/paddle/phi/kernels/impl/renorm_grad_kernel_impl.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/impl/renorm_impl.h" +#include "paddle/phi/kernels/renorm_grad_kernel.h" + +namespace phi { + +template +void RenormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float p, + int axis, + float max_norm, + DenseTensor* dx) { + int64_t numel = dout.numel(); + const T* dout_data = dout.template data(); + const T* x_data = x.template data(); + auto input_dims = x.dims(); + int dim = axis; + auto dimension_each = input_dims[dim]; + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + phi::funcs::RenormGradFunc(dev_ctx, + x_data, + dout_data, + dx->data(), + p, + dim, + max_norm, + dimension_each, + input_dims, + numel); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/renorm_impl.h b/paddle/phi/kernels/impl/renorm_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..edd32473085807059e0925746a3e4694415f79a9 --- /dev/null +++ b/paddle/phi/kernels/impl/renorm_impl.h @@ -0,0 +1,362 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" +#ifdef __NVCC__ +#include "cub/cub.cuh" +#else +#include +namespace cub = hipcub; +#endif +#endif + +namespace phi { +namespace funcs { + +template +void RenormFunc(const phi::CPUContext& ctx, + const T* x_data, + T* out_data, + float p, + int dim, + float max_norm, + int64_t dimension_each, + phi::DDim& input_dims, + int64_t numel) { + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + + std::vector dim_value(dimension_each, + 0); // dim_value = (x1^p + x2^p + x3^p....)^(1/p) + + int64_t index = 0, dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + for (int64_t i = 0; i < dimension_each; i++) { + dim_value[i] = std::pow(dim_value[i], 1.0 / p); + if (dim_value[i] > max_norm) + dim_value[i] = max_norm / dim_value[i]; + else + dim_value[i] = 1.0; + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + out_data[i] = dim_value[dim_index] < 1.0 ? dim_value[dim_index] * x_data[i] + : x_data[i]; + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } +} + +template +void RenormGradFunc(const phi::CPUContext& ctx, + const T* x_data, + const T* dout_data, + T* dx_data, + float p, + int dim, + float max_norm, + int64_t dimension_each, + phi::DDim& input_dims, + int64_t numel) { + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + std::vector dim_value(dimension_each, 0), dim_power_sum(dimension_each, 0), + weight_derivative(dimension_each, 0.0); + int64_t index = 0, dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + dim_value[dim_index] += std::pow(std::abs(x_data[i]), p); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + for (int64_t i = 0; i < dimension_each; i++) { + auto temp = std::pow(dim_value[i], 1.0 / p); + if (temp > max_norm) { + dim_power_sum[i] = + std::pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; + dim_value[i] = max_norm / temp; + } else + dim_value[i] = 1.0; + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + dx_data[i] = dim_value[dim_index] * dout_data[i]; + weight_derivative[dim_index] += x_data[i] * dout_data[i]; + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } + index = dim_index = 0; + for (int64_t i = 0; i < numel; i++) { + dx_data[i] += weight_derivative[dim_index] * dim_power_sum[dim_index] * + std::pow(std::abs(x_data[i]), p - 1.0) * + (x_data[i] >= 0 ? 1 : -1); + index++; + if (index == dim_divisor) { + dim_index++; + if (dim_index == dimension_each) { + dim_index = 0; + } + index = 0; + } + } +} + +#if defined(__NVCC__) || defined(__HIPCC__) +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; + } + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + +template +__global__ void RenormKernelFunc3(int64_t size, + T* dim_value, + float p, + float max_norm) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + if (i < size) { + T temp = pow(dim_value[i], (T)(1.0 / p)); + dim_value[i] = 1.0; + if (temp > max_norm) dim_value[i] = max_norm / temp; + } +} + +template +__global__ void RenormKernelFunc4(const T* x_data, + T* out_data, + int64_t size, + T* dim_value, + int64_t dimension_each, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + if (dim_value[dim_index] < 1.0) + out_data[i] = dim_value[dim_index] * x_data[i]; + else + out_data[i] = x_data[i]; + } +} + +template +__global__ void RenormElementwisePow(const T* x_data, + T* pow_value, + int64_t size, + float p) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + if (i < size) { + pow_value[i] = pow(abs(x_data[i]), (T)p); + } +} + +template +__global__ void RenormGradKernelFunc1(const T* x_data, + const T* dout_data, + T* pow_value, + T* mul_value, + int64_t size, + int64_t dimension_each, + float p, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + pow_value[i] = pow(abs(x_data[i]), (T)p); + mul_value[i] = x_data[i] * dout_data[i]; + } +} + +template +__global__ void RenormGradKernelFunc2(const T* x_data, + const T* dout_data, + T* dx_data, + int64_t size, + T* dim_value, + T* dim_power_sum, + T* weight_derivative, + int64_t dimension_each, + float p, + float max_norm, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < dimension_each) { + dim_power_sum[i] = 0; + auto temp = pow(dim_value[i], (T)(1.0 / p)); + if (temp > max_norm) { + dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; + dim_value[i] = max_norm / temp; + } else { + dim_value[i] = 1.0; + } + } + __syncthreads(); + if (i < size) { + dx_data[i] = dim_value[dim_index] * dout_data[i]; + dx_data[i] = dx_data[i] + weight_derivative[dim_index] * + dim_power_sum[dim_index] * + pow(abs(x_data[i]), T(p - 1.0)) * + (x_data[i] >= 0 ? 1 : -1); + } +} + +template +void RenormFunc(const phi::GPUContext& ctx, + const T* x_data, + T* out_data, + float p, + int dim, + float max_norm, + int64_t dimension_each, + phi::DDim& input_dims, + int64_t numel) { + auto dim_size = input_dims.size(); + DenseTensor pow_value, dim_value; + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + pow_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(phi::make_ddim({dimension_each})); + T* pow_value_data = ctx.template Alloc(&pow_value); + T* dim_value_data = ctx.template Alloc(&dim_value); + auto stream = ctx.stream(); + int block = std::min(numel, static_cast(256)); + int grid = (numel + block - 1) / block; + RenormElementwisePow + <<>>(x_data, pow_value_data, numel, p); + int block2 = std::min(dimension_each, static_cast(256)); + int grid2 = (dimension_each + block2 - 1) / block2; + std::vector reduce_axis = {0, 2}; + phi::funcs::ReduceKernel>( + ctx, pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis); + RenormKernelFunc3 + <<>>(numel, dim_value_data, p, max_norm); + RenormKernelFunc4<<>>( + x_data, out_data, numel, dim_value_data, dimension_each, dim_divisor); +} + +template +void RenormGradFunc(const phi::GPUContext& ctx, + const T* x_data, + const T* dout_data, + T* dx_data, + float p, + int dim, + float max_norm, + int64_t dimension_each, + phi::DDim& input_dims, + int64_t numel) { + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + DenseTensor pow_value, mul_value, dim_value, dim_power_sum, weight_derivative; + pow_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); + mul_value.Resize(phi::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(phi::make_ddim({dimension_each})); + dim_power_sum.Resize(phi::make_ddim({dimension_each})); + weight_derivative.Resize(phi::make_ddim({dimension_each})); + auto stream = ctx.stream(); + int block = std::min(numel, static_cast(256)); + int grid = (numel + block - 1) / block; + T* pow_value_data = ctx.template Alloc(&pow_value); + T* mul_value_data = ctx.template Alloc(&mul_value); + T* dim_value_data = ctx.template Alloc(&dim_value); + T* dim_power_sum_data = ctx.template Alloc(&dim_power_sum); + T* weight_derivative_data = ctx.template Alloc(&weight_derivative); + RenormGradKernelFunc1<<>>(x_data, + dout_data, + pow_value_data, + mul_value_data, + numel, + dimension_each, + p, + dim_divisor); + std::vector reduce_axis = {0, 2}; + phi::funcs::ReduceKernel>( + ctx, pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis); + phi::funcs::ReduceKernel>( + ctx, + mul_value, + &weight_derivative, + kps::IdentityFunctor(), + reduce_axis); + RenormGradKernelFunc2<<>>(x_data, + dout_data, + dx_data, + numel, + dim_value_data, + dim_power_sum_data, + weight_derivative_data, + dimension_each, + p, + max_norm, + dim_divisor); +} +#endif + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/impl/renorm_kernel_impl.h b/paddle/phi/kernels/impl/renorm_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..93aaa4a711de0f6a0200dce3cfc677b8b5e6145c --- /dev/null +++ b/paddle/phi/kernels/impl/renorm_kernel_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/impl/renorm_impl.h" +#include "paddle/phi/kernels/renorm_kernel.h" + +namespace phi { + +template +void RenormKernel(const Context& dev_ctx, + const DenseTensor& x, + float p, + int axis, + float max_norm, + DenseTensor* out) { + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + auto x_ptr = x.template data(); + auto numel = x.numel(); + int dim = axis; + auto input_dims = x.dims(); + auto dimension_each = input_dims[dim]; + + phi::funcs::RenormFunc(dev_ctx, + x_ptr, + out->data(), + p, + axis, + max_norm, + dimension_each, + input_dims, + numel); +} +} // namespace phi diff --git a/paddle/phi/kernels/renorm_grad_kernel.h b/paddle/phi/kernels/renorm_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..508cbe371b33c598eaa4c3e3c47a1a9abbad031c --- /dev/null +++ b/paddle/phi/kernels/renorm_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RenormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float p, + int axis, + float max_norm, + DenseTensor* dx); +} diff --git a/paddle/phi/kernels/renorm_kernel.h b/paddle/phi/kernels/renorm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e5db2a34a71f280a64a86330bee4fc28c95eda86 --- /dev/null +++ b/paddle/phi/kernels/renorm_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RenormKernel(const Context& dev_ctx, + const DenseTensor& x, + float p, + int axis, + float max_norm, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/compat/renorm_sig.cc b/paddle/phi/ops/compat/renorm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c5198dff37b045c954ecb4b3dae5edcd2bc697f --- /dev/null +++ b/paddle/phi/ops/compat/renorm_sig.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RenormOpArgumentMapping(const ArgumentMappingContext& ctx) { + VLOG(3) << "in renrom arguments mapping"; + return KernelSignature("renorm", {"X"}, {"p", "axis", "max_norm"}, {"Out"}); +} + +KernelSignature RenormGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + VLOG(3) << "in renrom grad arguments mapping"; + return KernelSignature( + "renorm_grad", {"X", "Out@GRAD"}, {"p", "axis", "max_norm"}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(renorm, phi::RenormOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(renorm_grad, phi::RenormGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a26a17d91fd03fdfef4acd0b20227808883f281a..73139c7fbbc24bd6e57f0a7b90041a05d40e1a8a 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1752,6 +1752,10 @@ py_test_modules( set_tests_properties(test_add_reader_dependency_for_interpretercore PROPERTIES TIMEOUT 120) +py_test_modules(test_renorm_op_without_eager MODULES test_renorm_op ENVS + FLAGS_enable_eager_mode=0) + +set_tests_properties(test_renorm_op_without_eager PROPERTIES TIMEOUT 120) py_test_modules( test_eager_deletion_padding_rnn_for_interpretercore MODULES test_eager_deletion_padding_rnn ENVS FLAGS_CONVERT_GRAPH_TO_PROGRAM=true) diff --git a/python/paddle/fluid/tests/unittests/test_renorm_op.py b/python/paddle/fluid/tests/unittests/test_renorm_op.py index e266800319db1c37bc305ef94a8f547df687b380..d25b0e9d2e538c50f427761d8bb239877fa59d79 100644 --- a/python/paddle/fluid/tests/unittests/test_renorm_op.py +++ b/python/paddle/fluid/tests/unittests/test_renorm_op.py @@ -71,7 +71,7 @@ class TestRenormAPI(unittest.TestCase): [[0, 0.01045918, 0.00683333], [0, 0.01394558, 0.00683333]]]) self.assertTrue(np.allclose(expected_grad, np.array(x.grad))) - #test exception: + # #test exception: with fluid.dygraph.guard(): input = [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]] x = paddle.to_tensor(input, stop_gradient=False) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index fe9d604251fc395b72317043f914a26f40813576..6d365622746e3f573c45ceca2904c5791cfeac46 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1799,7 +1799,10 @@ def renorm(x, p, axis, max_norm): if not axis >= -1 * len(input_shape): raise ValueError("the axis:{} should not be less than -1 * length of input_shape:{}".format(axis,-1 * len(input_shape))) axis = axis + len(input_shape) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + out = _C_ops.final_state_renorm(x, p, axis, max_norm) + return out + elif _in_legacy_dygraph(): out = _C_ops.renorm(x, 'p',p, 'axis',axis, 'max_norm', max_norm) return out