未验证 提交 22e5ccb0 编写于 作者: W winter-wang 提交者: GitHub

[IR] add control flow dialect (#56799)

* [IR] add control flow dialect
上级 12301bc5
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h" #include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
namespace paddle { namespace paddle {
...@@ -51,8 +52,8 @@ void PaddleDialect::initialize() { ...@@ -51,8 +52,8 @@ void PaddleDialect::initialize() {
RegisterOps<paddle::dialect::AddNOp, RegisterOps<paddle::dialect::AddNOp,
paddle::dialect::AddN_Op, paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp, paddle::dialect::AddNWithKernelOp,
paddle::dialect::SplitGradOp>(); paddle::dialect::SplitGradOp,
paddle::dialect::IfOp>();
RegisterInterfaces<ParameterConvertInterface>(); RegisterInterfaces<ParameterConvertInterface>();
} }
...@@ -100,6 +101,15 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { ...@@ -100,6 +101,15 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
} }
} }
void PaddleDialect::PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const {
if (auto if_op = op->dyn_cast<IfOp>()) {
if_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
......
...@@ -25,8 +25,11 @@ class PaddleDialect : public ir::Dialect { ...@@ -25,8 +25,11 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; } static const char* name() { return "pd"; }
void PrintType(ir::Type type, std::ostream& os) const; void PrintType(ir::Type type, std::ostream& os) const override;
void PrintAttribute(ir::Attribute type, std::ostream& os) const; void PrintAttribute(ir::Attribute type, std::ostream& os) const override;
void PrintOperation(ir::Operation* op,
ir::IrPrinter& printer) const override; // NOLINT
private: private:
void initialize(); void initialize();
......
...@@ -629,6 +629,45 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { ...@@ -629,6 +629,45 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta); fn(infer_meta);
} }
void IfOp::Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult cond,
std::vector<ir::Type> &&output_types) {
argument.num_regions = 2;
argument.AddOperand(cond);
argument.output_types.swap(output_types);
}
ir::Block *IfOp::true_block() {
ir::Region &true_region = (*this)->region(0);
if (true_region.empty()) true_region.emplace_back();
return true_region.front();
}
ir::Block *IfOp::false_block() {
ir::Region &false_region = (*this)->region(1);
if (false_region.empty()) false_region.emplace_back();
return false_region.front();
}
void IfOp::Print(ir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = pd.if";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *true_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } else {";
for (auto item : *false_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}
void IfOp::Verify() {}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -636,3 +675,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) ...@@ -636,3 +675,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#ifdef GET_MANUAL_OP_LIST #ifdef GET_MANUAL_OP_LIST
#undef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::SplitGradOp paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp
#else #else
...@@ -28,6 +28,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp ...@@ -28,6 +28,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
...@@ -116,6 +117,23 @@ class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> { ...@@ -116,6 +117,23 @@ class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
static void InferMeta(phi::InferMetaContext *infer_meta); static void InferMeta(phi::InferMetaContext *infer_meta);
}; };
class IfOp : public ir::Op<IfOp> {
public:
using Op::Op;
static const char *name() { return "pd.if"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult cond,
std::vector<ir::Type> &&output_types);
ir::Value cond() { return operand_source(0); }
ir::Block *true_block();
ir::Block *false_block();
void Print(ir::IrPrinter &printer); // NOLINT
void Verify();
};
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -123,5 +141,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) ...@@ -123,5 +141,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
#endif #endif
...@@ -145,7 +145,7 @@ class IR_API Dialect { ...@@ -145,7 +145,7 @@ class IR_API Dialect {
IR_THROW("dialect has no registered attribute printing hook"); IR_THROW("dialect has no registered attribute printing hook");
} }
virtual void PrintOperation(const Operation *op, virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT IrPrinter &printer) const; // NOLINT
private: private:
......
...@@ -125,7 +125,7 @@ void IrPrinter::PrintProgram(const Program* program) { ...@@ -125,7 +125,7 @@ void IrPrinter::PrintProgram(const Program* program) {
} }
} }
void IrPrinter::PrintOperation(const Operation* op) { void IrPrinter::PrintOperation(Operation* op) {
if (auto* dialect = op->dialect()) { if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this); dialect->PrintOperation(op, *this);
return; return;
...@@ -156,7 +156,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) { ...@@ -156,7 +156,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) {
} }
void IrPrinter::PrintFullOperation(const Operation* op) { void IrPrinter::PrintFullOperation(const Operation* op) {
PrintOperation(op); PrintGeneralOperation(op);
if (op->num_regions() > 0) { if (op->num_regions() > 0) {
os << newline; os << newline;
} }
...@@ -290,7 +290,7 @@ void IrPrinter::PrintOpReturnType(const Operation* op) { ...@@ -290,7 +290,7 @@ void IrPrinter::PrintOpReturnType(const Operation* op) {
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
} }
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const { void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op); printer.PrintGeneralOperation(op);
} }
...@@ -299,9 +299,9 @@ void Program::Print(std::ostream& os) const { ...@@ -299,9 +299,9 @@ void Program::Print(std::ostream& os) const {
printer.PrintProgram(this); printer.PrintProgram(this);
} }
void Operation::Print(std::ostream& os) const { void Operation::Print(std::ostream& os) {
IrPrinter printer(os); IrPrinter printer(os);
printer.PrintFullOperation(this); printer.PrintOperation(this);
} }
void Type::Print(std::ostream& os) const { void Type::Print(std::ostream& os) const {
......
...@@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter { ...@@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter {
void PrintProgram(const Program* program); void PrintProgram(const Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation /// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(const Operation* op); void PrintOperation(Operation* op);
/// @brief print operation itself without its regions /// @brief print operation itself without its regions
void PrintGeneralOperation(const Operation* op); void PrintGeneralOperation(const Operation* op);
/// @brief print operation and its regions /// @brief print operation and its regions
......
...@@ -215,6 +215,10 @@ class Op : public OpBase { ...@@ -215,6 +215,10 @@ class Op : public OpBase {
return ConcreteOp(nullptr); return ConcreteOp(nullptr);
} }
static bool classof(const Operation *op) {
return op && op->info().id() == TypeId::get<ConcreteOp>();
}
static std::vector<InterfaceValue> GetInterfaceMap() { static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value; constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num); std::vector<InterfaceValue> interfaces_map(interfaces_num);
......
...@@ -75,7 +75,7 @@ class IR_API alignas(8) Operation final { ...@@ -75,7 +75,7 @@ class IR_API alignas(8) Operation final {
const Region &region(unsigned index) const; const Region &region(unsigned index) const;
uint32_t num_regions() const { return num_regions_; } uint32_t num_regions() const { return num_regions_; }
void Print(std::ostream &os) const; void Print(std::ostream &os);
const AttributeMap &attributes() const { return attributes_; } const AttributeMap &attributes() const { return attributes_; }
...@@ -109,6 +109,11 @@ class IR_API alignas(8) Operation final { ...@@ -109,6 +109,11 @@ class IR_API alignas(8) Operation final {
return CastUtil<T>::call(this); return CastUtil<T>::call(this);
} }
template <typename T>
bool isa() const {
return T::classof(this);
}
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
return info_.HasTrait<Trait>(); return info_.HasTrait<Trait>();
......
add_subdirectory(control_flow)
add_subdirectory(shape) add_subdirectory(shape)
file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc")
ir_library(ir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS ir_core)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
namespace ir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dialect.h"
namespace ir {
class ControlFlowDialect : public Dialect {
public:
explicit ControlFlowDialect(IrContext *context)
: Dialect(name(), context, TypeId::get<ControlFlowDialect>()) {
initialize();
}
static const char *name() { return "cf"; }
private:
void initialize();
};
} // namespace ir
IR_DECLARE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
namespace ir {
void YieldOp::Build(Builder &builder,
OperationArgument &argument,
std::vector<OpResult> &&inputs) {
argument.AddOperands(inputs.begin(), inputs.end());
}
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::YieldOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h"
namespace ir {
class IR_API YieldOp : public Op<YieldOp> {
public:
using Op::Op;
static const char *name() { return "cf.yield"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
std::vector<OpResult> &&inputs);
void Verify() {}
};
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::YieldOp);
...@@ -4,4 +4,5 @@ add_subdirectory(pass) ...@@ -4,4 +4,5 @@ add_subdirectory(pass)
add_subdirectory(pattern_rewrite) add_subdirectory(pattern_rewrite)
add_subdirectory(kernel_dialect) add_subdirectory(kernel_dialect)
add_subdirectory(cinn) add_subdirectory(cinn)
add_subdirectory(control_flow_dialect)
add_subdirectory(shape_dialect) add_subdirectory(shape_dialect)
cc_test_old(
test_if_op
SRCS
if_op_test.cc
DEPS
ir
pd_dialect
gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
TEST(if_op_test, base) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ctx->GetOrRegisterDialect<ir::ControlFlowDialect>();
ir::Program program(ctx);
ir::Block* block = program.block();
ir::Builder builder(ctx, block);
auto full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, true, phi::DataType::BOOL);
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::vector<ir::Type>{builder.bool_type()});
ir::Block* true_block = if_op.true_block();
builder.SetInsertionPointToStart(true_block);
auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<ir::YieldOp>(std::vector<ir::OpResult>{full_op_1.out()});
ir::Block* false_block = if_op.false_block();
builder.SetInsertionPointToStart(false_block);
auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
builder.Build<ir::YieldOp>(std::vector<ir::OpResult>{full_op_2.out()});
std::stringstream ss;
program.Print(ss);
LOG(INFO) << ss.str();
}
...@@ -157,7 +157,7 @@ class TestDialect : public ir::Dialect { ...@@ -157,7 +157,7 @@ class TestDialect : public ir::Dialect {
} }
static const char *name() { return "test"; } static const char *name() { return "test"; }
void PrintOperation(const ir::Operation *op, void PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const override { ir::IrPrinter &printer) const override {
printer.PrintOpResult(op); printer.PrintOpResult(op);
printer.os << " ="; printer.os << " =";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册