diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 4b0a9d9b1c48fcb0d5e44ec1b977c817f3c70b2e..1f4077eec8f970d72aa15f4bc0f1293e6185fe49 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include #include +#include #include #include #include @@ -38,6 +39,14 @@ using framework::ir::Node; using framework::ir::TopologyVarientSort; using space_table_t = MemoryOptimizePass::space_table_t; +typedef struct { + std::string name; + size_t size; + int cluster; + std::pair lifetime; + std::unordered_set adj; +} MemNode; + // Collect the lifecycles of the tensors. // Traverse the graph in topological order. // The traversal order also affect the lifecycles, so different sort_kind is @@ -96,6 +105,89 @@ int DataTypeToSpace(framework::proto::VarType_Type type) { } } +void MemoryOptimizePass::CollectVarMemorySize( + space_table_t* space_table) const { + const int fake_batch_size = 1; + // Collect tensors from graph. + for (auto* node : graph_->Nodes()) { + if (node->IsVar() && + node->Var()->GetType() == + framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { + // Parameters will not be reused. + if (node->Var()->Persistable()) continue; + auto shape = node->Var()->GetShape(); + for (auto& v : shape) { + if (v < 0) v = fake_batch_size; + } + + int size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + (*space_table)[node->Var()->Name()] = + size * DataTypeToSpace(node->Var()->GetDataType()); + } + } +} + +void MakeSimpleReusePlan( + const std::unordered_map>& lifecycles, + const std::unordered_map& space_table, + std::unordered_map* node2cluster, + std::unordered_map* cluster_size) { + std::vector mem_nodes; + for (auto& data : lifecycles) { + MemNode temp_node; + temp_node.name = data.first; + PADDLE_ENFORCE( + space_table.count(data.first), + "%s variable should be in the spacetable during memory optimize", + data.first); + temp_node.size = space_table.at(data.first); + temp_node.cluster = -1; + temp_node.lifetime = data.second; + mem_nodes.push_back(temp_node); + } + auto overlap = [](std::pair a, std::pair b) -> bool { + return b.second >= a.first && a.second >= b.first; + }; + // If the lifetime of two nodes is overwritten, we set them as adjacent nodes. + for (size_t i = 0; i < mem_nodes.size(); i++) { + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (overlap(mem_nodes[i].lifetime, mem_nodes[j].lifetime)) { + mem_nodes[i].adj.insert(mem_nodes[j].name); + mem_nodes[j].adj.insert(mem_nodes[i].name); + } + } + } + + // Sort the nodes according to the node memory size. + auto sort_func = [](MemNode a, MemNode b) { return a.size > b.size; }; + std::sort(mem_nodes.begin(), mem_nodes.end(), sort_func); + + // Generating Memory Reuse Strategy Based on Greedy Way + for (size_t i = 0; i < mem_nodes.size(); i++) { + if (mem_nodes[i].cluster >= 0) continue; + int cluster_index = cluster_size->size(); + mem_nodes[i].cluster = cluster_index; + (*cluster_size)[mem_nodes[i].name] = mem_nodes[i].size; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + std::unordered_set cluster_adj = mem_nodes[i].adj; + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (mem_nodes[j].cluster < 0 && + (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) { + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + mem_nodes[j].cluster = cluster_index; + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } + } + } + } + for (auto& cluster : *cluster_size) { + LOG(INFO) << "Cluster name : " << cluster.first + << " size: " << cluster.second; + } +} + // Collect the memory size of the tensors. void MemoryOptimizePass::CollectVarMemorySize( const std::unordered_map& batch_var_ave_dim, @@ -377,6 +469,17 @@ void UpdateOpDescsByReuse( } } + // modify the graph + for (auto input_node : node->inputs) { + PADDLE_ENFORCE(input_node->IsVar()); + std::string input_node_name = input_node->Name(); + if (reuse_table.count(input_node_name) && + reuse_table.at(input_node_name) != input_node_name) { + auto name = reuse_table.at(input_node_name); + input_node->RenameVar(name); + } + } + for (auto argument : node->Op()->Outputs()) { for (const auto& x : argument.second) { auto name = x; @@ -388,6 +491,17 @@ void UpdateOpDescsByReuse( } } + // modify the graph + for (auto out_node : node->outputs) { + PADDLE_ENFORCE(out_node->IsVar()); + std::string out_node_name = out_node->Name(); + if (reuse_table.count(out_node_name) && + reuse_table.at(out_node_name) != out_node_name) { + auto name = reuse_table.at(out_node_name); + out_node->RenameVar(name); + } + } + // Update arguments. for (auto& arg : in_args) { node->Op()->SetInput(arg.first, arg.second); @@ -589,12 +703,24 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { VLOG(3) << "Load memory cache from " << path; std::vector>> batches; - if (argument->static_memory_optim() && inference::IsFileExists(path)) { + if (!(argument->static_memory_optim() && inference::IsFileExists(path))) { + string::PrettyLogInfo("--- Performing dynamic memory optimize"); + // batches = FakeBatchVarShapes(argument->main_program()); + int sort_kind = 0; + std::unordered_map lifecycles; + space_table_t space_table; + std::unordered_map node2cluster; + std::unordered_map cluster_size; + + CollectLifeCycle(&lifecycles, sort_kind); + CollectVarMemorySize(&space_table); + MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); + UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); + return; + + } else { string::PrettyLogInfo("--- Performing static memory optimize"); batches = DeseralizeBatchVarShapes(path); - } else { - string::PrettyLogInfo("--- Performing dynamic memory optimize"); - batches = FakeBatchVarShapes(argument->main_program()); } auto var_batch_ave_size = GetBatchAverageSize(batches); diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 2da565f2ae15a50a207173b10d4c350456086582..5a907303b4d3ba2d1404de7c5b82527b384aa3de 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include #include #include "paddle/fluid/inference/analysis/analysis_pass.h" @@ -72,6 +74,8 @@ class MemoryOptimizePass : public AnalysisPass { std::unordered_map *lifecycles, int sort_kind) const; + void CollectVarMemorySize(space_table_t *space_table) const; + void CollectVarMemorySize( const std::unordered_map &batch_var_ave_dim, std::unordered_map *tensor_nodes,