From f33763e30b0f0cd9aa3ea5fb59e1e292a1cde2e4 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 30 Jun 2022 21:29:30 +0800 Subject: [PATCH] Move apis(digamma, dist, dot) from legacy_api.yaml to api.yaml (#43956) * move standard apis to api.yaml * revert erfinv * delete dot_op.h * fix dot * rerun ci --- paddle/fluid/operators/digamma_op.cc | 98 ------------ paddle/fluid/operators/digamma_op.h | 18 --- paddle/fluid/operators/dist_op.cc | 140 ------------------ paddle/fluid/operators/dot_op.cc | 139 ----------------- paddle/fluid/operators/dot_op.cu | 36 ----- paddle/fluid/operators/dot_op.h | 83 ----------- paddle/fluid/operators/matmul_v2_op.h | 1 - paddle/phi/kernels/digamma_kernel.h | 7 + paddle/phi/kernels/dist_kernel.h | 31 ++++ paddle/phi/kernels/erfinv_kernel.h | 12 ++ paddle/phi/ops/compat/digamma_sig.cc | 26 ---- paddle/phi/ops/compat/dist_sig.cc | 26 ---- paddle/phi/ops/compat/dot_sig.cc | 26 ---- .../fluid/tests/unittests/test_dot_op.py | 36 +++-- python/paddle/tensor/linalg.py | 9 +- python/paddle/utils/code_gen/api.yaml | 28 ++++ python/paddle/utils/code_gen/backward.yaml | 31 ++++ python/paddle/utils/code_gen/legacy_api.yaml | 30 +--- .../utils/code_gen/legacy_backward.yaml | 20 --- 19 files changed, 141 insertions(+), 656 deletions(-) delete mode 100644 paddle/fluid/operators/digamma_op.cc delete mode 100644 paddle/fluid/operators/digamma_op.h delete mode 100644 paddle/fluid/operators/dist_op.cc delete mode 100644 paddle/fluid/operators/dot_op.cc delete mode 100644 paddle/fluid/operators/dot_op.cu delete mode 100644 paddle/fluid/operators/dot_op.h delete mode 100644 paddle/phi/ops/compat/digamma_sig.cc delete mode 100644 paddle/phi/ops/compat/dist_sig.cc delete mode 100644 paddle/phi/ops/compat/dot_sig.cc diff --git a/paddle/fluid/operators/digamma_op.cc b/paddle/fluid/operators/digamma_op.cc deleted file mode 100644 index 5f17c3b3da..0000000000 --- a/paddle/fluid/operators/digamma_op.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/digamma_op.h" - -namespace paddle { -namespace operators { - -class DigammaOp : public framework::OperatorWithKernel { - public: - DigammaOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Digamma"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Digamma"); - - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim("Out", in_dims); - ctx->ShareLoD("X", "Out"); - } -}; - -class DigammaOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of digamma operator."); - AddOutput("Out", "(Tensor), The output tensor of digamma operator."); - AddComment(R"DOC( -Digamma Operator. - -This operator is used to perform elementwise digamma for input $X$. -$$out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) }$$ - -)DOC"); - } -}; - -class DigammaGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@Grad", - "DigammaGrad"); - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DigammaGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@Grad", - "DigammaGrad"); - - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); - ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); - } -}; - -template -class DigammaGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr retv) const override { - retv->SetType("digamma_grad"); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetInput("X", this->Input("X")); - retv->SetAttrMap(this->Attrs()); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(digamma, - ops::DigammaOp, - ops::DigammaOpMaker, - ops::DigammaGradOpMaker, - ops::DigammaGradOpMaker); -REGISTER_OPERATOR(digamma_grad, ops::DigammaGradOp); diff --git a/paddle/fluid/operators/digamma_op.h b/paddle/fluid/operators/digamma_op.h deleted file mode 100644 index 85f9094e6a..0000000000 --- a/paddle/fluid/operators/digamma_op.h +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" diff --git a/paddle/fluid/operators/dist_op.cc b/paddle/fluid/operators/dist_op.cc deleted file mode 100644 index 49f8fa75aa..0000000000 --- a/paddle/fluid/operators/dist_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class DistOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dist"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Dist"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Dist"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - PADDLE_ENFORCE_NE(phi::product(x_dims), - 0, - platform::errors::InvalidArgument( - "The Input(X) has not been initialized properly. The " - "shape of Input(X) = [%s].", - x_dims)); - PADDLE_ENFORCE_NE(phi::product(y_dims), - 0, - platform::errors::InvalidArgument( - "The Input(Y) has not been initialized properly. The " - "shape of Input(Y) = [%s].", - y_dims)); - ctx->SetOutputDim("Out", {1}); - } -}; - -class DistOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input Tensor of Dist Op."); - AddInput("Y", "The Right-hand-side input Tensor of Dist Op."); - AddOutput("Out", - "The output of Dist Op, " - "which is the p-norm of (X - Y)"); - AddAttr("p", "the norm to be computed.").SetDefault(2.0f); - AddComment(R"DOC( -Dist Operator. -Given two tensors X and Y, compute Lp-norm of (X-Y). It is not a norm in a strict sense, -only as a measure of distance. The shapes of X and Y must be broadcastable. Where, Z = X - Y, - -When p = 0, defining $0^0 = 0$, the zero-norm of Z is simply the number of non-zero elements of z. -$$ -||Z||_{0} = \lim_{p \rightarrow 0} \sum_{i=1}^{m} |z_i|^p -$$ - -When p = inf, the inf-norm of Z is the maximum element of Z. -$$ -||Z||_\infty=\max_i |z_i| -$$ - -When p = -inf, the negative-inf-norm of Z is the minimum element of Z. -$$ -||Z||_{-\infty}=\min_i |z_i| -$$ - -Otherwise, the p-norm of Z follows the formula, -$$ -||Z||_{p} = (\sum_{i=i}^{m} |z_i|^p)^{1/p} -$$ - )DOC"); - } -}; - -class DistOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - if (ctx->HasOutput(framework::GradVarName("Y"))) { - ctx->SetOutputDim(framework::GradVarName("Y"), y_dims); - } - } -}; - -template -class DistGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType(this->ForwardOpType() + "_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("Y", this->Input("Y")); - op->SetInput("Out", this->Output("Out")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(dist, - DistInferShapeFunctor, - PD_INFER_META(phi::DistInferMeta)); - -REGISTER_OPERATOR(dist, - ops::DistOp, - ops::DistOpMaker, - ops::DistGradOpMaker, - ops::DistGradOpMaker, - DistInferShapeFunctor); -REGISTER_OPERATOR(dist_grad, ops::DistOpGrad); diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc deleted file mode 100644 index 880186b84c..0000000000 --- a/paddle/fluid/operators/dot_op.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/dot_op.h" - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class DotOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class DotOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() final { - AddInput("X", "(Tensor) The first input tensor. "); - AddInput("Y", "(Tensor) The second input tensor. "); - AddOutput("Out", "(Tensor) The result tensor."); - AddComment(""); - } -}; - -class DotGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - true, - ctx->HasInput("X"), - platform::errors::PreconditionNotMet("Input(X) should not be null.")); - PADDLE_ENFORCE_EQ( - true, - ctx->HasInput("Y"), - platform::errors::PreconditionNotMet("Input(Y) should not be null.")); - PADDLE_ENFORCE_EQ(true, - ctx->HasInput(framework::GradVarName("Out")), - platform::errors::PreconditionNotMet( - "Input(Out@GRAD) should not be null.")); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(x_grad_name)) { - ctx->ShareDim("X", /*->*/ x_grad_name); - ctx->ShareLoD("X", /*->*/ x_grad_name); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->ShareDim("Y", /*->*/ y_grad_name); - ctx->ShareLoD("Y", /*->*/ y_grad_name); - } - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; - -template -class DotOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("dot_grad"); - - op->SetInput("X", this->Input("X")); - op->SetInput("Y", this->Input("Y")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetAttrMap(this->Attrs()); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(dot, - DotInferShapeFunctor, - PD_INFER_META(phi::DotInferMeta)); - -REGISTER_OPERATOR(dot, - ops::DotOp, - ops::DotOpMaker, - ops::DotOpGradMaker, - ops::DotOpGradMaker, - DotInferShapeFunctor); - -REGISTER_OPERATOR(dot_grad, ops::DotGradOp); - -REGISTER_OP_CPU_KERNEL( - dot, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel>, - ops::DotKernel>); -REGISTER_OP_CPU_KERNEL( - dot_grad, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel>, - ops::DotGradKernel>); diff --git a/paddle/fluid/operators/dot_op.cu b/paddle/fluid/operators/dot_op.cu deleted file mode 100644 index 362a6a80f9..0000000000 --- a/paddle/fluid/operators/dot_op.cu +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/dot_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - dot, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel>, - ops::DotKernel>); -REGISTER_OP_CUDA_KERNEL(dot_grad, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel>, - ops::DotGradKernel>); diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h deleted file mode 100644 index 0f4c80c4c9..0000000000 --- a/paddle/fluid/operators/dot_op.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -// only can include the headers in paddle/phi/api dirs -#include "paddle/phi/api/lib/utils/tensor_utils.h" -#include "paddle/phi/kernels/dot_grad_kernel.h" -#include "paddle/phi/kernels/dot_kernel.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -// See Note [ Why still keep the original kernel implementation? ] -template -class DotKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - auto& dev_ctx = ctx.device_context(); - out->mutable_data(x->place()); - - // call new kernel - phi::DotKernel< - T, - typename paddle::framework::ConvertToPhiContext::TYPE>( - static_cast::TYPE&>(dev_ctx), - *x, - *y, - out); - } -}; - -template -class DotGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* tensor_x = ctx.Input("X"); - auto* tensor_y = ctx.Input("Y"); - auto* tensor_dout = ctx.Input(framework::GradVarName("Out")); - auto* tensor_dx = ctx.Output(framework::GradVarName("X")); - auto* tensor_dy = ctx.Output(framework::GradVarName("Y")); - - if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); - if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.device_context(); - - // call new kernel - phi::DotGradKernel( - static_cast::TYPE&>(dev_ctx), - *tensor_x, - *tensor_y, - *tensor_dout, - tensor_dx, - tensor_dy); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 36267b9f9a..8e436dd6af 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/complex_functors.h" diff --git a/paddle/phi/kernels/digamma_kernel.h b/paddle/phi/kernels/digamma_kernel.h index 3cf1eae67c..b45b7070d2 100644 --- a/paddle/phi/kernels/digamma_kernel.h +++ b/paddle/phi/kernels/digamma_kernel.h @@ -18,6 +18,13 @@ namespace phi { +/** + * @brief This kernrel is used to perform elementwise digamma for x. + * $$out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) }$$ + * @param ctx device context + * @param x the input tensor of digamma + * @param out the output tensor of digamma + */ template void DigammaKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); diff --git a/paddle/phi/kernels/dist_kernel.h b/paddle/phi/kernels/dist_kernel.h index 6cb3d6e0e8..8c1f6674aa 100644 --- a/paddle/phi/kernels/dist_kernel.h +++ b/paddle/phi/kernels/dist_kernel.h @@ -18,6 +18,37 @@ namespace phi { +/** + * @brief Given two tensors x and y, compute Lp-norm of (x-y). + * It is not a norm in a strict sense, only as a measure of distance. + * The shapes of x and y must be broadcastable. Where, z = x - y, + * + * When p = 0, defining $0^0 = 0$, the zero-norm of z is simply + * the number of non-zero elements of z. + * $$ + * ||z||_{0} = \lim_{p \rightarrow 0} \sum_{i=1}^{m} |z_i|^p + * $$ + * + * When p = inf, the inf-norm of z is the maximum element of z. + * $$ + * ||z||_\infty=\max_i |z_i| + * $$ + * + * When p = -inf, the negative-inf-norm of z is the minimum element of z. + * $$ + * ||z||_{-\infty}=\min_i |z_i| + * $$ + * + * Otherwise, the p-norm of z follows the formula, + * $$ + * ||z||_{p} = (\sum_{i=i}^{m} |z_i|^p)^{1/p} + * $$ + * @param ctx device context + * @param x the input Tensor of Dist + * @param y the Right-hand-side input Tensor of Dist + * @param p the norm to be computed + * @param out the output of Dist, which is the p-norm of (x - y) + */ template void DistKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/erfinv_kernel.h b/paddle/phi/kernels/erfinv_kernel.h index 8380a62971..3ddb1ecbdf 100644 --- a/paddle/phi/kernels/erfinv_kernel.h +++ b/paddle/phi/kernels/erfinv_kernel.h @@ -18,6 +18,18 @@ namespace phi { +/** + * @brief This kernel is used to compute inverse error function of x. + * + * The equation is: + * $$erfinv(x) = {ndtri({x \over 2} + 0.5)} \over {\sqrt{2}}$$ + * + * The input `x` can carry the LoD (Level of Details) information, + * or not. And the output shares the LoD information with `x` + * @param ctx device context + * @param x the input tensor of erfinv + * @param out the output tensor of erfinv + */ template void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); diff --git a/paddle/phi/ops/compat/digamma_sig.cc b/paddle/phi/ops/compat/digamma_sig.cc deleted file mode 100644 index 6c14dd9bf1..0000000000 --- a/paddle/phi/ops/compat/digamma_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature DigammaGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(digamma_grad, phi::DigammaGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/dist_sig.cc b/paddle/phi/ops/compat/dist_sig.cc deleted file mode 100644 index cc702fefbc..0000000000 --- a/paddle/phi/ops/compat/dist_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(dist_grad, phi::DistGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/dot_sig.cc b/paddle/phi/ops/compat/dot_sig.cc deleted file mode 100644 index 2187a7eb4f..0000000000 --- a/paddle/phi/ops/compat/dot_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(dot_grad, phi::DotGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py index 536f8fd8d8..ffdc90dd98 100644 --- a/python/paddle/fluid/tests/unittests/test_dot_op.py +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -27,6 +27,7 @@ class DotOp(OpTest): def setUp(self): self.op_type = "dot" + self.python_api = paddle.dot self.init_dtype() self.init_input_output() @@ -38,34 +39,43 @@ class DotOp(OpTest): self.attrs = {} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): if core.is_compiled_with_rocm(): self.check_grad( ['X', 'Y'], 'Out', - user_defined_grads=[self.inputs['Y'], self.inputs['X']]) + user_defined_grads=[self.inputs['Y'], self.inputs['X']], + check_eager=True) else: - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ingore_x(self): if core.is_compiled_with_rocm(): self.check_grad(['Y'], 'Out', no_grad_set=set("X"), - user_defined_grads=[self.inputs['X']]) + user_defined_grads=[self.inputs['X']], + check_eager=True) else: - self.check_grad(['Y'], 'Out', no_grad_set=set("X")) + self.check_grad(['Y'], + 'Out', + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): if core.is_compiled_with_rocm(): self.check_grad(['X'], 'Out', no_grad_set=set('Y'), - user_defined_grads=[self.inputs['Y']]) + user_defined_grads=[self.inputs['Y']], + check_eager=True) else: - self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + self.check_grad(['X'], + 'Out', + no_grad_set=set('Y'), + check_eager=True) def init_input_output(self): self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) @@ -137,6 +147,7 @@ class TestComplexDotOp(OpTest): def setUp(self): self.op_type = "dot" + self.python_api = paddle.dot self.init_base_dtype() self.init_input_output() self.init_grad_input_output() @@ -164,27 +175,30 @@ class TestComplexDotOp(OpTest): self.grad_y = self.grad_out * np.conj(self.x) def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): self.check_grad(['X', 'Y'], 'Out', user_defined_grads=[self.grad_x, self.grad_y], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) def test_check_grad_ingore_x(self): self.check_grad(['Y'], 'Out', no_grad_set=set("X"), user_defined_grads=[self.grad_y], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) def test_check_grad_ingore_y(self): self.check_grad(['X'], 'Out', no_grad_set=set('Y'), user_defined_grads=[self.grad_x], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) class TestComplexDotOp2D(OpTest): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c704a1b52d..95eaee2cc0 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1017,11 +1017,12 @@ def dot(x, y, name=None): print(z) """ + if in_dygraph_mode(): + return _C_ops.final_state_dot(x, y) + if _in_legacy_dygraph(): + return _C_ops.dot(x, y) + op_type = 'dot' - # skip var type check in dygraph mode to improve efficiency - if paddle.in_dynamic_mode(): - op = getattr(_C_ops, op_type) - return op(x, y) assert x is not None, 'x cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 500ea7b7ad..0f86c93d93 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -52,6 +52,34 @@ func : diagonal backward : diagonal_grad +- api : digamma + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : digamma + backward : digamma_grad + +- api : dist + args : (Tensor x, Tensor y, float p = 2.0) + output : Tensor + infer_meta : + func : DistInferMeta + kernel : + func : dist + backward : dist_grad + +- api : dot + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : DotInferMeta + kernel : + func : dot + data_type : x + backward : dot_grad + - api : erf args : (Tensor x) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 32c6e2c4b6..32906ce382 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -51,6 +51,37 @@ data_type : out_grad no_need_buffer : x +- backward_api : digamma_grad + forward : digamma (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : digamma_grad + +- backward_api : dist_grad + forward : dist (Tensor x, Tensor y, float p) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, float p) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : dist_grad + +- backward_api : dot_grad + forward : dot (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : dot_grad + data_type : out_grad + - backward_api : erf_grad forward : erf (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/python/paddle/utils/code_gen/legacy_api.yaml b/python/paddle/utils/code_gen/legacy_api.yaml index 8d20833c65..c307fc7a19 100644 --- a/python/paddle/utils/code_gen/legacy_api.yaml +++ b/python/paddle/utils/code_gen/legacy_api.yaml @@ -497,24 +497,6 @@ kernel : func : diag -- api : digamma - args : (Tensor x) - output : Tensor - infer_meta : - func : UnchangedInferMeta - kernel : - func : digamma - backward : digamma_grad - -- api : dist - args : (Tensor x, Tensor y, float p) - output : Tensor - infer_meta : - func : DistInferMeta - kernel : - func : dist - backward : dist_grad - - api : divide args : (Tensor x, Tensor y) output : Tensor @@ -524,14 +506,6 @@ func : divide backward : divide_grad -- api : dot - args : (Tensor x, Tensor y) - output : Tensor - infer_meta : - func : DotInferMeta - kernel : - func : dot - - api : dropout args : (Tensor x, Tensor seed_tensor, float p, bool is_test, str mode, int seed, bool fix_seed) output : Tensor(out), Tensor(mask) @@ -629,14 +603,14 @@ kernel : func : equal_all -# erfinv - api : erfinv args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : erfinv + inplace : (x -> out) backward : erfinv_grad # exp diff --git a/python/paddle/utils/code_gen/legacy_backward.yaml b/python/paddle/utils/code_gen/legacy_backward.yaml index 16d58fde77..a4589120cc 100644 --- a/python/paddle/utils/code_gen/legacy_backward.yaml +++ b/python/paddle/utils/code_gen/legacy_backward.yaml @@ -498,26 +498,6 @@ kernel : func : determinant_grad -- backward_api : digamma_grad - forward : digamma (Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : digamma_grad - -- backward_api : dist_grad - forward : dist (Tensor x, Tensor y, float p) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, float p) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, y] - kernel : - func : dist_grad - - backward_api : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) -- GitLab