diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index 42b87f3853c58ab336474773f7eeb2501b4fd971..c760e7a98614cb55dd5c39883601cf85df00d7c1 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() { std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } +static bool CanBeErased(ir::MemOptVarInfo *var_info) { + if (var_info->IsSkippedAllMemoryOptimization() || + !var_info->DecreaseRefCnt()) { + return false; + } +#ifdef PADDLE_WITH_CINN + // if parent_holder exists, it should meet deletion condition too. + std::shared_ptr parent_holder = var_info->ParentHolder(); + if (parent_holder && !CanBeErased(parent_holder.get())) { + return false; + } +#endif + return true; +} + void EagerDeletionOpHandle::RunImpl() { if (vars_.size() != var_infos_.size() || is_variant_scope_) { vars_.clear(); @@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() { std::deque> garbages; for (size_t i = 0; i < var_infos_.size(); ++i) { auto *var_info = var_infos_[i]; - if (var_info->IsSkippedAllMemoryOptimization() || - !var_info->DecreaseRefCnt()) { + if (!CanBeErased(var_info)) { VLOG(4) << "skip memory optimization with var: " << var_info->Name(); continue; } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index ee63b314adedb1128c16327be667e33d1f32b06e..25b07ddf4141466fd8ef41d7cb564e7999aa7a8f 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -5,13 +5,19 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) -cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle - eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) +SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) +if (WITH_CINN) + cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) + cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op) + list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass) +endif() -cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) +cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS}) + +cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper) -cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) +cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index 7b9b5aa62307443789214b4cca2c6b367dc2a287..af1a65f7a6c3bcf6ad3abec1c18450577329c8f7 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -285,6 +285,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { auto recurrent_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass"); recurrent_op_eager_deletion_pass->Apply(graph); + +#ifdef PADDLE_WITH_CINN + auto share_varinfo_into_cinn_pass = + ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); + share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos); + share_varinfo_into_cinn_pass->Apply(graph); +#endif } } // namespace ir @@ -300,3 +307,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) USE_PASS(conditional_block_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass); +#ifdef PADDLE_WITH_CINN +USE_PASS(share_varinfo_into_cinn_pass); +#endif diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h index 94842485440bdce17f47d3b2fc7000e57a37c3c8..e89734bacec36e9178d6b315e4df716ffe92f72f 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h @@ -66,6 +66,12 @@ class MemOptVarInfo { return skip_memory_reuse_ || skip_all_memory_optimization_; } + void SetParentHolder(std::shared_ptr parent) { + parent_holder_ = parent; + } + + std::shared_ptr ParentHolder() const { return parent_holder_; } + const std::string &Name() const { return name_; } private: @@ -88,6 +94,9 @@ class MemOptVarInfo { std::atomic runtime_ref_cnt_; bool skip_memory_reuse_{false}; bool skip_all_memory_optimization_{false}; + // point to var info of the same variable in the main graph, + // used in external(input/output) variables of a subgraph + std::shared_ptr parent_holder_{nullptr}; }; using MemOptVarInfoMapList = std::vector< diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b2a62695fb135925d43a3341aaacdf956da8da3 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle::framework::ir { + +using Name2VarInfoMap = + std::unordered_map>; + +static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp( + details::ComputationOpHandle* compute_op) { + for (details::VarHandleBase* var : compute_op->Outputs()) { + if (!var->Node()->IsCtrlVar()) { + continue; + } + for (details::OpHandleBase* op : var->PendingOps()) { + auto* eager_deletion_op = + dynamic_cast(op); + if (eager_deletion_op) { + return eager_deletion_op; + } + } + } + return nullptr; +} + +static void ShareVarInfoToCinnLaunch( + const MemOptVarInfoMapList& varinfo_maps, + details::ComputationOpHandle* cinn_launch_op) { + details::EagerDeletionOpHandle* followed_eager_deletion_op = + FindFollowedEagerDeletionOp(cinn_launch_op); + if (!followed_eager_deletion_op) { + VLOG(4) << "No eager_deletion op found after this cinn_launch op"; + return; + } + + std::vector vars_to_delete = + followed_eager_deletion_op->VarsToDelete(); + if (vars_to_delete.empty()) { + VLOG(4) << "No var to be deleted after this cinn_launch op"; + return; + } + VLOG(4) << "Variables would be deleted by the eager_deletion_op" + << " following the cinn_launch:" + << paddle::string::join_strings(vars_to_delete, ','); + + const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( + cinn_launch_op->GetOp()->Attr(operators::kCompilationKey)); + auto& dst_varinfo_map = + subgraph.Get(paddle2cinn::kMemOptVarInfoFromMainGraph); + const Name2VarInfoMap& src_varinfo_map = + varinfo_maps.at(cinn_launch_op->GetScopeIdx()); + + // collect all MemOptVarInfos of external variables + // that would be eager deleted after the cinn_launch subgraph executed, + // and store them as attribute of the subgraph + for (const auto& var_name : vars_to_delete) { + auto it = src_varinfo_map.find(var_name); + PADDLE_ENFORCE_NE(it, src_varinfo_map.end(), + platform::errors::NotFound( + "MemOptVarInfo of var[%s] not found", var_name)); + dst_varinfo_map.emplace(var_name, it->second); + } +} + +static void TakeVarInfoFromMainGraph( + const Name2VarInfoMap& src_varinfo_map, + const MemOptVarInfoMapList& varinfo_maps, + details::EagerDeletionOpHandle* eager_deletion_op) { + const Name2VarInfoMap& dst_varinfo_map = + varinfo_maps.at(eager_deletion_op->GetScopeIdx()); + for (auto&& var_name : eager_deletion_op->VarsToDelete()) { + auto dst_it = dst_varinfo_map.find(var_name); + PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(), + platform::errors::NotFound( + "MemOptVarInfo of var[%s] not found", var_name)); + auto src_it = src_varinfo_map.find(var_name); + if (src_it != src_varinfo_map.end()) { + VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder"; + dst_it->second->SetParentHolder(src_it->second); + } + } +} + +// This pass will be applied on both the main graph and all cinn subgraphs, +// and it distinguishs them according to whether the graph has the +// kMemOptVarInfoFromMainGraph attribute or not. +// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos +// to their subgraphs. +// On a cinn subgraph, it iterates each variable that will be deleted by a +// eager_deletion op, and take the MemOptVarInfo from the main graph +// if such one found. +class ShareMemOptInfoToSubGraphPass : public ir::Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override { + auto all_ops = ir::FilterByNodeWrapper(*graph); + const auto& varinfo_maps = Get(kMemOptVarInfoMapList); + + // the main graph + if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) { + for (details::OpHandleBase* op : all_ops) { + auto compute_op = dynamic_cast(op); + if (compute_op && compute_op->Name() == "cinn_launch") { + ShareVarInfoToCinnLaunch(varinfo_maps, compute_op); + } + } + } else { // a cinn subgraph + const auto& parent_varinfo_map = + graph->Get(paddle2cinn::kMemOptVarInfoFromMainGraph); + for (details::OpHandleBase* op : all_ops) { + auto eager_deletion_op = + dynamic_cast(op); + if (eager_deletion_op) { + TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps, + eager_deletion_op); + } + } + } + } +}; + +} // namespace paddle::framework::ir + +REGISTER_PASS(share_varinfo_into_cinn_pass, + paddle::framework::ir::ShareMemOptInfoToSubGraphPass) + .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..abed6a5bd4bc48e01d9bcf20abf1bed236ed847a --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/framework/program_desc.h" + +USE_OP(mul); +USE_OP(cinn_launch); +USE_OP(elementwise_add); +namespace paddle::framework { + +using Name2VarInfoMap = + std::unordered_map>; + +static ProgramDesc BuildProgramInsideCinnLaunchOp() { + ProgramDesc program; + auto* block = program.MutableBlock(0); + block->Var("var1"); + block->Var("var2"); + block->Var("var3"); + block->Var("var4"); + block->Var("var5"); + + auto add_op = std::unique_ptr( + new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}}, + {{"Out", {"var3"}}}, {})); + block->AppendAllocatedOp(std::move(add_op)); + auto mul_op = std::unique_ptr(new OpDesc( + "mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {})); + block->AppendAllocatedOp(std::move(mul_op)); + return program; +} + +static ProgramDesc BuildProgramWithCinnLaunchOp( + const std::string& compilation_key) { + // create a cinn_launch op + ProgramDesc program; + auto* block = program.MutableBlock(0); + block->Var("var1"); + block->Var("var2"); + block->Var("var4"); + block->Var("var5"); + + auto cinn_launch_op = std::unique_ptr( + new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}}, + {{"Out", {"var5"}}}, {{"compilation_key", compilation_key}})); + block->AppendAllocatedOp(std::move(cinn_launch_op)); + return program; +} + +struct TestPassContext { + explicit TestPassContext(const ProgramDesc& program) { + graph = std::make_unique(program); + details::BuildStrategy build_strategy; + details::ExecutionStrategy exec_strategy; + exec_strategy.use_device_ = paddle::platform::kCUDA; + executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope, + exec_strategy, build_strategy, + graph.get())); + } + + Scope scope; + std::unique_ptr graph; + std::unique_ptr executor; +}; + +TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) { + // add a subgraph to CinnCompiler + auto subgraph = std::make_unique(BuildProgramInsideCinnLaunchOp()); + subgraph->GetOrInit( + paddle2cinn::kMemOptVarInfoFromMainGraph); + std::string compilation_key = + paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); + + // build test data and apply pass + auto context = std::make_unique( + BuildProgramWithCinnLaunchOp(compilation_key)); + + // check result + const ir::Graph& result_subgraph = + paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key); + const auto& dst_varinfo_map = result_subgraph.Get( + paddle2cinn::kMemOptVarInfoFromMainGraph); + ASSERT_EQ(dst_varinfo_map.size(), 4); + EXPECT_EQ(dst_varinfo_map.count("var1"), 1); + EXPECT_EQ(dst_varinfo_map.count("var5"), 1); + EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2); + EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2); +} + +TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) { + // build test data and apply pass + auto context = + std::make_unique(BuildProgramInsideCinnLaunchOp()); + auto& varinfo_map_shared = context->graph->GetOrInit( + paddle2cinn::kMemOptVarInfoFromMainGraph); + varinfo_map_shared = { + {"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 2)}, + }; + + ir::MemOptVarInfoMapList varinfo_maps(1); + auto& dst_varinfo_map = varinfo_maps.front(); + dst_varinfo_map = {{"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 1)}, + {"var3", std::make_shared("var3", 1)}, + {"var4", std::make_shared("var4", 1)}, + {"var5", std::make_shared("var5", 1)}}; + auto share_pass = + ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); + share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps); + share_pass->Apply(context->graph.get()); + + // check result + ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr); + ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr); +} + +} // namespace paddle::framework diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 4abe3a55b298f1f006a129873b544cf55d252daa..ab259a0fc85abb7600cc49123f242fdfd8dc147b 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -44,6 +44,11 @@ DECLARE_string(deny_cinn_ops); namespace paddle { namespace framework { + +namespace ir { +class MemOptVarInfo; +} // namespace ir + namespace paddle2cinn { using framework::ir::Graph; @@ -369,6 +374,11 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, ExtractNoNeedBufferFeeds(cluster, cluster_inputs)); subgraph->Set>( kNoNeedBufferFeeds, no_need_buffer_feeds.release()); + // initialize empty map for kMemOptVarInfoFromMainGraph attribute, + // it will be filled on the share_mem_opt_info_to_subgraph pass + subgraph->GetOrInit>>( + kMemOptVarInfoFromMainGraph); return subgraph; } diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 10d12f93f8bd83c3768f6951396959d0d9db5634..9bb25b6b52e5466b1665fc080511fbe63d8011df 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -22,6 +22,8 @@ namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds"; +constexpr char kMemOptVarInfoFromMainGraph[] = + "mem_opt_var_info_from_main_graph"; // A pass named BuildCinnPass, the function of this pass is: // diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 024dd26747b8e7db9eec15fd2998cefaeeb931fb..91a7b4e5a11f0054112df9645c4f8b8f3c22501b 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -20,9 +20,6 @@ #include #include #include - -#include "cinn/common/target.h" -#include "cinn/hlir/framework/graph_compiler.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" @@ -30,13 +27,22 @@ #include "paddle/fluid/platform/macros.h" #include "paddle/pten/core/utils/rw_lock.h" -namespace paddle { +namespace cinn { +namespace common { +class Target; +} // namespace common -namespace operators { -namespace details { +namespace hlir::framework { +class GraphCompiler; +class Program; +class Scope; +} // namespace hlir::framework +} // namespace cinn + +namespace paddle { +namespace operators::details { class CinnLaunchContext; -} // namespace details -} // namespace operators +} // namespace operators::details namespace framework { namespace paddle2cinn {