From b75c0c24f4f2d1697d3524f50f34f22f3402cde2 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Fri, 26 May 2023 13:39:53 +0800 Subject: [PATCH] [IR&PASS] part 1: add pass base, pass manager, adaptor pass, ut (#54023) * [IR&PASS] part 1: add pass base, pass manager, adaptor pass, ut * include cstdint --- paddle/CMakeLists.txt | 1 + paddle/pass/CMakeLists.txt | 10 +++ paddle/pass/pass.cc | 93 ++++++++++++++++++++++++ paddle/pass/pass.h | 94 +++++++++++++++++++++++++ paddle/pass/pass_adaptor.h | 52 ++++++++++++++ paddle/pass/pass_manager.h | 67 ++++++++++++++++++ test/cpp/CMakeLists.txt | 1 + test/cpp/pass/CMakeLists.txt | 3 + test/cpp/pass/pass_manager_test.cc | 109 +++++++++++++++++++++++++++++ 9 files changed, 430 insertions(+) create mode 100644 paddle/pass/CMakeLists.txt create mode 100644 paddle/pass/pass.cc create mode 100644 paddle/pass/pass.h create mode 100644 paddle/pass/pass_adaptor.h create mode 100644 paddle/pass/pass_manager.h create mode 100644 test/cpp/pass/CMakeLists.txt create mode 100644 test/cpp/pass/pass_manager_test.cc diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 924d0c2cb8c..af7c0309935 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(phi) add_subdirectory(fluid) add_subdirectory(ir) +add_subdirectory(pass) # NOTE(zhiqiu): The changes of cc tests # Before, (1) the source file of cc tests are distributed in different sub-directories, diff --git a/paddle/pass/CMakeLists.txt b/paddle/pass/CMakeLists.txt new file mode 100644 index 00000000000..2a46ae60d2c --- /dev/null +++ b/paddle/pass/CMakeLists.txt @@ -0,0 +1,10 @@ +if(NOT WITH_NEWIR) + return() +endif() + +file(GLOB NEW_PASS_SRCS "*.cc") + +cc_library( + new_pass + SRCS ${NEW_PASS_SRCS} + DEPS new_ir) diff --git a/paddle/pass/pass.cc b/paddle/pass/pass.cc new file mode 100644 index 00000000000..2d74f7a20e2 --- /dev/null +++ b/paddle/pass/pass.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pass/pass.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/operation.h" +#include "paddle/pass/pass_adaptor.h" +#include "paddle/pass/pass_manager.h" + +namespace ir { + +void detail::PassAdaptor::Run(ir::Operation* op, uint8_t opt_level) { + RunImpl(op, opt_level); +} + +void detail::PassAdaptor::RunImpl(ir::Operation* op, uint8_t opt_level) { + // TODO(liuyuanle): Support block, region, etc. + return; +} + +bool detail::PassAdaptor::RunPipeline(const PassManager& pm, + ir::Operation* op, + uint8_t opt_level) { + for (auto& pass : pm.GetPasses()) { + if (pass->CanScheduleOn(op)) { + if (!RunPass(pass.get(), op, opt_level)) { + return false; + } + } + } + + // Apply pass manager on all nested ir. + if (!RunPass(pm.pass_adaptor_.get(), op, opt_level)) { + return false; + } + + return true; +} + +bool detail::PassAdaptor::RunPass(Pass* pass, + ir::Operation* op, + uint8_t opt_level) { + if (opt_level < pass->info_.opt_level) return true; + + pass->pass_state_ = detail::PassExecutionState(op); + + if (auto* adaptor = dynamic_cast(pass)) { + adaptor->Run(op, opt_level); + } else { + pass->Run(op); + } + + bool pass_failed = pass->pass_state_->pass_failed; + + return !pass_failed; +} + +PassManager::PassManager(ir::IrContext* context, uint8_t opt_level) + : context_(context), opt_level_(opt_level) { + pass_adaptor_ = std::make_unique(this); +} + +bool PassManager::Run(ir::Operation* op) { + if (!Initialize(context_)) { + return false; + } + return RunPasses(op); +} + +bool PassManager::RunPasses(ir::Operation* op) { + return detail::PassAdaptor::RunPipeline(*this, op, opt_level_); +} + +bool PassManager::Initialize(ir::IrContext* context) { + for (auto& pass : GetPasses()) { + if (!pass->Initialize(context)) return false; + } + + return true; +} + +} // namespace ir diff --git a/paddle/pass/pass.h b/paddle/pass/pass.h new file mode 100644 index 00000000000..bcbabe6b36e --- /dev/null +++ b/paddle/pass/pass.h @@ -0,0 +1,94 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/utils/optional.h" + +namespace ir { + +class IrContext; +class Operation; + +namespace detail { +class PassAdaptor; +} + +namespace detail { + +struct PassExecutionState { + explicit PassExecutionState(ir::Operation* ir) : ir(ir), pass_failed(false) {} + + ir::Operation* ir; + bool pass_failed; + // TODO(liuyuanle): Add implementation of AnalysisManager and + // PreservedAnalyses. +}; + +struct PassInfo { + PassInfo(const char* name, + uint8_t opt_level, + const std::vector& dependents = {}) + : name(name), opt_level(opt_level), dependents(dependents) {} + + // Pass name. + const char* name; + + // opt_level=0: the basic pass which framework need. + // opt_level=1: the fusion logical pass. + // opt_level=2: constant fold, cse, memory optimize, etc. + // opt_level=3: layout. + uint8_t opt_level; + + // The list which pass depends on. + // PassManager will check the constraint(TODO). + std::vector dependents; +}; + +} // namespace detail + +/// We can access pass only from PassManager. +class Pass { + public: + explicit Pass(const char* name, + uint8_t opt_level, + const std::vector& dependents = {}) + : info_(name, opt_level, dependents) {} + + virtual ~Pass() = default; + + const detail::PassInfo& GetPassInfo() const { return info_; } + + protected: + virtual void Run(ir::Operation* op) = 0; + + // TODO(liuyuanle): Add block/region judgement. + virtual inline bool CanScheduleOn(ir::Operation* op) const { return true; } + + virtual bool Initialize(ir::IrContext* context) { return true; } + + void SignalPassFailure() { pass_state_->pass_failed = true; } + + detail::PassInfo info_; + + paddle::optional pass_state_; + + friend class PassManager; + friend class detail::PassAdaptor; +}; + +} // namespace ir diff --git a/paddle/pass/pass_adaptor.h b/paddle/pass/pass_adaptor.h new file mode 100644 index 00000000000..2bc82510617 --- /dev/null +++ b/paddle/pass/pass_adaptor.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pass/pass.h" + +namespace ir { + +class Operation; + +class PassManager; + +namespace detail { +// Used to run operation passes over nested operations. +class PassAdaptor final : public Pass { + public: + explicit PassAdaptor(PassManager* pm) : Pass("pass_adaptor", 0), pm_(pm) {} + + void Run(ir::Operation*) override {} + + void Run(ir::Operation*, uint8_t opt_level); + + private: + void RunImpl(ir::Operation* op, uint8_t opt_level); + + static bool RunPass(Pass* pass, ir::Operation* op, uint8_t opt_level); + + static bool RunPipeline(const PassManager& pm, + ir::Operation* op, + uint8_t opt_level); + + // Use for RunImpl later. + PassManager* pm_; + + // For accessing RunPipeline. + friend class ir::PassManager; +}; +} // namespace detail + +} // namespace ir diff --git a/paddle/pass/pass_manager.h b/paddle/pass/pass_manager.h new file mode 100644 index 00000000000..3969c65d264 --- /dev/null +++ b/paddle/pass/pass_manager.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ir { + +class IrContext; +class Operation; +class Pass; + +namespace detail { +class PassAdaptor; +} + +class PassManager { + public: + explicit PassManager(ir::IrContext *context, uint8_t opt_level = 2); + + ~PassManager() = default; + + const std::vector> &GetPasses() const { + return passes_; + } + + bool Empty() const { return passes_.empty(); } + + ir::IrContext *GetContext() const { return context_; } + + bool Run(ir::Operation *op); + + void AddPass(std::unique_ptr pass) { + passes_.emplace_back(std::move(pass)); + } + + private: + bool RunPasses(ir::Operation *op); + + bool Initialize(ir::IrContext *context); + + private: + ir::IrContext *context_; + + uint8_t opt_level_; + + std::vector> passes_; + + std::unique_ptr pass_adaptor_; + + friend class detail::PassAdaptor; +}; + +} // namespace ir diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 60bdac586c4..9ff3f24ebaa 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(new_executor) add_subdirectory(prim) add_subdirectory(imperative) add_subdirectory(ir) +add_subdirectory(pass) add_subdirectory(inference) add_subdirectory(eager) add_subdirectory(fluid) diff --git a/test/cpp/pass/CMakeLists.txt b/test/cpp/pass/CMakeLists.txt new file mode 100644 index 00000000000..b03d63e3503 --- /dev/null +++ b/test/cpp/pass/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_NEWIR) + cc_test_old(pass_manager_test SRCS pass_manager_test.cc DEPS new_pass gtest) +endif() diff --git a/test/cpp/pass/pass_manager_test.cc b/test/cpp/pass/pass_manager_test.cc new file mode 100644 index 00000000000..d33c5a04222 --- /dev/null +++ b/test/cpp/pass/pass_manager_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "glog/logging.h" + +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/dialect.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/op_base.h" +#include "paddle/ir/operation.h" +#include "paddle/pass/pass.h" +#include "paddle/pass/pass_manager.h" + +ir::AttributeMap CreateAttributeMap(ir::IrContext *ctx, + std::string attribute_name, + std::string attribute) { + ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute); + ir::AttributeMap attr_map; + attr_map.insert( + std::pair(attribute_name, attr_value)); + return attr_map; +} + +class TestOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "TestDialect.TestOp"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + if (attributes.count("op1_attr1") == 0 || + !attributes.at("op1_attr1").isa()) { + throw("Type of attribute: parameter_name is not right."); + } + } +}; +const char *TestOp::attributes_name[attributes_num] = {"op1_attr1"}; + +class TestDialect : public ir::Dialect { + public: + explicit TestDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + initialize(); + } + static const char *name() { return "TestDialect"; } + + private: + void initialize() { RegisterOps(); } +}; + +class TestPass : public ir::Pass { + public: + TestPass() : ir::Pass("TestPass", 1) {} + void Run(ir::Operation *op) override { + auto test_op = op->dyn_cast(); + CHECK_EQ(test_op.operation(), op); + CHECK_EQ(test_op.name(), test_op->op_info().name()); + LOG(INFO) << "In " << info_.name << ": " << test_op->op_info().name(); + } + + bool CanScheduleOn(ir::Operation *op) const override { + return std::strcmp(op->op_info().name(), "TestDialect.TestOp") == 0; + } +}; + +TEST(pass_manager_test, pass_manager_test) { + // (1) Register Dialect, Operation into IrContext. + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); + CHECK_EQ(test_dialect != nullptr, true); + + // (2) Get registered operations. + std::string op_name = std::string(TestOp::name()); + auto op_info = ctx->GetRegisteredOpInfo(op_name); + CHECK_EQ(op_info != nullptr, true); + + // (3) Test uses for op. + std::vector op_inputs = {}; + std::vector op_output_types = {ir::Float32Type::get(ctx)}; + ir::Operation *op = + ir::Operation::create(op_inputs, + op_output_types, + CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"), + op_info); + + // (4) Test pass manager for op. + ir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + CHECK_EQ(pm.Run(op), true); + + op->destroy(); +} -- GitLab