From 22e5ccb0d0a8899d42aadcfd2c6fffc7523e3b2c Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Thu, 7 Sep 2023 10:33:03 +0800 Subject: [PATCH] [IR] add control flow dialect (#56799) * [IR] add control flow dialect --- .../dialect/paddle_dialect/ir/pd_dialect.cc | 14 ++++- .../ir/dialect/paddle_dialect/ir/pd_dialect.h | 7 ++- .../dialect/paddle_dialect/ir/pd_manual_op.cc | 40 ++++++++++++ .../dialect/paddle_dialect/ir/pd_manual_op.h | 22 ++++++- paddle/ir/core/dialect.h | 2 +- paddle/ir/core/ir_printer.cc | 10 +-- paddle/ir/core/ir_printer.h | 2 +- paddle/ir/core/op_base.h | 4 ++ paddle/ir/core/operation.h | 7 ++- paddle/ir/dialect/CMakeLists.txt | 1 + paddle/ir/dialect/control_flow/CMakeLists.txt | 2 + .../ir/dialect/control_flow/ir/cf_dialect.cc | 20 ++++++ .../ir/dialect/control_flow/ir/cf_dialect.h | 33 ++++++++++ paddle/ir/dialect/control_flow/ir/cf_ops.cc | 26 ++++++++ paddle/ir/dialect/control_flow/ir/cf_ops.h | 35 +++++++++++ test/cpp/ir/CMakeLists.txt | 1 + .../ir/control_flow_dialect/CMakeLists.txt | 8 +++ .../cpp/ir/control_flow_dialect/if_op_test.cc | 61 +++++++++++++++++++ test/cpp/ir/core/ir_op_test.cc | 2 +- 19 files changed, 282 insertions(+), 15 deletions(-) create mode 100644 paddle/ir/dialect/control_flow/CMakeLists.txt create mode 100644 paddle/ir/dialect/control_flow/ir/cf_dialect.cc create mode 100644 paddle/ir/dialect/control_flow/ir/cf_dialect.h create mode 100644 paddle/ir/dialect/control_flow/ir/cf_ops.cc create mode 100644 paddle/ir/dialect/control_flow/ir/cf_ops.h create mode 100644 test/cpp/ir/control_flow_dialect/CMakeLists.txt create mode 100644 test/cpp/ir/control_flow_dialect/if_op_test.cc diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index 9d24dcd2778..4b9dd25d67e 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" #include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/utils.h" namespace paddle { @@ -51,8 +52,8 @@ void PaddleDialect::initialize() { RegisterOps(); - + paddle::dialect::SplitGradOp, + paddle::dialect::IfOp>(); RegisterInterfaces(); } @@ -100,6 +101,15 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { } } +void PaddleDialect::PrintOperation(ir::Operation *op, + ir::IrPrinter &printer) const { + if (auto if_op = op->dyn_cast()) { + if_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h index 5a3d46afb43..b9e9567e790 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h @@ -25,8 +25,11 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } - void PrintType(ir::Type type, std::ostream& os) const; - void PrintAttribute(ir::Attribute type, std::ostream& os) const; + void PrintType(ir::Type type, std::ostream& os) const override; + void PrintAttribute(ir::Attribute type, std::ostream& os) const override; + + void PrintOperation(ir::Operation* op, + ir::IrPrinter& printer) const override; // NOLINT private: void initialize(); diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 8e27526cebe..45d29ce80b2 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -629,6 +629,45 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +void IfOp::Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult cond, + std::vector &&output_types) { + argument.num_regions = 2; + argument.AddOperand(cond); + argument.output_types.swap(output_types); +} +ir::Block *IfOp::true_block() { + ir::Region &true_region = (*this)->region(0); + if (true_region.empty()) true_region.emplace_back(); + return true_region.front(); +} +ir::Block *IfOp::false_block() { + ir::Region &false_region = (*this)->region(1); + if (false_region.empty()) false_region.emplace_back(); + return false_region.front(); +} +void IfOp::Print(ir::IrPrinter &printer) { + auto &os = printer.os; + auto op = operation(); + printer.PrintOpResult(op); + os << " = pd.if"; + printer.PrintOpOperands(op); + os << " -> "; + printer.PrintOpReturnType(op); + os << "{"; + for (auto item : *true_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n } else {"; + for (auto item : *false_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n }"; +} +void IfOp::Verify() {} } // namespace dialect } // namespace paddle @@ -636,3 +675,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index 8f0dbd86d1d..6e120317cb4 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -14,7 +14,7 @@ #ifdef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp, paddle::dialect::SplitGradOp +paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp #else @@ -28,6 +28,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp #include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/ir/core/builder.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -116,6 +117,23 @@ class SplitGradOp : public ir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); }; +class IfOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.if"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult cond, + std::vector &&output_types); + ir::Value cond() { return operand_source(0); } + ir::Block *true_block(); + ir::Block *false_block(); + void Print(ir::IrPrinter &printer); // NOLINT + void Verify(); +}; + } // namespace dialect } // namespace paddle @@ -123,5 +141,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) - +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) #endif diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index c1cc54a257b..be67898dd98 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -145,7 +145,7 @@ class IR_API Dialect { IR_THROW("dialect has no registered attribute printing hook"); } - virtual void PrintOperation(const Operation *op, + virtual void PrintOperation(Operation *op, IrPrinter &printer) const; // NOLINT private: diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 25f23b31e28..16d6568ecc4 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -125,7 +125,7 @@ void IrPrinter::PrintProgram(const Program* program) { } } -void IrPrinter::PrintOperation(const Operation* op) { +void IrPrinter::PrintOperation(Operation* op) { if (auto* dialect = op->dialect()) { dialect->PrintOperation(op, *this); return; @@ -156,7 +156,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) { } void IrPrinter::PrintFullOperation(const Operation* op) { - PrintOperation(op); + PrintGeneralOperation(op); if (op->num_regions() > 0) { os << newline; } @@ -290,7 +290,7 @@ void IrPrinter::PrintOpReturnType(const Operation* op) { [this]() { this->os << ", "; }); } -void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const { +void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { printer.PrintGeneralOperation(op); } @@ -299,9 +299,9 @@ void Program::Print(std::ostream& os) const { printer.PrintProgram(this); } -void Operation::Print(std::ostream& os) const { +void Operation::Print(std::ostream& os) { IrPrinter printer(os); - printer.PrintFullOperation(this); + printer.PrintOperation(this); } void Type::Print(std::ostream& os) const { diff --git a/paddle/ir/core/ir_printer.h b/paddle/ir/core/ir_printer.h index d3f868946dd..c393d2dfbe9 100644 --- a/paddle/ir/core/ir_printer.h +++ b/paddle/ir/core/ir_printer.h @@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter { void PrintProgram(const Program* program); /// @brief dispatch to custom printer function or PrintGeneralOperation - void PrintOperation(const Operation* op); + void PrintOperation(Operation* op); /// @brief print operation itself without its regions void PrintGeneralOperation(const Operation* op); /// @brief print operation and its regions diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 1b4690f9099..0a491795d4e 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -215,6 +215,10 @@ class Op : public OpBase { return ConcreteOp(nullptr); } + static bool classof(const Operation *op) { + return op && op->info().id() == TypeId::get(); + } + static std::vector GetInterfaceMap() { constexpr size_t interfaces_num = std::tuple_size::value; std::vector interfaces_map(interfaces_num); diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index dec0dfa6883..961e4a5fccc 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -75,7 +75,7 @@ class IR_API alignas(8) Operation final { const Region ®ion(unsigned index) const; uint32_t num_regions() const { return num_regions_; } - void Print(std::ostream &os) const; + void Print(std::ostream &os); const AttributeMap &attributes() const { return attributes_; } @@ -109,6 +109,11 @@ class IR_API alignas(8) Operation final { return CastUtil::call(this); } + template + bool isa() const { + return T::classof(this); + } + template bool HasTrait() const { return info_.HasTrait(); diff --git a/paddle/ir/dialect/CMakeLists.txt b/paddle/ir/dialect/CMakeLists.txt index a87b0abfb23..064d328fc53 100644 --- a/paddle/ir/dialect/CMakeLists.txt +++ b/paddle/ir/dialect/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(control_flow) add_subdirectory(shape) diff --git a/paddle/ir/dialect/control_flow/CMakeLists.txt b/paddle/ir/dialect/control_flow/CMakeLists.txt new file mode 100644 index 00000000000..5a693ba156c --- /dev/null +++ b/paddle/ir/dialect/control_flow/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc") +ir_library(ir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/control_flow/ir/cf_dialect.cc b/paddle/ir/dialect/control_flow/ir/cf_dialect.cc new file mode 100644 index 00000000000..8d26f862b56 --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_dialect.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +namespace ir { +void ControlFlowDialect::initialize() { RegisterOps(); } +} // namespace ir +IR_DEFINE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect) diff --git a/paddle/ir/dialect/control_flow/ir/cf_dialect.h b/paddle/ir/dialect/control_flow/ir/cf_dialect.h new file mode 100644 index 00000000000..867290cdd5b --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_dialect.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/dialect.h" + +namespace ir { +class ControlFlowDialect : public Dialect { + public: + explicit ControlFlowDialect(IrContext *context) + : Dialect(name(), context, TypeId::get()) { + initialize(); + } + static const char *name() { return "cf"; } + + private: + void initialize(); +}; + +} // namespace ir +IR_DECLARE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect) diff --git a/paddle/ir/dialect/control_flow/ir/cf_ops.cc b/paddle/ir/dialect/control_flow/ir/cf_ops.cc new file mode 100644 index 00000000000..dc5491d1ad5 --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_ops.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +namespace ir { + +void YieldOp::Build(Builder &builder, + OperationArgument &argument, + std::vector &&inputs) { + argument.AddOperands(inputs.begin(), inputs.end()); +} +} // namespace ir + +IR_DEFINE_EXPLICIT_TYPE_ID(ir::YieldOp) diff --git a/paddle/ir/dialect/control_flow/ir/cf_ops.h b/paddle/ir/dialect/control_flow/ir/cf_ops.h new file mode 100644 index 00000000000..d58e717136a --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_ops.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/op_base.h" + +namespace ir { +class IR_API YieldOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.yield"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + std::vector &&inputs); + void Verify() {} +}; +} // namespace ir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::YieldOp); diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index 4eec7e8ef94..87c538633e6 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -4,4 +4,5 @@ add_subdirectory(pass) add_subdirectory(pattern_rewrite) add_subdirectory(kernel_dialect) add_subdirectory(cinn) +add_subdirectory(control_flow_dialect) add_subdirectory(shape_dialect) diff --git a/test/cpp/ir/control_flow_dialect/CMakeLists.txt b/test/cpp/ir/control_flow_dialect/CMakeLists.txt new file mode 100644 index 00000000000..5f2a864f994 --- /dev/null +++ b/test/cpp/ir/control_flow_dialect/CMakeLists.txt @@ -0,0 +1,8 @@ +cc_test_old( + test_if_op + SRCS + if_op_test.cc + DEPS + ir + pd_dialect + gtest) diff --git a/test/cpp/ir/control_flow_dialect/if_op_test.cc b/test/cpp/ir/control_flow_dialect/if_op_test.cc new file mode 100644 index 00000000000..8d0d962b5e7 --- /dev/null +++ b/test/cpp/ir/control_flow_dialect/if_op_test.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +TEST(if_op_test, base) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + ir::Program program(ctx); + ir::Block* block = program.block(); + ir::Builder builder(ctx, block); + + auto full_op = builder.Build( + std::vector{1}, true, phi::DataType::BOOL); + + auto if_op = builder.Build( + full_op.out(), std::vector{builder.bool_type()}); + + ir::Block* true_block = if_op.true_block(); + + builder.SetInsertionPointToStart(true_block); + + auto full_op_1 = builder.Build( + std::vector{2}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_1.out()}); + + ir::Block* false_block = if_op.false_block(); + + builder.SetInsertionPointToStart(false_block); + + auto full_op_2 = builder.Build( + std::vector{3}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_2.out()}); + + std::stringstream ss; + program.Print(ss); + + LOG(INFO) << ss.str(); +} diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 39880c4e5bd..48f54c63230 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -157,7 +157,7 @@ class TestDialect : public ir::Dialect { } static const char *name() { return "test"; } - void PrintOperation(const ir::Operation *op, + void PrintOperation(ir::Operation *op, ir::IrPrinter &printer) const override { printer.PrintOpResult(op); printer.os << " ="; -- GitLab