未验证 提交 b75c0c24 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] part 1: add pass base, pass manager, adaptor pass, ut (#54023)

* [IR&PASS] part 1: add pass base, pass manager, adaptor pass, ut

* include cstdint
上级 7bf97d2c
...@@ -9,6 +9,7 @@ add_subdirectory(phi) ...@@ -9,6 +9,7 @@ add_subdirectory(phi)
add_subdirectory(fluid) add_subdirectory(fluid)
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(pass)
# NOTE(zhiqiu): The changes of cc tests # NOTE(zhiqiu): The changes of cc tests
# Before, (1) the source file of cc tests are distributed in different sub-directories, # Before, (1) the source file of cc tests are distributed in different sub-directories,
......
if(NOT WITH_NEWIR)
return()
endif()
file(GLOB NEW_PASS_SRCS "*.cc")
cc_library(
new_pass
SRCS ${NEW_PASS_SRCS}
DEPS new_ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pass/pass.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/operation.h"
#include "paddle/pass/pass_adaptor.h"
#include "paddle/pass/pass_manager.h"
namespace ir {
void detail::PassAdaptor::Run(ir::Operation* op, uint8_t opt_level) {
RunImpl(op, opt_level);
}
void detail::PassAdaptor::RunImpl(ir::Operation* op, uint8_t opt_level) {
// TODO(liuyuanle): Support block, region, etc.
return;
}
bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
uint8_t opt_level) {
for (auto& pass : pm.GetPasses()) {
if (pass->CanScheduleOn(op)) {
if (!RunPass(pass.get(), op, opt_level)) {
return false;
}
}
}
// Apply pass manager on all nested ir.
if (!RunPass(pm.pass_adaptor_.get(), op, opt_level)) {
return false;
}
return true;
}
bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
uint8_t opt_level) {
if (opt_level < pass->info_.opt_level) return true;
pass->pass_state_ = detail::PassExecutionState(op);
if (auto* adaptor = dynamic_cast<detail::PassAdaptor*>(pass)) {
adaptor->Run(op, opt_level);
} else {
pass->Run(op);
}
bool pass_failed = pass->pass_state_->pass_failed;
return !pass_failed;
}
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}
bool PassManager::Run(ir::Operation* op) {
if (!Initialize(context_)) {
return false;
}
return RunPasses(op);
}
bool PassManager::RunPasses(ir::Operation* op) {
return detail::PassAdaptor::RunPipeline(*this, op, opt_level_);
}
bool PassManager::Initialize(ir::IrContext* context) {
for (auto& pass : GetPasses()) {
if (!pass->Initialize(context)) return false;
}
return true;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
#include "paddle/utils/optional.h"
namespace ir {
class IrContext;
class Operation;
namespace detail {
class PassAdaptor;
}
namespace detail {
struct PassExecutionState {
explicit PassExecutionState(ir::Operation* ir) : ir(ir), pass_failed(false) {}
ir::Operation* ir;
bool pass_failed;
// TODO(liuyuanle): Add implementation of AnalysisManager and
// PreservedAnalyses.
};
struct PassInfo {
PassInfo(const char* name,
uint8_t opt_level,
const std::vector<const char* /* pass name */>& dependents = {})
: name(name), opt_level(opt_level), dependents(dependents) {}
// Pass name.
const char* name;
// opt_level=0: the basic pass which framework need.
// opt_level=1: the fusion logical pass.
// opt_level=2: constant fold, cse, memory optimize, etc.
// opt_level=3: layout.
uint8_t opt_level;
// The list which pass depends on.
// PassManager will check the constraint(TODO).
std::vector<const char*> dependents;
};
} // namespace detail
/// We can access pass only from PassManager.
class Pass {
public:
explicit Pass(const char* name,
uint8_t opt_level,
const std::vector<const char*>& dependents = {})
: info_(name, opt_level, dependents) {}
virtual ~Pass() = default;
const detail::PassInfo& GetPassInfo() const { return info_; }
protected:
virtual void Run(ir::Operation* op) = 0;
// TODO(liuyuanle): Add block/region judgement.
virtual inline bool CanScheduleOn(ir::Operation* op) const { return true; }
virtual bool Initialize(ir::IrContext* context) { return true; }
void SignalPassFailure() { pass_state_->pass_failed = true; }
detail::PassInfo info_;
paddle::optional<detail::PassExecutionState> pass_state_;
friend class PassManager;
friend class detail::PassAdaptor;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pass/pass.h"
namespace ir {
class Operation;
class PassManager;
namespace detail {
// Used to run operation passes over nested operations.
class PassAdaptor final : public Pass {
public:
explicit PassAdaptor(PassManager* pm) : Pass("pass_adaptor", 0), pm_(pm) {}
void Run(ir::Operation*) override {}
void Run(ir::Operation*, uint8_t opt_level);
private:
void RunImpl(ir::Operation* op, uint8_t opt_level);
static bool RunPass(Pass* pass, ir::Operation* op, uint8_t opt_level);
static bool RunPipeline(const PassManager& pm,
ir::Operation* op,
uint8_t opt_level);
// Use for RunImpl later.
PassManager* pm_;
// For accessing RunPipeline.
friend class ir::PassManager;
};
} // namespace detail
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace ir {
class IrContext;
class Operation;
class Pass;
namespace detail {
class PassAdaptor;
}
class PassManager {
public:
explicit PassManager(ir::IrContext *context, uint8_t opt_level = 2);
~PassManager() = default;
const std::vector<std::unique_ptr<Pass>> &GetPasses() const {
return passes_;
}
bool Empty() const { return passes_.empty(); }
ir::IrContext *GetContext() const { return context_; }
bool Run(ir::Operation *op);
void AddPass(std::unique_ptr<Pass> pass) {
passes_.emplace_back(std::move(pass));
}
private:
bool RunPasses(ir::Operation *op);
bool Initialize(ir::IrContext *context);
private:
ir::IrContext *context_;
uint8_t opt_level_;
std::vector<std::unique_ptr<Pass>> passes_;
std::unique_ptr<Pass> pass_adaptor_;
friend class detail::PassAdaptor;
};
} // namespace ir
...@@ -4,6 +4,7 @@ add_subdirectory(new_executor) ...@@ -4,6 +4,7 @@ add_subdirectory(new_executor)
add_subdirectory(prim) add_subdirectory(prim)
add_subdirectory(imperative) add_subdirectory(imperative)
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(pass)
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(eager) add_subdirectory(eager)
add_subdirectory(fluid) add_subdirectory(fluid)
if(WITH_NEWIR)
cc_test_old(pass_manager_test SRCS pass_manager_test.cc DEPS new_pass gtest)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cstring>
#include "glog/logging.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_base.h"
#include "paddle/ir/operation.h"
#include "paddle/pass/pass.h"
#include "paddle/pass/pass_manager.h"
ir::AttributeMap CreateAttributeMap(ir::IrContext *ctx,
std::string attribute_name,
std::string attribute) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute);
ir::AttributeMap attr_map;
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_name, attr_value));
return attr_map;
}
class TestOp : public ir::Op<TestOp> {
public:
using Op::Op;
static const char *name() { return "TestDialect.TestOp"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (attributes.count("op1_attr1") == 0 ||
!attributes.at("op1_attr1").isa<ir::StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
}
};
const char *TestOp::attributes_name[attributes_num] = {"op1_attr1"};
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "TestDialect"; }
private:
void initialize() { RegisterOps<TestOp>(); }
};
class TestPass : public ir::Pass {
public:
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
auto test_op = op->dyn_cast<TestOp>();
CHECK_EQ(test_op.operation(), op);
CHECK_EQ(test_op.name(), test_op->op_info().name());
LOG(INFO) << "In " << info_.name << ": " << test_op->op_info().name();
}
bool CanScheduleOn(ir::Operation *op) const override {
return std::strcmp(op->op_info().name(), "TestDialect.TestOp") == 0;
}
};
TEST(pass_manager_test, pass_manager_test) {
// (1) Register Dialect, Operation into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
CHECK_EQ(test_dialect != nullptr, true);
// (2) Get registered operations.
std::string op_name = std::string(TestOp::name());
auto op_info = ctx->GetRegisteredOpInfo(op_name);
CHECK_EQ(op_info != nullptr, true);
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op =
ir::Operation::create(op_inputs,
op_output_types,
CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"),
op_info);
// (4) Test pass manager for op.
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
CHECK_EQ(pm.Run(op), true);
op->destroy();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册