未验证 提交 35b03e1c 编写于 作者: T TeFeng Chen 提交者: GitHub

share MemOptVarInfos of external variables into cinn_launch subgraph (#39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake
上级 29d31606
......@@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
static bool CanBeErased(ir::MemOptVarInfo *var_info) {
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
return false;
}
#ifdef PADDLE_WITH_CINN
// if parent_holder exists, it should meet deletion condition too.
std::shared_ptr<ir::MemOptVarInfo> parent_holder = var_info->ParentHolder();
if (parent_holder && !CanBeErased(parent_holder.get())) {
return false;
}
#endif
return true;
}
void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size() || is_variant_scope_) {
vars_.clear();
......@@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (size_t i = 0; i < var_infos_.size(); ++i) {
auto *var_info = var_infos_[i];
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
if (!CanBeErased(var_info)) {
VLOG(4) << "skip memory optimization with var: " << var_info->Name();
continue;
}
......
......@@ -5,8 +5,14 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
if (WITH_CINN)
cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler)
cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op)
list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass)
endif()
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS})
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
......
......@@ -285,6 +285,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto recurrent_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass");
recurrent_op_eager_deletion_pass->Apply(graph);
#ifdef PADDLE_WITH_CINN
auto share_varinfo_into_cinn_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos);
share_varinfo_into_cinn_pass->Apply(graph);
#endif
}
} // namespace ir
......@@ -300,3 +307,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
USE_PASS(conditional_block_op_eager_deletion_pass);
USE_PASS(while_op_eager_deletion_pass);
USE_PASS(recurrent_op_eager_deletion_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(share_varinfo_into_cinn_pass);
#endif
......@@ -66,6 +66,12 @@ class MemOptVarInfo {
return skip_memory_reuse_ || skip_all_memory_optimization_;
}
void SetParentHolder(std::shared_ptr<MemOptVarInfo> parent) {
parent_holder_ = parent;
}
std::shared_ptr<MemOptVarInfo> ParentHolder() const { return parent_holder_; }
const std::string &Name() const { return name_; }
private:
......@@ -88,6 +94,9 @@ class MemOptVarInfo {
std::atomic<size_t> runtime_ref_cnt_;
bool skip_memory_reuse_{false};
bool skip_all_memory_optimization_{false};
// point to var info of the same variable in the main graph,
// used in external(input/output) variables of a subgraph
std::shared_ptr<MemOptVarInfo> parent_holder_{nullptr};
};
using MemOptVarInfoMapList = std::vector<
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle::framework::ir {
using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>;
static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp(
details::ComputationOpHandle* compute_op) {
for (details::VarHandleBase* var : compute_op->Outputs()) {
if (!var->Node()->IsCtrlVar()) {
continue;
}
for (details::OpHandleBase* op : var->PendingOps()) {
auto* eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
return eager_deletion_op;
}
}
}
return nullptr;
}
static void ShareVarInfoToCinnLaunch(
const MemOptVarInfoMapList& varinfo_maps,
details::ComputationOpHandle* cinn_launch_op) {
details::EagerDeletionOpHandle* followed_eager_deletion_op =
FindFollowedEagerDeletionOp(cinn_launch_op);
if (!followed_eager_deletion_op) {
VLOG(4) << "No eager_deletion op found after this cinn_launch op";
return;
}
std::vector<std::string> vars_to_delete =
followed_eager_deletion_op->VarsToDelete();
if (vars_to_delete.empty()) {
VLOG(4) << "No var to be deleted after this cinn_launch op";
return;
}
VLOG(4) << "Variables would be deleted by the eager_deletion_op"
<< " following the cinn_launch:"
<< paddle::string::join_strings(vars_to_delete, ',');
const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey));
auto& dst_varinfo_map =
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
const Name2VarInfoMap& src_varinfo_map =
varinfo_maps.at(cinn_launch_op->GetScopeIdx());
// collect all MemOptVarInfos of external variables
// that would be eager deleted after the cinn_launch subgraph executed,
// and store them as attribute of the subgraph
for (const auto& var_name : vars_to_delete) {
auto it = src_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
dst_varinfo_map.emplace(var_name, it->second);
}
}
static void TakeVarInfoFromMainGraph(
const Name2VarInfoMap& src_varinfo_map,
const MemOptVarInfoMapList& varinfo_maps,
details::EagerDeletionOpHandle* eager_deletion_op) {
const Name2VarInfoMap& dst_varinfo_map =
varinfo_maps.at(eager_deletion_op->GetScopeIdx());
for (auto&& var_name : eager_deletion_op->VarsToDelete()) {
auto dst_it = dst_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
auto src_it = src_varinfo_map.find(var_name);
if (src_it != src_varinfo_map.end()) {
VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder";
dst_it->second->SetParentHolder(src_it->second);
}
}
}
// This pass will be applied on both the main graph and all cinn subgraphs,
// and it distinguishs them according to whether the graph has the
// kMemOptVarInfoFromMainGraph attribute or not.
// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos
// to their subgraphs.
// On a cinn subgraph, it iterates each variable that will be deleted by a
// eager_deletion op, and take the MemOptVarInfo from the main graph
// if such one found.
class ShareMemOptInfoToSubGraphPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
const auto& varinfo_maps = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
// the main graph
if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) {
for (details::OpHandleBase* op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle*>(op);
if (compute_op && compute_op->Name() == "cinn_launch") {
ShareVarInfoToCinnLaunch(varinfo_maps, compute_op);
}
}
} else { // a cinn subgraph
const auto& parent_varinfo_map =
graph->Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
for (details::OpHandleBase* op : all_ops) {
auto eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps,
eager_deletion_op);
}
}
}
}
};
} // namespace paddle::framework::ir
REGISTER_PASS(share_varinfo_into_cinn_pass,
paddle::framework::ir::ShareMemOptInfoToSubGraphPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"
USE_OP(mul);
USE_OP(cinn_launch);
USE_OP(elementwise_add);
namespace paddle::framework {
using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<ir::MemOptVarInfo>>;
static ProgramDesc BuildProgramInsideCinnLaunchOp() {
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var3");
block->Var("var4");
block->Var("var5");
auto add_op = std::unique_ptr<OpDesc>(
new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}},
{{"Out", {"var3"}}}, {}));
block->AppendAllocatedOp(std::move(add_op));
auto mul_op = std::unique_ptr<OpDesc>(new OpDesc(
"mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {}));
block->AppendAllocatedOp(std::move(mul_op));
return program;
}
static ProgramDesc BuildProgramWithCinnLaunchOp(
const std::string& compilation_key) {
// create a cinn_launch op
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var4");
block->Var("var5");
auto cinn_launch_op = std::unique_ptr<OpDesc>(
new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}},
{{"Out", {"var5"}}}, {{"compilation_key", compilation_key}}));
block->AppendAllocatedOp(std::move(cinn_launch_op));
return program;
}
struct TestPassContext {
explicit TestPassContext(const ProgramDesc& program) {
graph = std::make_unique<ir::Graph>(program);
details::BuildStrategy build_strategy;
details::ExecutionStrategy exec_strategy;
exec_strategy.use_device_ = paddle::platform::kCUDA;
executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope,
exec_strategy, build_strategy,
graph.get()));
}
Scope scope;
std::unique_ptr<ir::Graph> graph;
std::unique_ptr<ParallelExecutor> executor;
};
TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
// add a subgraph to CinnCompiler
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
subgraph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
std::string compilation_key =
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));
// build test data and apply pass
auto context = std::make_unique<TestPassContext>(
BuildProgramWithCinnLaunchOp(compilation_key));
// check result
const ir::Graph& result_subgraph =
paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key);
const auto& dst_varinfo_map = result_subgraph.Get<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
ASSERT_EQ(dst_varinfo_map.size(), 4);
EXPECT_EQ(dst_varinfo_map.count("var1"), 1);
EXPECT_EQ(dst_varinfo_map.count("var5"), 1);
EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2);
EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2);
}
TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) {
// build test data and apply pass
auto context =
std::make_unique<TestPassContext>(BuildProgramInsideCinnLaunchOp());
auto& varinfo_map_shared = context->graph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
varinfo_map_shared = {
{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 2)},
};
ir::MemOptVarInfoMapList varinfo_maps(1);
auto& dst_varinfo_map = varinfo_maps.front();
dst_varinfo_map = {{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 1)},
{"var3", std::make_shared<ir::MemOptVarInfo>("var3", 1)},
{"var4", std::make_shared<ir::MemOptVarInfo>("var4", 1)},
{"var5", std::make_shared<ir::MemOptVarInfo>("var5", 1)}};
auto share_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps);
share_pass->Apply(context->graph.get());
// check result
ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr);
ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr);
}
} // namespace paddle::framework
......@@ -44,6 +44,11 @@ DECLARE_string(deny_cinn_ops);
namespace paddle {
namespace framework {
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace paddle2cinn {
using framework::ir::Graph;
......@@ -369,6 +374,11 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
ExtractNoNeedBufferFeeds(cluster, cluster_inputs));
subgraph->Set<std::unordered_set<std::string>>(
kNoNeedBufferFeeds, no_need_buffer_feeds.release());
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<std::unordered_map<
std::string, std::shared_ptr<framework::ir::MemOptVarInfo>>>(
kMemOptVarInfoFromMainGraph);
return subgraph;
}
......
......@@ -22,6 +22,8 @@ namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
// A pass named BuildCinnPass, the function of this pass is:
//
......
......@@ -20,9 +20,6 @@
#include <memory>
#include <string>
#include <unordered_map>
#include "cinn/common/target.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
......@@ -30,13 +27,22 @@
#include "paddle/fluid/platform/macros.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle {
namespace cinn {
namespace common {
class Target;
} // namespace common
namespace operators {
namespace details {
namespace hlir::framework {
class GraphCompiler;
class Program;
class Scope;
} // namespace hlir::framework
} // namespace cinn
namespace paddle {
namespace operators::details {
class CinnLaunchContext;
} // namespace details
} // namespace operators
} // namespace operators::details
namespace framework {
namespace paddle2cinn {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册