From 75923a3278521aad18872bef49b0cfb5ccd2ea6d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 28 Jan 2022 17:12:22 +0800 Subject: [PATCH] [PTen] Update all forward argument maping fns (#39252) * update forward argument mapping * fix compile failed * fix test failed --- paddle/fluid/framework/infershape_utils.cc | 10 +++ paddle/fluid/framework/operator.h | 12 ++- paddle/fluid/operators/cast_op.cc | 5 -- paddle/fluid/operators/concat_op.cc | 9 --- .../operators/elementwise/elementwise_op.h | 44 ----------- paddle/fluid/operators/empty_op.cc | 14 ---- paddle/fluid/operators/fill_any_like_op.cc | 5 -- paddle/fluid/operators/fill_constant_op.cc | 23 ------ paddle/fluid/operators/flatten_op.cc | 12 --- paddle/fluid/operators/reshape_op.cc | 14 ---- paddle/pten/core/compat/arg_map_context.h | 3 + paddle/pten/ops/compat/cast_sig.cc | 25 ++++++ paddle/pten/ops/compat/concat_sig.cc | 28 +++++++ paddle/pten/ops/compat/elementwise_sig.cc | 76 +++++++++++++++++++ paddle/pten/ops/compat/empty_sig.cc | 31 ++++++++ paddle/pten/ops/compat/fill_any_like_sig.cc | 26 +++++++ paddle/pten/ops/compat/fill_constant_sig.cc | 71 +++++++++++++++++ paddle/pten/ops/compat/flatten_sig.cc | 34 +++++++++ paddle/pten/ops/compat/reduce_sig.cc | 49 ++++++++++++ paddle/pten/ops/compat/reshape_sig.cc | 31 ++++++++ 20 files changed, 394 insertions(+), 128 deletions(-) create mode 100644 paddle/pten/ops/compat/cast_sig.cc create mode 100644 paddle/pten/ops/compat/concat_sig.cc create mode 100644 paddle/pten/ops/compat/elementwise_sig.cc create mode 100644 paddle/pten/ops/compat/empty_sig.cc create mode 100644 paddle/pten/ops/compat/fill_any_like_sig.cc create mode 100644 paddle/pten/ops/compat/fill_constant_sig.cc create mode 100644 paddle/pten/ops/compat/flatten_sig.cc create mode 100644 paddle/pten/ops/compat/reduce_sig.cc create mode 100644 paddle/pten/ops/compat/reshape_sig.cc diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 08b945159ad..52c0aa003e8 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { return var_types[0] == proto::VarType::SELECTED_ROWS; } + bool IsDenseTensorOutput(const std::string& name) const override { + auto var_types = ctx_.GetOutputsVarType(name); + return var_types[0] == proto::VarType::LOD_TENSOR; + } + + bool IsSelectedRowsOutput(const std::string& name) const override { + auto var_types = ctx_.GetOutputsVarType(name); + return var_types[0] == proto::VarType::SELECTED_ROWS; + } + private: const InferShapeContext& ctx_; }; diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 9039efbc7c5..79b15a14d1e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { } size_t InputSize(const std::string& name) const override { - return ctx_.InputSize(name); + return ctx_.MultiInputVar(name).size(); } size_t OutputSize(const std::string& name) const override { - return ctx_.OutputSize(name); + return ctx_.MultiOutputVar(name).size(); } bool IsDenseTensorInput(const std::string& name) const override { @@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.InputVar(name)->IsType(); } + bool IsDenseTensorOutput(const std::string& name) const override { + return ctx_.OutputVar(name)->IsType(); + } + + bool IsSelectedRowsOutput(const std::string& name) const override { + return ctx_.OutputVar(name)->IsType(); + } + private: const ExecutionContext& ctx_; }; diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 9e2fe6e2d06..0c49dbe7e0f 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(tensor->type(), tensor_place); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"}); - } }; } // namespace operators diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 68a4d09f3b9..2746f034530 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (ctx.HasInput("AxisTensor")) { - return framework::KernelSignature("concat", {"X"}, {"AxisTensor"}, - {"Out"}); - } - return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); - } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 64beac0804d..e18ff9727b2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - int axis = ctx.Attr("axis"); - if (Type() == "elementwise_add") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"}); - } - return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_sub") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("subtract", {"X", "Y"}, {}, - {"Out"}); - } - return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_div") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); - } - return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_mul") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("multiply", {"X", "Y"}, {}, - {"Out"}); - } - return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - return framework::KernelSignature("None", {"X"}, {}, {"Out"}); - } }; class ElementwiseOpInferVarType diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index 71780971560..3d28ca90a5a 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel { framework::proto::VarType::Type(context.Attr("dtype")), context.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - std::string shape; - if (ctx.HasInput("ShapeTensor")) { - shape = "ShapeTensor"; - } else if (ctx.MultiInput("ShapeTensorList").size()) { - shape = "ShapeTensorList"; - } else { - shape = "shape"; - } - - return framework::KernelSignature("empty", {}, {shape}, {"Out"}); - } }; class EmptyOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 245a8977c0b..1e908d5ead9 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { expected_kernel_type.place_, tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("full_like", {}, {"value"}, {"Out"}); - } }; class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index c0e2b4584d0..04c2d027cac 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel { return kt; } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - std::string shape; - if (ctx.HasInput("ShapeTensor")) { - shape = "ShapeTensor"; - } else if (ctx.MultiInput("ShapeTensorList").size()) { - shape = "ShapeTensorList"; - } else { - shape = "shape"; - } - std::string value; - if (ctx.HasInput("ValueTensor")) { - value = "ValueTensor"; - } else { - const auto& str_value = ctx.Attr("str_value"); - value = str_value.empty() ? "value" : "str_value"; - } - if (!ctx.OutputVar("Out")->IsType()) { - return framework::KernelSignature("full", {}, {shape, value}, {"Out"}); - } - return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); - } }; class FillConstantOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 6b1ee00b55d..110e6f1d025 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { return out_shape; } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (ctx.HasOutput("XShape")) { - return framework::KernelSignature("flatten_with_xshape", {"X"}, - {"start_axis", "stop_axis"}, - {"Out", "XShape"}); - } else { - return framework::KernelSignature("flatten", {"X"}, - {"start_axis", "stop_axis"}, {"Out"}); - } - } }; class FlattenContiguousRangeOpMaker : public FlattenOpMaker { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 6c2d5ebcc7d..d54d5234010 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp { ReshapeOp::InferShape(ctx); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - std::string shape; - auto multi_inputs = ctx.MultiInput("ShapeTensor"); - if (multi_inputs.size() > 0) { - shape = "ShapeTensor"; - } else if (ctx.HasInput("Shape")) { - shape = "Shape"; - } else { - shape = "shape"; - } - return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"}); - } }; class Reshape2OpMaker : public ReshapeOpMaker { diff --git a/paddle/pten/core/compat/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h index 835799ec546..6898dd36d63 100644 --- a/paddle/pten/core/compat/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -75,6 +75,9 @@ class ArgumentMappingContext { virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0; + + virtual bool IsDenseTensorOutput(const std::string& name) const = 0; + virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; }; } // namespace pten diff --git a/paddle/pten/ops/compat/cast_sig.cc b/paddle/pten/ops/compat/cast_sig.cc new file mode 100644 index 00000000000..e05ca88aaf3 --- /dev/null +++ b/paddle/pten/ops/compat/cast_sig.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(cast, pten::CastOpArgumentMapping); diff --git a/paddle/pten/ops/compat/concat_sig.cc b/paddle/pten/ops/compat/concat_sig.cc new file mode 100644 index 00000000000..1352cc7eaca --- /dev/null +++ b/paddle/pten/ops/compat/concat_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"}); + } + return KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(concat, pten::ConcatOpArgumentMapping); diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc new file mode 100644 index 00000000000..77e7625532b --- /dev/null +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ElementwiseAddOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("add", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseSubOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseMulOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseDivOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(elementwise_add, + pten::ElementwiseAddOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_sub, + pten::ElementwiseSubOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_mul, + pten::ElementwiseMulOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_div, + pten::ElementwiseDivOpArgumentMapping); diff --git a/paddle/pten/ops/compat/empty_sig.cc b/paddle/pten/ops/compat/empty_sig.cc new file mode 100644 index 00000000000..c74f6106981 --- /dev/null +++ b/paddle/pten/ops/compat/empty_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("ShapeTensor")) { + return KernelSignature("empty", {}, {"ShapeTensor"}, {"Out"}); + } else if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature("empty", {}, {"ShapeTensorList"}, {"Out"}); + } else { + return KernelSignature("empty", {}, {"shape"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(empty, pten::EmptyOpArgumentMapping); diff --git a/paddle/pten/ops/compat/fill_any_like_sig.cc b/paddle/pten/ops/compat/fill_any_like_sig.cc new file mode 100644 index 00000000000..39e301d6338 --- /dev/null +++ b/paddle/pten/ops/compat/fill_any_like_sig.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature FillAnyLikeOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("full_like", {}, {"value"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping); diff --git a/paddle/pten/ops/compat/fill_constant_sig.cc b/paddle/pten/ops/compat/fill_constant_sig.cc new file mode 100644 index 00000000000..6acf01c7c6f --- /dev/null +++ b/paddle/pten/ops/compat/fill_constant_sig.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +// we have to return every specific KernelSignature for infrt now +KernelSignature FillConstantOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorOutput("Out")) { + if (ctx.HasInput("ShapeTensor")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature( + "full", {}, {"ShapeTensor", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"}); + } else { + return KernelSignature( + "full", {}, {"ShapeTensor", "str_value"}, {"Out"}); + } + } + } else if (ctx.InputSize("ShapeTensorList") > 0) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature( + "full", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature( + "full", {}, {"ShapeTensorList", "value"}, {"Out"}); + } else { + return KernelSignature( + "full", {}, {"ShapeTensorList", "str_value"}, {"Out"}); + } + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("full", {}, {"shape", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature("full", {}, {"shape", "value"}, {"Out"}); + } else { + return KernelSignature("full", {}, {"shape", "str_value"}, {"Out"}); + } + } + } + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping); diff --git a/paddle/pten/ops/compat/flatten_sig.cc b/paddle/pten/ops/compat/flatten_sig.cc new file mode 100644 index 00000000000..f1c77440164 --- /dev/null +++ b/paddle/pten/ops/compat/flatten_sig.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasOutput("XShape")) { + return KernelSignature("flatten_with_xshape", + {"X"}, + {"start_axis", "stop_axis"}, + {"Out", "XShape"}); + } else { + return KernelSignature( + "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, + pten::FlattenOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc new file mode 100644 index 00000000000..7f9171fd581 --- /dev/null +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + if (ctx.IsDenseTensorInput("X")) { + if (!reduce_all) { + return KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + } + return KernelSignature("sum_raw", + {"X"}, + {"dim", "keep_dim", "reduce_all", "out_dtype"}, + {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + if (ctx.IsDenseTensorInput("X")) { + if (!reduce_all) { + return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); + } + return KernelSignature( + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc new file mode 100644 index 00000000000..22d39ef4110 --- /dev/null +++ b/paddle/pten/ops/compat/reshape_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.InputSize("ShapeTensor") > 0) { + return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); + } else if (ctx.HasInput("Shape")) { + return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); + } else { + return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping); -- GitLab