diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index 9d24dcd277884f23d12465a80d65686e7588e2df..4b9dd25d67e00325c069703898a0271a029ee7f7 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" #include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/utils.h" namespace paddle { @@ -51,8 +52,8 @@ void PaddleDialect::initialize() { RegisterOps(); - + paddle::dialect::SplitGradOp, + paddle::dialect::IfOp>(); RegisterInterfaces(); } @@ -100,6 +101,15 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { } } +void PaddleDialect::PrintOperation(ir::Operation *op, + ir::IrPrinter &printer) const { + if (auto if_op = op->dyn_cast()) { + if_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h index 5a3d46afb4394d1730f15a44e70559c2033354d6..b9e9567e7908df7c74c436f4397731e3363e4226 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h @@ -25,8 +25,11 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } - void PrintType(ir::Type type, std::ostream& os) const; - void PrintAttribute(ir::Attribute type, std::ostream& os) const; + void PrintType(ir::Type type, std::ostream& os) const override; + void PrintAttribute(ir::Attribute type, std::ostream& os) const override; + + void PrintOperation(ir::Operation* op, + ir::IrPrinter& printer) const override; // NOLINT private: void initialize(); diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 8e27526cebefafad6658209ecf7d3332876a20cc..45d29ce80b269e1d97def8c84138a80db2b8f412 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -629,6 +629,45 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +void IfOp::Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult cond, + std::vector &&output_types) { + argument.num_regions = 2; + argument.AddOperand(cond); + argument.output_types.swap(output_types); +} +ir::Block *IfOp::true_block() { + ir::Region &true_region = (*this)->region(0); + if (true_region.empty()) true_region.emplace_back(); + return true_region.front(); +} +ir::Block *IfOp::false_block() { + ir::Region &false_region = (*this)->region(1); + if (false_region.empty()) false_region.emplace_back(); + return false_region.front(); +} +void IfOp::Print(ir::IrPrinter &printer) { + auto &os = printer.os; + auto op = operation(); + printer.PrintOpResult(op); + os << " = pd.if"; + printer.PrintOpOperands(op); + os << " -> "; + printer.PrintOpReturnType(op); + os << "{"; + for (auto item : *true_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n } else {"; + for (auto item : *false_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n }"; +} +void IfOp::Verify() {} } // namespace dialect } // namespace paddle @@ -636,3 +675,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index 8f0dbd86d1d80c27a8df8b1dd3f3015aa765ddf4..6e120317cb461f2a4a76bf7bb99421cd02757d6f 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -14,7 +14,7 @@ #ifdef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp, paddle::dialect::SplitGradOp +paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp #else @@ -28,6 +28,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp #include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/ir/core/builder.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -116,6 +117,23 @@ class SplitGradOp : public ir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); }; +class IfOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.if"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult cond, + std::vector &&output_types); + ir::Value cond() { return operand_source(0); } + ir::Block *true_block(); + ir::Block *false_block(); + void Print(ir::IrPrinter &printer); // NOLINT + void Verify(); +}; + } // namespace dialect } // namespace paddle @@ -123,5 +141,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) - +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) #endif diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index c1cc54a257b76fdb619790a914f6a3978848c28b..be67898dd98f5824831e327a9479ec184440e7af 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -145,7 +145,7 @@ class IR_API Dialect { IR_THROW("dialect has no registered attribute printing hook"); } - virtual void PrintOperation(const Operation *op, + virtual void PrintOperation(Operation *op, IrPrinter &printer) const; // NOLINT private: diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 25f23b31e2854195d3d8d351257a2304c35e3b07..16d6568ecc4c3133c0d5cc72cfc5cadb976f6a24 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -125,7 +125,7 @@ void IrPrinter::PrintProgram(const Program* program) { } } -void IrPrinter::PrintOperation(const Operation* op) { +void IrPrinter::PrintOperation(Operation* op) { if (auto* dialect = op->dialect()) { dialect->PrintOperation(op, *this); return; @@ -156,7 +156,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) { } void IrPrinter::PrintFullOperation(const Operation* op) { - PrintOperation(op); + PrintGeneralOperation(op); if (op->num_regions() > 0) { os << newline; } @@ -290,7 +290,7 @@ void IrPrinter::PrintOpReturnType(const Operation* op) { [this]() { this->os << ", "; }); } -void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const { +void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { printer.PrintGeneralOperation(op); } @@ -299,9 +299,9 @@ void Program::Print(std::ostream& os) const { printer.PrintProgram(this); } -void Operation::Print(std::ostream& os) const { +void Operation::Print(std::ostream& os) { IrPrinter printer(os); - printer.PrintFullOperation(this); + printer.PrintOperation(this); } void Type::Print(std::ostream& os) const { diff --git a/paddle/ir/core/ir_printer.h b/paddle/ir/core/ir_printer.h index d3f868946ddc2e8912d94072b971d0bbb3c8dcf7..c393d2dfbe90ad7a68ed43af19d3831ea976a645 100644 --- a/paddle/ir/core/ir_printer.h +++ b/paddle/ir/core/ir_printer.h @@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter { void PrintProgram(const Program* program); /// @brief dispatch to custom printer function or PrintGeneralOperation - void PrintOperation(const Operation* op); + void PrintOperation(Operation* op); /// @brief print operation itself without its regions void PrintGeneralOperation(const Operation* op); /// @brief print operation and its regions diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 1b4690f9099e93c381a042e739ebb1d4d10272ae..0a491795d4eed5e909208ac44eaf12fc63120534 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -215,6 +215,10 @@ class Op : public OpBase { return ConcreteOp(nullptr); } + static bool classof(const Operation *op) { + return op && op->info().id() == TypeId::get(); + } + static std::vector GetInterfaceMap() { constexpr size_t interfaces_num = std::tuple_size::value; std::vector interfaces_map(interfaces_num); diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index dec0dfa6883ea106ef2ab4e4c69ceb0ccd35709e..961e4a5fccc505780e1ee798feb024bee5ec241a 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -75,7 +75,7 @@ class IR_API alignas(8) Operation final { const Region ®ion(unsigned index) const; uint32_t num_regions() const { return num_regions_; } - void Print(std::ostream &os) const; + void Print(std::ostream &os); const AttributeMap &attributes() const { return attributes_; } @@ -109,6 +109,11 @@ class IR_API alignas(8) Operation final { return CastUtil::call(this); } + template + bool isa() const { + return T::classof(this); + } + template bool HasTrait() const { return info_.HasTrait(); diff --git a/paddle/ir/dialect/CMakeLists.txt b/paddle/ir/dialect/CMakeLists.txt index a87b0abfb2383d1ebf0f867d6f8868437e212665..064d328fc53d6aecca22ce7082f215ad1af130e8 100644 --- a/paddle/ir/dialect/CMakeLists.txt +++ b/paddle/ir/dialect/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(control_flow) add_subdirectory(shape) diff --git a/paddle/ir/dialect/control_flow/CMakeLists.txt b/paddle/ir/dialect/control_flow/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a693ba156ccdf09b58a857bca82c2864c9332f2 --- /dev/null +++ b/paddle/ir/dialect/control_flow/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc") +ir_library(ir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/control_flow/ir/cf_dialect.cc b/paddle/ir/dialect/control_flow/ir/cf_dialect.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d26f862b562b8259fd0ddb3a3a6e616e6bcee99 --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_dialect.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +namespace ir { +void ControlFlowDialect::initialize() { RegisterOps(); } +} // namespace ir +IR_DEFINE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect) diff --git a/paddle/ir/dialect/control_flow/ir/cf_dialect.h b/paddle/ir/dialect/control_flow/ir/cf_dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..867290cdd5babafd6d205acb250e529298418b36 --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_dialect.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/dialect.h" + +namespace ir { +class ControlFlowDialect : public Dialect { + public: + explicit ControlFlowDialect(IrContext *context) + : Dialect(name(), context, TypeId::get()) { + initialize(); + } + static const char *name() { return "cf"; } + + private: + void initialize(); +}; + +} // namespace ir +IR_DECLARE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect) diff --git a/paddle/ir/dialect/control_flow/ir/cf_ops.cc b/paddle/ir/dialect/control_flow/ir/cf_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc5491d1ad5d34a2838850a8a482d35b907abd31 --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_ops.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +namespace ir { + +void YieldOp::Build(Builder &builder, + OperationArgument &argument, + std::vector &&inputs) { + argument.AddOperands(inputs.begin(), inputs.end()); +} +} // namespace ir + +IR_DEFINE_EXPLICIT_TYPE_ID(ir::YieldOp) diff --git a/paddle/ir/dialect/control_flow/ir/cf_ops.h b/paddle/ir/dialect/control_flow/ir/cf_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d58e717136ae28e6f019faf0ff234ab255dda89d --- /dev/null +++ b/paddle/ir/dialect/control_flow/ir/cf_ops.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/op_base.h" + +namespace ir { +class IR_API YieldOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.yield"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + std::vector &&inputs); + void Verify() {} +}; +} // namespace ir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::YieldOp); diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index 4eec7e8ef94c14b5e7e3b0fa00a2e46691920a40..87c538633e6dff90bc19625579e41035c60b5e86 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -4,4 +4,5 @@ add_subdirectory(pass) add_subdirectory(pattern_rewrite) add_subdirectory(kernel_dialect) add_subdirectory(cinn) +add_subdirectory(control_flow_dialect) add_subdirectory(shape_dialect) diff --git a/test/cpp/ir/control_flow_dialect/CMakeLists.txt b/test/cpp/ir/control_flow_dialect/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5f2a864f9942eb5bd1a3c51728853199fd62fa3d --- /dev/null +++ b/test/cpp/ir/control_flow_dialect/CMakeLists.txt @@ -0,0 +1,8 @@ +cc_test_old( + test_if_op + SRCS + if_op_test.cc + DEPS + ir + pd_dialect + gtest) diff --git a/test/cpp/ir/control_flow_dialect/if_op_test.cc b/test/cpp/ir/control_flow_dialect/if_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d0d962b5e79174d44ee70e305f436694dc4090c --- /dev/null +++ b/test/cpp/ir/control_flow_dialect/if_op_test.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/ir/dialect/control_flow/ir/cf_ops.h" + +TEST(if_op_test, base) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + ir::Program program(ctx); + ir::Block* block = program.block(); + ir::Builder builder(ctx, block); + + auto full_op = builder.Build( + std::vector{1}, true, phi::DataType::BOOL); + + auto if_op = builder.Build( + full_op.out(), std::vector{builder.bool_type()}); + + ir::Block* true_block = if_op.true_block(); + + builder.SetInsertionPointToStart(true_block); + + auto full_op_1 = builder.Build( + std::vector{2}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_1.out()}); + + ir::Block* false_block = if_op.false_block(); + + builder.SetInsertionPointToStart(false_block); + + auto full_op_2 = builder.Build( + std::vector{3}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_2.out()}); + + std::stringstream ss; + program.Print(ss); + + LOG(INFO) << ss.str(); +} diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 39880c4e5bdaa245b3f6d275647cb4ba2f7a2f61..48f54c63230e0d815970b0307f0007904fa63291 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -157,7 +157,7 @@ class TestDialect : public ir::Dialect { } static const char *name() { return "test"; } - void PrintOperation(const ir::Operation *op, + void PrintOperation(ir::Operation *op, ir::IrPrinter &printer) const override { printer.PrintOpResult(op); printer.os << " =";