From 4a74f4c5aaca9dc36fe2abb7990fe3bd056d87ec Mon Sep 17 00:00:00 2001 From: RedContritio Date: Tue, 11 Apr 2023 15:39:33 +0800 Subject: [PATCH] support auto generate static for randperm (#52531) * support auto generate static for randperm * remove enforce in randperm infermeta --- paddle/fluid/operators/randperm_op.cc | 98 ------------------- paddle/fluid/operators/unity_build_rule.cmake | 2 - paddle/phi/api/yaml/op_compat.yaml | 6 ++ paddle/phi/api/yaml/static_ops.yaml | 11 +++ paddle/phi/ops/compat/randperm_sig.cc | 25 ----- 5 files changed, 17 insertions(+), 125 deletions(-) delete mode 100644 paddle/fluid/operators/randperm_op.cc delete mode 100644 paddle/phi/ops/compat/randperm_sig.cc diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc deleted file mode 100644 index 187b227f331..00000000000 --- a/paddle/fluid/operators/randperm_op.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/randperm_op.h" - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { - -class RandpermOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), - true, - platform::errors::NotFound( - "The output(Out) of randperm op must not be null.")); - int n = ctx->Attrs().Get("n"); - PADDLE_ENFORCE_GT( - n, - 0, - platform::errors::InvalidArgument( - "The input 'n' of randperm op should be greater than 0. " - "But received %d.", - n)); - - ctx->SetOutputDim("Out", phi::make_ddim({n})); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto data_type = - static_cast(ctx.Attr("dtype")); - return phi::KernelKey(data_type, ctx.GetPlace()); - } -}; - -class RandpermOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddOutput("Out", "The output tensor of randperm op."); - - AddAttr( - "n", "The upper bound (exclusive), and it should be greater than 0."); - AddAttr("dtype", - "The data type of output tensor. " - "Default: 3[int64].") - .SetDefault(framework::proto::VarType::INT64); - AddAttr("seed", - "Random seed used for permute samples. " - "0 means use a seed generated by the system." - "Note that if seed is not 0, this operator will always " - "generate the same random permutation every time. " - "Default: 0.") - .SetDefault(0); - - AddComment(R"DOC( -This operator returns a random permutation of integers from 0 to n-1. -)DOC"); - } -}; - -class RandpermOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto var_data_type = static_cast( - PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); - ctx->SetOutputDataType("Out", var_data_type); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR( - randperm, - paddle::operators::RandpermOp, - paddle::operators::RandpermOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::operators::RandpermOpVarTypeInference); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 7ca431e8ea5..91033e2fa67 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -222,7 +222,6 @@ register_unity_group( mkldnn/quantize_mkldnn_op.cc queue_generator_op.cc random_crop_op.cc - randperm_op.cc range_op.cc rank_attention_op.cc rank_loss_op.cc @@ -500,7 +499,6 @@ register_unity_group( register_unity_group( cu random_crop_op.cu - randperm_op.cu range_op.cu reverse_op.cu partial_concat_op.cu diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 98a00e6f5a9..90c75a8dcc6 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1684,6 +1684,12 @@ tensors_name : ShapeTensorList manual_signature : [randint] +- op : randperm + outputs : + out : Out + extra : + attrs : [int seed = 0] + - op : real backward : real_grad inputs : diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 4e0d4cfc931..f0f26e27c1f 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -260,6 +260,17 @@ param : [low, high, shape, dtype] data_type : dtype +- op : randperm + args : (int n, DataType dtype = DataType::INT64) + output : Tensor(out) + infer_meta : + func : RandpermInferMeta + param : [n, dtype] + kernel : + func : randperm + param : [n, dtype] + data_type : dtype + - op : reduce args : (Tensor x, int ring_id = 0, int root_id = 0, int reduce_type = 0) output : Tensor(out) diff --git a/paddle/phi/ops/compat/randperm_sig.cc b/paddle/phi/ops/compat/randperm_sig.cc deleted file mode 100644 index 14b28512e40..00000000000 --- a/paddle/phi/ops/compat/randperm_sig.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(randperm, phi::RandpermOpArgumentMapping); -- GitLab