From b98306348224ba4f328d0695044ecbbcd3aae3c3 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Mon, 17 Apr 2023 15:30:33 +0800 Subject: [PATCH] add autogen code support for uniform_inplace (#52955) --- .../operators/uniform_random_inplace_op.cc | 120 ------------------ paddle/phi/api/yaml/backward.yaml | 10 ++ paddle/phi/api/yaml/legacy_backward.yaml | 10 -- paddle/phi/api/yaml/legacy_ops.yaml | 11 -- paddle/phi/api/yaml/op_compat.yaml | 7 + paddle/phi/api/yaml/ops.yaml | 11 ++ .../ops/compat/uniform_random_inplace_sig.cc | 44 ------- 7 files changed, 28 insertions(+), 185 deletions(-) delete mode 100644 paddle/fluid/operators/uniform_random_inplace_op.cc delete mode 100644 paddle/phi/ops/compat/uniform_random_inplace_sig.cc diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cc b/paddle/fluid/operators/uniform_random_inplace_op.cc deleted file mode 100644 index d43d1cd1252..00000000000 --- a/paddle/fluid/operators/uniform_random_inplace_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class UniformRandomInplaceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddComment(R"DOC( -This operator fills self tensor with random values sampled from a -uniform distribution. The random result is in a range of [min, max). -)DOC"); - AddInput("X", "The input tensor."); - AddOutput("Out", "The output tensor of uniform random op"); - AddAttr("min", "Minimum value of uniform random. [default -1.0].") - .SetDefault(-1.0f); - AddAttr("max", "Maximun value of uniform random. [default 1.0].") - .SetDefault(1.0f); - AddAttr("seed", - "Random seed used for generating samples. " - "If seed is 0, it will use the seed of the global default " - "generator (which can be set by paddle.seed). " - "Note that if seed is not 0, this operator will always " - "generate the same random numbers every time. [default 0].") - .SetDefault(0); - AddAttr("diag_num", - "The number of diag elements. Note that if " - "diag_num is 0, it means without diag init.[default 0].") - .SetDefault(0); - AddAttr("diag_step", "The step between two diag element.[default 0].") - .SetDefault(0); - AddAttr("diag_val", "The value of diag element. [default 1.0].") - .SetDefault(1.0f); - } -}; - -class UniformRandomInplaceOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -class UniformRandomInplaceGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class UniformRandomInplaceOpVarTypeInference - : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override {} -}; - -template -class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType(this->ForwardOpType() + "_grad"); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle -DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceGradInplaceInferer, - {paddle::framework::GradVarName("Out"), - paddle::framework::GradVarName("X")}); - -DECLARE_INFER_SHAPE_FUNCTOR(uniform_random_inplace, - UniformRandomInplaceInferShapeFunctor, - PD_INFER_META(phi::UniformRandomInplaceInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR( - uniform_random_inplace_grad, - UniformRandomInplaceGradInferShapeFunctor, - PD_INFER_META(phi::UniformRandomInplaceGradInferMeta)); - -REGISTER_OPERATOR(uniform_random_inplace, - paddle::operators::UniformRandomInplaceOp, - paddle::operators::UniformRandomInplaceOpMaker, - paddle::operators::UniformRandomInplaceGradOpMaker< - paddle::framework::OpDesc>, - paddle::operators::UniformRandomInplaceGradOpMaker< - paddle::imperative::OpBase>, - paddle::operators::UniformRandomInplaceOpVarTypeInference, - UniformRandomInplaceInferer, - UniformRandomInplaceInferShapeFunctor); -REGISTER_OPERATOR(uniform_random_inplace_grad, - paddle::operators::UniformRandomInplaceGradOp, - UniformRandomInplaceGradInplaceInferer, - UniformRandomInplaceGradInferShapeFunctor); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 6c0184d7d0c..c3ad63fd5cf 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1952,6 +1952,16 @@ data_type : out_grad no_need_buffer : x +- backward_op : uniform_inplace_grad + forward : uniform_inplace(Tensor x, float min = -1.0, float max = 1.0, int seed = 0, int diag_num = 0, int diag_step = 0, float diag_val = 1.0) -> Tensor(out) + args : (Tensor out_grad, float min = -1.0, float max = 1.0, int seed = 0, int diag_num = 0, int diag_step = 0, float diag_val = 1.0) + output : Tensor(x_grad) + infer_meta : + func : UniformRandomInplaceGradInferMeta + kernel : + func : uniform_inplace_grad + inplace : (out_grad -> x_grad) + - backward_op : unsqueeze_double_grad forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x) args : (Tensor grad_x_grad, IntArray axes) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c0168544a3b..9d323fe3a06 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1097,16 +1097,6 @@ kernel : func : triu_grad -- backward_op : uniform_inplace_grad - forward : uniform_inplace(Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) -> Tensor(out) - args : (Tensor out_grad, float min, float max, int seed, int diag_num, int diag_step, float diag_val) - output : Tensor(x_grad) - infer_meta : - func : UniformRandomInplaceGradInferMeta - kernel : - func : uniform_inplace_grad - inplace : (out_grad -> x_grad) - - backward_op : yolo_loss_grad forward : yolo_loss(Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index ab84e018775..8600df291f2 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1312,17 +1312,6 @@ data_type : dtype backend : place -- op : uniform_inplace - args: (Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) - output: Tensor(out) - infer_meta: - func: UniformRandomInplaceInferMeta - kernel: - func: uniform_inplace - data_type: x - inplace: (x -> out) - backward: uniform_inplace_grad - # The `axis` argument of Python API paddle.unique is not vector - op : unique args : (Tensor x, bool return_index, bool return_inverse, bool return_counts, int[] axis, DataType dtype=DataType::INT64) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f5bdfbfb46a..3aba1954f77 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2264,6 +2264,13 @@ support_tensor : true manual_signature : [uniform] +- op : uniform_inplace (uniform_random_inplace) + backward : uniform_inplace_grad(uniform_random_inplace_grad) + inputs : + x : X + outputs : + out : Out + - op : unique inputs : {x : X} diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 25bda07141c..dae7793c9a1 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1980,6 +1980,17 @@ func : unfold backward : unfold_grad +- op : uniform_inplace + args: (Tensor x, float min = -1.0, float max = 1.0, int seed = 0, int diag_num = 0, int diag_step = 0, float diag_val = 1.0) + output: Tensor(out) + infer_meta: + func: UniformRandomInplaceInferMeta + kernel: + func: uniform_inplace + data_type: x + inplace: (x -> out) + backward: uniform_inplace_grad + - op : unique_consecutive args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, int dtype = 5) output : Tensor(out), Tensor(index), Tensor(counts) diff --git a/paddle/phi/ops/compat/uniform_random_inplace_sig.cc b/paddle/phi/ops/compat/uniform_random_inplace_sig.cc deleted file mode 100644 index ae955e9ca19..00000000000 --- a/paddle/phi/ops/compat/uniform_random_inplace_sig.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { -KernelSignature UniformRandomInplaceOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "uniform_inplace", - {"X"}, - {"min", "max", "seed", "diag_num", "diag_step", "diag_val"}, - {"Out"}); -} - -KernelSignature UniformRandomInplaceGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "uniform_inplace_grad", - {"Out@GRAD"}, - {"min", "max", "seed", "diag_num", "diag_step", "diag_val"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(uniform_random_inplace, uniform_inplace); - -PD_REGISTER_ARG_MAPPING_FN(uniform_random_inplace, - phi::UniformRandomInplaceOpArgumentMapping); - -PD_REGISTER_ARG_MAPPING_FN(uniform_random_inplace_grad, - phi::UniformRandomInplaceGradOpArgumentMapping); -- GitLab