diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 08b945159ad7ee201514845af2cb8d8f5876664c..52c0aa003e8fa18c15cb6fc9351ac408f07506a5 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { return var_types[0] == proto::VarType::SELECTED_ROWS; } + bool IsDenseTensorOutput(const std::string& name) const override { + auto var_types = ctx_.GetOutputsVarType(name); + return var_types[0] == proto::VarType::LOD_TENSOR; + } + + bool IsSelectedRowsOutput(const std::string& name) const override { + auto var_types = ctx_.GetOutputsVarType(name); + return var_types[0] == proto::VarType::SELECTED_ROWS; + } + private: const InferShapeContext& ctx_; }; diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 9039efbc7c5746d2ffb3e6927b51d213b8bbb073..79b15a14d1eba78d3962823714cf32d869c0c364 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { } size_t InputSize(const std::string& name) const override { - return ctx_.InputSize(name); + return ctx_.MultiInputVar(name).size(); } size_t OutputSize(const std::string& name) const override { - return ctx_.OutputSize(name); + return ctx_.MultiOutputVar(name).size(); } bool IsDenseTensorInput(const std::string& name) const override { @@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.InputVar(name)->IsType(); } + bool IsDenseTensorOutput(const std::string& name) const override { + return ctx_.OutputVar(name)->IsType(); + } + + bool IsSelectedRowsOutput(const std::string& name) const override { + return ctx_.OutputVar(name)->IsType(); + } + private: const ExecutionContext& ctx_; }; diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 9e2fe6e2d066e63424b6701dfffd47665ce3b60c..0c49dbe7e0f4919b869d365ad4fbd310927f87f0 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(tensor->type(), tensor_place); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"}); - } }; } // namespace operators diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 68a4d09f3b92dcaa81390014640ddfa1afeb31dc..2746f0345302ffce55a7a103c46a12e8442e2565 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (ctx.HasInput("AxisTensor")) { - return framework::KernelSignature("concat", {"X"}, {"AxisTensor"}, - {"Out"}); - } - return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); - } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 64beac0804d0f650a65fe218d2a68495da2303f1..e18ff9727b273acac93beb23dc920eb238d4d531 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - int axis = ctx.Attr("axis"); - if (Type() == "elementwise_add") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"}); - } - return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_sub") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("subtract", {"X", "Y"}, {}, - {"Out"}); - } - return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_div") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); - } - return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - if (Type() == "elementwise_mul") { - if (ctx.InputVar("X")->IsType()) { - if (axis == -1) { - return framework::KernelSignature("multiply", {"X", "Y"}, {}, - {"Out"}); - } - return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, - {"Out"}); - } - } - return framework::KernelSignature("None", {"X"}, {}, {"Out"}); - } }; class ElementwiseOpInferVarType diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index 717809715601728d66f83006a3bd7b0903d07108..3d28ca90a5a15fd53a57034a4722a21842dc4b1c 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel { framework::proto::VarType::Type(context.Attr("dtype")), context.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - std::string shape; - if (ctx.HasInput("ShapeTensor")) { - shape = "ShapeTensor"; - } else if (ctx.MultiInput("ShapeTensorList").size()) { - shape = "ShapeTensorList"; - } else { - shape = "shape"; - } - - return framework::KernelSignature("empty", {}, {shape}, {"Out"}); - } }; class EmptyOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 245a8977c0bbad4f5e3f53102be48ec3fc0432cb..1e908d5ead9c6f3b8f402c6ac00689a2b603b2ae 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { expected_kernel_type.place_, tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("full_like", {}, {"value"}, {"Out"}); - } }; class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index c0e2b4584d0260e221b2fc45d3e7e46415a9b7b5..04c2d027cac6a16747935e0254775f3fc50870dd 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel { return kt; } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - std::string shape; - if (ctx.HasInput("ShapeTensor")) { - shape = "ShapeTensor"; - } else if (ctx.MultiInput("ShapeTensorList").size()) { - shape = "ShapeTensorList"; - } else { - shape = "shape"; - } - std::string value; - if (ctx.HasInput("ValueTensor")) { - value = "ValueTensor"; - } else { - const auto& str_value = ctx.Attr("str_value"); - value = str_value.empty() ? "value" : "str_value"; - } - if (!ctx.OutputVar("Out")->IsType()) { - return framework::KernelSignature("full", {}, {shape, value}, {"Out"}); - } - return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); - } }; class FillConstantOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 6b1ee00b55d62aa0ab3a5093aec329f7fcb10fd1..110e6f1d025389cf80a8e97d05d4a8934c456471 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { return out_shape; } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (ctx.HasOutput("XShape")) { - return framework::KernelSignature("flatten_with_xshape", {"X"}, - {"start_axis", "stop_axis"}, - {"Out", "XShape"}); - } else { - return framework::KernelSignature("flatten", {"X"}, - {"start_axis", "stop_axis"}, {"Out"}); - } - } }; class FlattenContiguousRangeOpMaker : public FlattenOpMaker { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 6c2d5ebcc7d880aa33786df153f270db685f3525..d54d5234010bab999a20279fcf80a8c1cdf5977a 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp { ReshapeOp::InferShape(ctx); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - std::string shape; - auto multi_inputs = ctx.MultiInput("ShapeTensor"); - if (multi_inputs.size() > 0) { - shape = "ShapeTensor"; - } else if (ctx.HasInput("Shape")) { - shape = "Shape"; - } else { - shape = "shape"; - } - return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"}); - } }; class Reshape2OpMaker : public ReshapeOpMaker { diff --git a/paddle/pten/core/compat/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h index 835799ec546aff468972e7adf755129950390f9f..6898dd36d63ad205c0842b3ebc56860161f0ac2e 100644 --- a/paddle/pten/core/compat/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -75,6 +75,9 @@ class ArgumentMappingContext { virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0; + + virtual bool IsDenseTensorOutput(const std::string& name) const = 0; + virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; }; } // namespace pten diff --git a/paddle/pten/ops/compat/cast_sig.cc b/paddle/pten/ops/compat/cast_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..e05ca88aaf36e59b17d3b13a4b72fdef79416159 --- /dev/null +++ b/paddle/pten/ops/compat/cast_sig.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(cast, pten::CastOpArgumentMapping); diff --git a/paddle/pten/ops/compat/concat_sig.cc b/paddle/pten/ops/compat/concat_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..1352cc7eaca6691a52f491cfa0733c540dc28abe --- /dev/null +++ b/paddle/pten/ops/compat/concat_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"}); + } + return KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(concat, pten::ConcatOpArgumentMapping); diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..77e7625532b18b0856e1c645271f3793cc292db7 --- /dev/null +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ElementwiseAddOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("add", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseSubOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseMulOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ElementwiseDivOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (ctx.IsDenseTensorInput("X")) { + if (axis == -1) { + return KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(elementwise_add, + pten::ElementwiseAddOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_sub, + pten::ElementwiseSubOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_mul, + pten::ElementwiseMulOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_div, + pten::ElementwiseDivOpArgumentMapping); diff --git a/paddle/pten/ops/compat/empty_sig.cc b/paddle/pten/ops/compat/empty_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..c74f6106981a011c263bc38118b74fedf1436b21 --- /dev/null +++ b/paddle/pten/ops/compat/empty_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("ShapeTensor")) { + return KernelSignature("empty", {}, {"ShapeTensor"}, {"Out"}); + } else if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature("empty", {}, {"ShapeTensorList"}, {"Out"}); + } else { + return KernelSignature("empty", {}, {"shape"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(empty, pten::EmptyOpArgumentMapping); diff --git a/paddle/pten/ops/compat/fill_any_like_sig.cc b/paddle/pten/ops/compat/fill_any_like_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..39e301d633863d153de62fa1bc59d5839fabf847 --- /dev/null +++ b/paddle/pten/ops/compat/fill_any_like_sig.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature FillAnyLikeOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("full_like", {}, {"value"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping); diff --git a/paddle/pten/ops/compat/fill_constant_sig.cc b/paddle/pten/ops/compat/fill_constant_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..6acf01c7c6f05ca536320797771233879cde782b --- /dev/null +++ b/paddle/pten/ops/compat/fill_constant_sig.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +// we have to return every specific KernelSignature for infrt now +KernelSignature FillConstantOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorOutput("Out")) { + if (ctx.HasInput("ShapeTensor")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature( + "full", {}, {"ShapeTensor", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"}); + } else { + return KernelSignature( + "full", {}, {"ShapeTensor", "str_value"}, {"Out"}); + } + } + } else if (ctx.InputSize("ShapeTensorList") > 0) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature( + "full", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature( + "full", {}, {"ShapeTensorList", "value"}, {"Out"}); + } else { + return KernelSignature( + "full", {}, {"ShapeTensorList", "str_value"}, {"Out"}); + } + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("full", {}, {"shape", "ValueTensor"}, {"Out"}); + } else { + const auto& str_value = + paddle::any_cast(ctx.Attr("str_value")); + if (str_value.empty()) { + return KernelSignature("full", {}, {"shape", "value"}, {"Out"}); + } else { + return KernelSignature("full", {}, {"shape", "str_value"}, {"Out"}); + } + } + } + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping); diff --git a/paddle/pten/ops/compat/flatten_sig.cc b/paddle/pten/ops/compat/flatten_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1c774401648e8f6b7daab8bd9ecb0451a6f35cd --- /dev/null +++ b/paddle/pten/ops/compat/flatten_sig.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasOutput("XShape")) { + return KernelSignature("flatten_with_xshape", + {"X"}, + {"start_axis", "stop_axis"}, + {"Out", "XShape"}); + } else { + return KernelSignature( + "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, + pten::FlattenOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f9171fd5811e070f162fdc72f574c25f7ed5263 --- /dev/null +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + if (ctx.IsDenseTensorInput("X")) { + if (!reduce_all) { + return KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + } + return KernelSignature("sum_raw", + {"X"}, + {"dim", "keep_dim", "reduce_all", "out_dtype"}, + {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + if (ctx.IsDenseTensorInput("X")) { + if (!reduce_all) { + return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); + } + return KernelSignature( + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..22d39ef41109e7026e9403a90e6173bd0816a887 --- /dev/null +++ b/paddle/pten/ops/compat/reshape_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.InputSize("ShapeTensor") > 0) { + return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); + } else if (ctx.HasInput("Shape")) { + return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); + } else { + return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);