未验证 提交 22e5ccb0 编写于 作者: W winter-wang 提交者: GitHub

[IR] add control flow dialect (#56799)

* [IR] add control flow dialect
上级 12301bc5
......@@ -20,6 +20,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/utils.h"
namespace paddle {
......@@ -51,8 +52,8 @@ void PaddleDialect::initialize() {
RegisterOps<paddle::dialect::AddNOp,
paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp,
paddle::dialect::SplitGradOp>();
paddle::dialect::SplitGradOp,
paddle::dialect::IfOp>();
RegisterInterfaces<ParameterConvertInterface>();
}
......@@ -100,6 +101,15 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
}
}
void PaddleDialect::PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const {
if (auto if_op = op->dyn_cast<IfOp>()) {
if_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
}
} // namespace dialect
} // namespace paddle
......
......@@ -25,8 +25,11 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; }
void PrintType(ir::Type type, std::ostream& os) const;
void PrintAttribute(ir::Attribute type, std::ostream& os) const;
void PrintType(ir::Type type, std::ostream& os) const override;
void PrintAttribute(ir::Attribute type, std::ostream& os) const override;
void PrintOperation(ir::Operation* op,
ir::IrPrinter& printer) const override; // NOLINT
private:
void initialize();
......
......@@ -629,6 +629,45 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}
void IfOp::Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult cond,
std::vector<ir::Type> &&output_types) {
argument.num_regions = 2;
argument.AddOperand(cond);
argument.output_types.swap(output_types);
}
ir::Block *IfOp::true_block() {
ir::Region &true_region = (*this)->region(0);
if (true_region.empty()) true_region.emplace_back();
return true_region.front();
}
ir::Block *IfOp::false_block() {
ir::Region &false_region = (*this)->region(1);
if (false_region.empty()) false_region.emplace_back();
return false_region.front();
}
void IfOp::Print(ir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = pd.if";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *true_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } else {";
for (auto item : *false_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}
void IfOp::Verify() {}
} // namespace dialect
} // namespace paddle
......@@ -636,3 +675,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
......@@ -14,7 +14,7 @@
#ifdef GET_MANUAL_OP_LIST
#undef GET_MANUAL_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::SplitGradOp
paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp
#else
......@@ -28,6 +28,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
......@@ -116,6 +117,23 @@ class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
static void InferMeta(phi::InferMetaContext *infer_meta);
};
class IfOp : public ir::Op<IfOp> {
public:
using Op::Op;
static const char *name() { return "pd.if"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult cond,
std::vector<ir::Type> &&output_types);
ir::Value cond() { return operand_source(0); }
ir::Block *true_block();
ir::Block *false_block();
void Print(ir::IrPrinter &printer); // NOLINT
void Verify();
};
} // namespace dialect
} // namespace paddle
......@@ -123,5 +141,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
#endif
......@@ -145,7 +145,7 @@ class IR_API Dialect {
IR_THROW("dialect has no registered attribute printing hook");
}
virtual void PrintOperation(const Operation *op,
virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT
private:
......
......@@ -125,7 +125,7 @@ void IrPrinter::PrintProgram(const Program* program) {
}
}
void IrPrinter::PrintOperation(const Operation* op) {
void IrPrinter::PrintOperation(Operation* op) {
if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this);
return;
......@@ -156,7 +156,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) {
}
void IrPrinter::PrintFullOperation(const Operation* op) {
PrintOperation(op);
PrintGeneralOperation(op);
if (op->num_regions() > 0) {
os << newline;
}
......@@ -290,7 +290,7 @@ void IrPrinter::PrintOpReturnType(const Operation* op) {
[this]() { this->os << ", "; });
}
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op);
}
......@@ -299,9 +299,9 @@ void Program::Print(std::ostream& os) const {
printer.PrintProgram(this);
}
void Operation::Print(std::ostream& os) const {
void Operation::Print(std::ostream& os) {
IrPrinter printer(os);
printer.PrintFullOperation(this);
printer.PrintOperation(this);
}
void Type::Print(std::ostream& os) const {
......
......@@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter {
void PrintProgram(const Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(const Operation* op);
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(const Operation* op);
/// @brief print operation and its regions
......
......@@ -215,6 +215,10 @@ class Op : public OpBase {
return ConcreteOp(nullptr);
}
static bool classof(const Operation *op) {
return op && op->info().id() == TypeId::get<ConcreteOp>();
}
static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num);
......
......@@ -75,7 +75,7 @@ class IR_API alignas(8) Operation final {
const Region &region(unsigned index) const;
uint32_t num_regions() const { return num_regions_; }
void Print(std::ostream &os) const;
void Print(std::ostream &os);
const AttributeMap &attributes() const { return attributes_; }
......@@ -109,6 +109,11 @@ class IR_API alignas(8) Operation final {
return CastUtil<T>::call(this);
}
template <typename T>
bool isa() const {
return T::classof(this);
}
template <typename Trait>
bool HasTrait() const {
return info_.HasTrait<Trait>();
......
add_subdirectory(control_flow)
add_subdirectory(shape)
file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc")
ir_library(ir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS ir_core)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
namespace ir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dialect.h"
namespace ir {
class ControlFlowDialect : public Dialect {
public:
explicit ControlFlowDialect(IrContext *context)
: Dialect(name(), context, TypeId::get<ControlFlowDialect>()) {
initialize();
}
static const char *name() { return "cf"; }
private:
void initialize();
};
} // namespace ir
IR_DECLARE_EXPLICIT_TYPE_ID(ir::ControlFlowDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
namespace ir {
void YieldOp::Build(Builder &builder,
OperationArgument &argument,
std::vector<OpResult> &&inputs) {
argument.AddOperands(inputs.begin(), inputs.end());
}
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::YieldOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h"
namespace ir {
class IR_API YieldOp : public Op<YieldOp> {
public:
using Op::Op;
static const char *name() { return "cf.yield"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
std::vector<OpResult> &&inputs);
void Verify() {}
};
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::YieldOp);
......@@ -4,4 +4,5 @@ add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
add_subdirectory(kernel_dialect)
add_subdirectory(cinn)
add_subdirectory(control_flow_dialect)
add_subdirectory(shape_dialect)
cc_test_old(
test_if_op
SRCS
if_op_test.cc
DEPS
ir
pd_dialect
gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/ir/dialect/control_flow/ir/cf_ops.h"
TEST(if_op_test, base) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ctx->GetOrRegisterDialect<ir::ControlFlowDialect>();
ir::Program program(ctx);
ir::Block* block = program.block();
ir::Builder builder(ctx, block);
auto full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, true, phi::DataType::BOOL);
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::vector<ir::Type>{builder.bool_type()});
ir::Block* true_block = if_op.true_block();
builder.SetInsertionPointToStart(true_block);
auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<ir::YieldOp>(std::vector<ir::OpResult>{full_op_1.out()});
ir::Block* false_block = if_op.false_block();
builder.SetInsertionPointToStart(false_block);
auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
builder.Build<ir::YieldOp>(std::vector<ir::OpResult>{full_op_2.out()});
std::stringstream ss;
program.Print(ss);
LOG(INFO) << ss.str();
}
......@@ -157,7 +157,7 @@ class TestDialect : public ir::Dialect {
}
static const char *name() { return "test"; }
void PrintOperation(const ir::Operation *op,
void PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册