From 39f0eb2cf40abe6d452b1f69563598970dc49c32 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Thu, 5 Jan 2023 19:01:05 +0800 Subject: [PATCH] Generate the static graph code of ops (#49413) * generate the static graph code of ops * modify the isclose comment * modify the clip comment in nn.py * reset nn.py --- paddle/fluid/operators/activation_op.cc | 23 --- paddle/fluid/operators/allclose_op.cc | 136 ---------------- paddle/fluid/operators/clip_op.cc | 174 --------------------- paddle/fluid/operators/fill_diagonal_op.cc | 126 --------------- paddle/phi/api/yaml/backward.yaml | 43 +++++ paddle/phi/api/yaml/legacy_backward.yaml | 42 ----- paddle/phi/api/yaml/legacy_ops.yaml | 40 ----- paddle/phi/api/yaml/op_compat.yaml | 40 ++++- paddle/phi/api/yaml/op_version.yaml | 38 +++++ paddle/phi/api/yaml/ops.yaml | 43 +++++ paddle/phi/ops/compat/activation_sig.cc | 3 - paddle/phi/ops/compat/allclose_sig.cc | 49 ------ paddle/phi/ops/compat/clip_sig.cc | 80 ---------- paddle/phi/ops/compat/fill_diagonal_sig.cc | 37 ----- python/paddle/tensor/logic.py | 19 ++- 15 files changed, 176 insertions(+), 717 deletions(-) delete mode 100644 paddle/fluid/operators/allclose_op.cc delete mode 100644 paddle/fluid/operators/clip_op.cc delete mode 100644 paddle/fluid/operators/fill_diagonal_op.cc delete mode 100644 paddle/phi/ops/compat/allclose_sig.cc delete mode 100644 paddle/phi/ops/compat/clip_sig.cc delete mode 100644 paddle/phi/ops/compat/fill_diagonal_sig.cc diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 649382ffc9..9d895edc96 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -140,28 +140,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } }; -class BReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "The input is a multi-dimensional Tensor. The data type is " - "float32, float64."); - AddOutput("Out", - "The output is a multi-dimensional Tensor which has same " - "dimension and data type as the ``X``."); - AddAttr("t_min", "The min marginal value of BRelu") - .SetDefault(static_cast(0)); - AddAttr("t_max", "The max marginal value of BRelu") - .SetDefault(static_cast(24)); - AddComment(R"DOC( -BRelu Activation Operator. - -$$out = \min(\max(x, t_{min}), t_{max})$$ - -)DOC"); - } -}; - class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -595,7 +573,6 @@ namespace plat = paddle::platform; FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); -REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor); REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor); diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc deleted file mode 100644 index ab876921f9..0000000000 --- a/paddle/fluid/operators/allclose_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", - "The input tensor, it's data type should be float32, float64."); - AddInput("Other", - "The input tensor, it's data type should be float32, float64."); - AddInput("Rtol", "The relative tolerance.").AsDispensable(); - AddInput("Atol", "The absolute tolerance.").AsDispensable(); - AddOutput("Out", "The output tensor, it's data type is bool."); - AddAttr("rtol", - "The relative tolerance. Default: :math:`1e-5` .") - .SetDefault("1e-5"); - AddAttr("atol", - "The absolute tolerance. Default: :math:`1e-8` .") - .SetDefault("1e-8"); - AddAttr("equal_nan", - "If :math:`True` , then two :math:`NaNs` will be " - "compared as equal. Default: :math:`False` .") - .SetDefault(false); - - AddComment(R"DOC( -This operator checks if all :math:`x` and :math:`y` satisfy the condition: - -.. math:: - \left| x - y \right| \leq atol + rtol \times \left| y \right| - -elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this -operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if -two tensors are elementwise equal within a tolerance. -)DOC"); - } -}; - -class AllcloseOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context().GetPlace()); - } -}; - -class AllcloseOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext* ctx) const override { - ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPU = phi::CPUContext; - -DECLARE_INFER_SHAPE_FUNCTOR(allclose, - AllcloseInferShapeFunctor, - PD_INFER_META(phi::AllValueCompareInferMeta)); -REGISTER_OPERATOR( - allclose, - ops::AllcloseOp, - ops::AllcloseOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::AllcloseOpVarTypeInference, - AllcloseInferShapeFunctor); - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(allclose) - .AddCheckpoint( - R"ROC(Upgrade allclose, add two new inputs [Rtol] and [Atol].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput("Rtol", - "The added input 'Rtol' is not" - "dispensable.") - .NewInput("Atol", - "The added input 'Atol' is not" - "dispensable.")) - .AddCheckpoint( - R"ROC(Delete two float attributes [rtol] and [atol], - then add 2 string attributes [atol, rtol]. Don't be surprised. - This is because float cannot represent hight-precision - floating-point values, and our framework doesn't support - the use of double attributes. As a result, string instead - of double is used here to represent high-precision - floating-point values. - )ROC", - paddle::framework::compatible::OpVersionDesc() - .DeleteAttr("rtol", - "The attribute 'rtol' is deleted." - "The reason why it is deleted is that" - "attributes do not support a float64 value" - "and it is changed to a tensor.") - .DeleteAttr("atol", - "The attribute 'atol' is deleted." - "The reason why it is deleted is that" - "attributes do not support a float64 value" - "and it is changed to a tensor.") - .NewAttr("rtol", - "(string) The relative tolerance. Default: :math:`1e-5` .", - std::string("1e-5")) - .NewAttr("atol", - "(string) The absolute tolerance. Default: :math:`1e-8` .", - std::string("1e-8"))); diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc deleted file mode 100644 index 1fdc4e9a12..0000000000 --- a/paddle/fluid/operators/clip_op.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class ClipOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -template -class ClipOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "Tensor, the input of clip op, data type should be float32 or " - "float64."); - AddInput("Min", - "Tensor, the lower bound, data type should be float32 " - "or float64.") - .AsDispensable(); - AddInput("Max", - "Tensor, the upper bound, data type should be float32 " - "or float64.") - .AsDispensable(); - AddOutput( - "Out", - "Tensor, the clipped tensor, with the same shape and data type as " - "input(x)"); - AddAttr("min", "float number, the minimum value to clip by."); - AddAttr("max", "float number, the maximum value to clip by."); - AddComment(R"DOC( -Clip Operator. - -The clip operator limits the value of given input within an interval [min, max], -just as the following equation, - -$$ -Out = \MIN(\MAX(x, min), max) -$$ - -)DOC"); - } -}; - -class ClipOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip_grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "clip_grad"); - auto x_dims = ctx->GetInputDim("X"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -template -class ClipGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("clip_grad"); - op->SetInput("X", this->Input("X")); - if (this->HasInput("Min")) { - op->SetInput("Min", this->Input("Min")); - } - if (this->HasInput("Max")) { - op->SetInput("Max", this->Input("Max")); - } - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_INPLACE_OP_INFERER(ClipInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); - -template -class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("clip_grad"); - op->SetInput("X", this->Input("X")); - if (this->HasInput("Min")) { - op->SetInput("Min", this->Input("Min")); - } - if (this->HasInput("Max")) { - op->SetInput("Max", this->Input("Max")); - } - op->SetInput(framework::GradVarName("Out"), - this->OutputGrad(framework::GradVarName("X"))); - op->SetOutput(framework::GradVarName("X"), - this->InputGrad(framework::GradVarName("Out"))); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(clip, - ClipInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); -REGISTER_OPERATOR(clip, - ops::ClipOp, - ops::ClipOpMaker, - ops::ClipGradOpMaker, - ops::ClipGradOpMaker, - ops::ClipInplaceInferer, - ClipInferShapeFunctor); -REGISTER_OPERATOR(clip_grad, - ops::ClipOpGrad, - ops::ClipGradInplaceInferer, - ops::ClipDoubleGradOpMaker, - ops::ClipDoubleGradOpMaker); - -REGISTER_OP_VERSION(clip).AddCheckpoint( - R"ROC( - Upgrade clip add a new input [Min])ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput("Min", - "Pass the mix, min value as input, not attribute. Min is " - "dispensable.") - .NewInput("Max", - "Pass the mix, min value as input, not attribute. Max is " - "dispensable.")); diff --git a/paddle/fluid/operators/fill_diagonal_op.cc b/paddle/fluid/operators/fill_diagonal_op.cc deleted file mode 100644 index 373a63b7ff..0000000000 --- a/paddle/fluid/operators/fill_diagonal_op.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class FillIDiagonalOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddComment(R"DOC(Fill replace operator - Fill the diagonal of an tensor with 'value'. - )DOC"); - AddInput("X", "(Tensor) The input tensor."); - AddOutput("Out", - "Tensor, the output tensor, with the same shape and data type " - "as input(x)"); - AddAttr( - "value", - "The float values of tensor, whose dim is one, and no need of grad") - .SetDefault(0); - AddAttr("wrap", - "the diagonal 'wrapped' after N columns for tall matrices") - .SetDefault(false); - AddAttr("offset", - "offset of diagonal, zero means no offset, positive means " - "offset to up-right corner; negtive means offset to " - "bottom-left corner") - .SetDefault(0); - } -}; - -class FillIDiagonalOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -class FillIDiagonalOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto var_type = ctx->GetInputType("X", 0); - auto data_type = ctx->GetInputDataType("X", 0); - ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); - ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); - } -}; - -class FillIDiagonalGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - // Note: don't get data type from ctx.Input("Input"); - auto dtype = framework::TransToProtoVarType( - ctx.Input(framework::GradVarName("Out"))->type()); - return phi::KernelKey(dtype, ctx.GetPlace()); - } -}; - -template -class FillIDiagonalGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("fill_diagonal_grad"); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_INPLACE_OP_INFERER(FillIDiagonalOpInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); - -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal, - FillDiagonalShapeFunctor, - PD_INFER_META(phi::FillDiagonalInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal_grad, - FillDiagonalGradShapeFunctor, - PD_INFER_META(phi::FillDiagonalGradInferMeta)); - -REGISTER_OPERATOR(fill_diagonal, - ops::FillIDiagonalOp, - ops::FillIDiagonalGradOpMaker, - ops::FillIDiagonalGradOpMaker, - ops::FillIDiagonalOpMaker, - ops::FillIDiagonalOpInplaceInferer, - ops::FillIDiagonalOpVarTypeInference, - FillDiagonalShapeFunctor); - -REGISTER_OPERATOR(fill_diagonal_grad, - ops::FillIDiagonalGradOp, - ops::FillIDiagonalGradOpInplaceInferer, - FillDiagonalGradShapeFunctor); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0eee2b6a2b..11cac165d9 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -182,6 +182,29 @@ kernel : func : cholesky_solve_grad +- backward_op : clip_double_grad + forward : clip_grad (Tensor x, Tensor grad_out, Scalar min = 0., Scalar max = 0.) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_x_grad, Scalar min = 0., Scalar max = 0.) + output : Tensor(grad_out_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : clip_grad + data_type : x + +- backward_op : clip_grad + forward : clip (Tensor x, Scalar min, Scalar max) -> Tensor(out) + args : (Tensor x, Tensor out_grad, Scalar min = 0., Scalar max = 0.) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : clip_grad + backward : clip_double_grad + inplace : (out_grad -> x_grad) + - backward_op : complex_grad forward : complex (Tensor real, Tensor imag) -> Tensor(out) args : (Tensor real, Tensor imag, Tensor out_grad) @@ -452,6 +475,15 @@ data_type: out_grad no_need_buffer: x +- backward_op : fill_diagonal_grad + forward : fill_diagonal (Tensor x, float value=0, int offset=0, bool wrap=false) -> Tensor(out) + args : (Tensor out_grad, float value, int offset, bool wrap) + output : Tensor(x_grad) + infer_meta : + func : FillDiagonalGradInferMeta + kernel : + func : fill_diagonal_grad + - backward_op : fill_diagonal_tensor_grad forward : fill_diagonal_tensor (Tensor x, Tensor y, int64_t offset, int dim1, int dim2) -> Tensor(out) args : (Tensor out_grad, int64_t offset, int dim1, int dim2) @@ -563,6 +595,17 @@ func : hard_sigmoid_grad inplace : (out_grad -> x_grad) +- backward_op : hardtanh_grad + forward : hardtanh (Tensor x, float t_min=0, float t_max=24) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float t_min, float t_max) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hardtanh_grad + inplace : (out_grad -> x_grad) + - backward_op : imag_grad forward : imag (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 5621b2c7db..10d328915e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -193,28 +193,6 @@ invoke : cast (out_grad, x.dtype()) no_need_buffer : x -- backward_op : clip_double_grad - forward : clip_grad (Tensor x, Tensor grad_out, Scalar min = 0., Scalar max = 0.) -> Tensor(grad_x) - args : (Tensor x, Tensor grad_x_grad, Scalar min = 0., Scalar max = 0.) - output : Tensor(grad_out_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : clip_grad - -- backward_op : clip_grad - forward : clip (Tensor x, Scalar min, Scalar max) -> Tensor(out) - args : (Tensor x, Tensor out_grad, Scalar min = 0., Scalar max = 0.) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : clip_grad - backward : clip_double_grad - inplace : (out_grad -> x_grad) - - backward_op : concat_double_grad forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) args : (Tensor[] grad_x_grad, Scalar axis = 0) @@ -485,15 +463,6 @@ func : UnchangedInferMeta invoke : zeros_like(out_grad) -- backward_op : fill_diagonal_grad - forward : fill_diagonal (Tensor x, float value, int offset, bool wrap) -> Tensor(out) - args : (Tensor out_grad, float value, int offset, bool wrap) - output : Tensor(x_grad) - infer_meta : - func : FillDiagonalGradInferMeta - kernel : - func : fill_diagonal_grad - - backward_op : fill_grad forward : fill (Tensor x, Scalar value) -> Tensor(out) args : (Tensor out_grad, Scalar value) @@ -585,17 +554,6 @@ func : hardswish_grad inplace : (out_grad -> x_grad) -- backward_op : hardtanh_grad - forward : hardtanh (Tensor x, float t_min, float t_max) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float t_min, float t_max) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : hardtanh_grad - inplace : (out_grad -> x_grad) - - backward_op : hsigmoid_loss_grad forward : hsigmoid_loss (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool remote_prefetch, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out) args : (Tensor x, Tensor w, Tensor label, Tensor path, Tensor code, Tensor bias, Tensor pre_out, Tensor out_grad, int num_classes, bool remote_prefetch, bool is_sparse) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index b0294c245a..ccc54950c6 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -107,15 +107,6 @@ kernel : func : all -- op : allclose - args : (Tensor x, Tensor y, Scalar rtol, Scalar atol, bool equal_nan) - output : Tensor(out) - infer_meta : - func : AllValueCompareInferMeta - param: [x, y] - kernel : - func : allclose - - op : amax args : (Tensor x, int64_t[] axis={}, bool keepdim=false) output : Tensor(out) @@ -367,17 +358,6 @@ kernel : func : class_center_sample -- op : clip - args : (Tensor x, Scalar(float) min, Scalar(float) max) - output : Tensor(out) - inplace : (x -> out) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : clip - backward : clip_grad - - op : clip_by_norm args : (Tensor x, float max_norm) output : Tensor(out) @@ -692,16 +672,6 @@ inplace : (x -> out) backward: fill_grad -- op : fill_diagonal - args : (Tensor x, float value, int offset, bool wrap) - output : Tensor(out) - infer_meta : - func : FillDiagonalInferMeta - kernel : - func : fill_diagonal - inplace : (x -> out) - backward : fill_diagonal_grad - - op : flatten args : (Tensor x, int start_axis, int stop_axis) output : Tensor(out), Tensor(xshape) @@ -870,16 +840,6 @@ func : hardswish backward : hardswish_grad -- op : hardtanh - args : (Tensor x, float t_min, float t_max) - output : Tensor - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : hardtanh - backward : hardtanh_grad - - op : hsigmoid_loss args : (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool remote_prefetch, bool is_sparse) output : Tensor(out), Tensor(pre_out), Tensor(w_out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index ceee57a771..001a0981e1 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -60,6 +60,19 @@ extra : attrs : [bool use_cudnn = true] +- op : allclose + inputs : + {x : Input, y : Other} + outputs : + out : Out + scalar : + rtol : + data_type : std::string + tensor_name : Rtol + atol : + data_type : std::string + tensor_name : Atol + - op : angle backward : angle_grad inputs : @@ -180,7 +193,18 @@ out : Out - op : clip - backward : clip_grad + backward : clip_grad, clip_double_grad + inputs : + x : X + outputs : + out : Out + scalar : + min : + data_type : float + tensor_name : Min + max : + data_type : float + tensor_name : Max extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] @@ -503,6 +527,13 @@ inputs: {x: X} outputs: {out: Out} +- op : fill_diagonal + backward : fill_diagonal_grad + inputs : + x : X + outputs : + out : Out + - op : fill_diagonal_tensor inputs : {x : X, y : Y} @@ -634,6 +665,13 @@ outputs : out : Out +- op : hardtanh (brelu) + backward : hardtanh_grad (brelu_grad) + inputs : + x : X + outputs : + out : Out + - op : heaviside (elementwise_heaviside) backward : heaviside_grad (elementwise_heaviside_grad) extra : diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 77e722a297..358aaf4ec1 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -1,3 +1,41 @@ +- op : allclose + version : + - checkpoint : Upgrade allclose, add two new inputs [Rtol] and [Atol]. + action: + - add_input : Rtol + comment : The added input 'Rtol' is not dispensable. + - add_input : Atol + comment : The added input 'Atol' is not dispensable. + - checkpoint : Delete two float attributes [rtol] and [atol], + then add 2 string attributes [atol, rtol]. Don't be surprised. + This is because float cannot represent hight-precision + floating-point values, and our framework doesn't support + the use of double attributes. As a result, string instead + of double is used here to represent high-precision + floating-point values. + action : + - add_attr : rtol + comment : The relative tolerance. Default::math:`1e-5` . + default : std::string("1e-5") + - delete_attr : rtol + comment : The attribute 'rtol' is deleted. The reason why it is deleted is that + attributes do not support a float64 value and it is changed to a tensor. + - add_attr : atol + comment : (string) The absolute tolerance. Default::math:`1e-8` . + default : std::string("1e-5") + - delete_attr : atol + comment : The attribute 'atol' is deleted. The reason why it is deleted is that + attributes do not support a float64 value and it is changed to a tensor. + +- op : clip + version : + - checkpoint : Upgrade clip add a new input [Min] + action : + - add_input : Min + comment : Pass the mix, min value as input, not attribute. Min is dispensable. + - add_input : Max + comment : Pass the mix, min value as input, not attribute. Max is dispensable. + - op : flip version : - checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 127d856e37..8d249d2c6a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -26,6 +26,16 @@ data_type : x backward : addmm_grad +- op : allclose + args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false) + output : Tensor(out) + infer_meta : + func : AllValueCompareInferMeta + param: [x, y] + kernel : + func : allclose + data_type : x + - op : angle args : (Tensor x) output : Tensor @@ -162,6 +172,18 @@ func : cholesky_solve backward : cholesky_solve_grad +- op : clip + args : (Tensor x, Scalar(float) min, Scalar(float) max) + output : Tensor(out) + inplace : (x -> out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : clip + data_type : x + backward : clip_grad + - op : complex args : (Tensor real, Tensor imag) output : Tensor @@ -393,6 +415,17 @@ func : fft_r2c backward : fft_r2c_grad +- op : fill_diagonal + args : (Tensor x, float value=0, int offset=0, bool wrap=false) + output : Tensor(out) + infer_meta : + func : FillDiagonalInferMeta + kernel : + func : fill_diagonal + data_type : x + inplace : (x -> out) + backward : fill_diagonal_grad + - op : fill_diagonal_tensor args : (Tensor x, Tensor y, int64_t offset = 0, int dim1 = 0, int dim2 = 1) output : Tensor(out) @@ -509,6 +542,16 @@ func : hard_sigmoid backward : hardsigmoid_grad +- op : hardtanh + args : (Tensor x, float t_min=0, float t_max=24) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hardtanh + backward : hardtanh_grad + - op : histogram args : (Tensor input, int64_t bins = 100, int min = 0, int max = 0) output : Tensor(out) diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index e40bb6bc3d..acc036d393 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -113,15 +113,12 @@ KernelSignature PowTripleGradOpArgumentMapping( } } // namespace phi -PD_REGISTER_BASE_KERNEL_NAME(brelu, hardtanh); -PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hardtanh_grad); PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish); PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(relu6, phi::Relu6OpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad, diff --git a/paddle/phi/ops/compat/allclose_sig.cc b/paddle/phi/ops/compat/allclose_sig.cc deleted file mode 100644 index e5c4fc027b..0000000000 --- a/paddle/phi/ops/compat/allclose_sig.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature AllCloseOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("Rtol")) { - if (ctx.HasInput("Atol")) { - return KernelSignature("allclose", - {"Input", "Other"}, - {"Rtol", "Atol", "equal_nan"}, - {"Out"}); - } else { - return KernelSignature("allclose", - {"Input", "Other"}, - {"Rtol", "atol", "equal_nan"}, - {"Out"}); - } - } else { - if (ctx.HasInput("Atol")) { - return KernelSignature("allclose", - {"Input", "Other"}, - {"rtol", "Atol", "equal_nan"}, - {"Out"}); - } else { - return KernelSignature("allclose", - {"Input", "Other"}, - {"rtol", "atol", "equal_nan"}, - {"Out"}); - } - } -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(allclose, phi::AllCloseOpArgumentMapping); diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc deleted file mode 100644 index 889dbf6ef9..0000000000 --- a/paddle/phi/ops/compat/clip_sig.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" -#include "paddle/utils/small_vector.h" - -namespace phi { - -KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { - paddle::small_vector attr_names; - attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); - attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); - if (ctx.IsDenseTensorInput("X")) { - if (ctx.HasInput("Min")) { - if (ctx.HasInput("Max")) { - return KernelSignature("clip", {"X"}, {"Min", "Max"}, {"Out"}); - } else { - return KernelSignature("clip", {"X"}, {"Min", "max"}, {"Out"}); - } - } else { - if (ctx.HasInput("Max")) { - return KernelSignature("clip", {"X"}, {"min", "Max"}, {"Out"}); - } else { - return KernelSignature("clip", {"X"}, {"min", "max"}, {"Out"}); - } - } - } else if (ctx.IsSelectedRowsInput("X")) { - if (ctx.HasInput("Min")) { - if (ctx.HasInput("Max")) { - return KernelSignature("clip_sr", {"X"}, {"Min", "Max"}, {"Out"}); - } else { - return KernelSignature("clip_sr", {"X"}, {"Min", "max"}, {"Out"}); - } - } else { - if (ctx.HasInput("Max")) { - return KernelSignature("clip_sr", {"X"}, {"min", "Max"}, {"Out"}); - } else { - return KernelSignature("clip_sr", {"X"}, {"min", "max"}, {"Out"}); - } - } - } - - return KernelSignature("unregistered", {}, {}, {}); -} - -KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("Min")) { - if (ctx.HasInput("Max")) { - return KernelSignature( - "clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"}); - } else { - return KernelSignature( - "clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"}); - } - } else { - if (ctx.HasInput("Max")) { - return KernelSignature( - "clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"}); - } else { - return KernelSignature( - "clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"}); - } - } -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(clip, phi::ClipOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(clip_grad, phi::ClipGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/fill_diagonal_sig.cc b/paddle/phi/ops/compat/fill_diagonal_sig.cc deleted file mode 100644 index 81a0faf645..0000000000 --- a/paddle/phi/ops/compat/fill_diagonal_sig.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature FillDiagonalOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "fill_diagonal", {"X"}, {"value", "offset", "wrap"}, {"Out"}); -} - -KernelSignature FillDiagonalGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("fill_diagonal_grad", - {"Out@GRAD"}, - {"value", "offset", "wrap"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(fill_diagonal, phi::FillDiagonalOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_grad, - phi::FillDiagonalGradOpArgumentMapping); diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index c53d530e60..36e3de3f53 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -351,12 +351,19 @@ def equal_all(x, y, name=None): @templatedoc() def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - """ - ${comment} + r""" + Check if all :math:`x` and :math:`y` satisfy the condition: + + .. math:: + \left| x - y \right| \leq atol + rtol \times \left| y \right| + + elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this + operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if + two tensors are elementwise equal within a tolerance. Args: - x(Tensor): ${input_comment}. - y(Tensor): ${other_comment}. + x(Tensor): The input tensor, it's data type should be float32, float64.. + y(Tensor): The input tensor, it's data type should be float32, float64.. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . equal_nan(equalnantype, optional): ${equal_nan_comment}. @@ -364,7 +371,7 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): refer to :ref:`api_guide_Name`. Default: None. Returns: - Tensor: ${out_comment}. + Tensor: The output tensor, it's data type is bool. Examples: .. code-block:: python @@ -937,7 +944,7 @@ def bitwise_not(x, out=None, name=None): @templatedoc() def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): r""" - Checks if all :math:`x` and :math:`y` satisfy the condition: + Check if all :math:`x` and :math:`y` satisfy the condition: .. math:: -- GitLab