diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index e43c0f276881aa1a0b87731cba06a68014ba945d..70355fdf890eb63cd5bedd5bab42a2dd69af0927 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -39,3 +39,4 @@ USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); +USE_MIR_PASS(memory_optimize_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 6dfc2bd295915f9970cf056267d59a90b2d2ba53..7d967b15c4eaa13d9a98f129addfcf316350b6b5 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -31,6 +31,7 @@ lite_cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc + memory_optimize_pass.cc DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..24d00f4b7457d294acfbd0e883984f30ce1b4cd9 --- /dev/null +++ b/lite/core/mir/memory_optimize_pass.cc @@ -0,0 +1,264 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/memory_optimize_pass.h" +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace mir { + +typedef struct { + std::string name; + int cluster; + std::pair lifetime; + std::unordered_set adj; +} MemNode; + +void MemoryOptimizePass::CollectLifeCycleByDevice( + std::unordered_map* lifecycles, + SSAGraph* graph) { + max_lifecycle_ = 0; + + auto is_host = [](TargetType x) -> bool { + return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); + }; + // The vars which inputs or outputs are invalid op will not be reused. + auto valid_var = [&](Node* node) -> bool { + std::set invalid_op = {"while", + "conditional_block", + "conditional_block_infer", + "merge_lod_tensor_infer", + "merge_lod_tensor", + "equal", + "lod_reset", + "concat", + "graph_op"}; + for (auto* tmp : node->inlinks) { + CHECK(tmp->IsStmt()); + std::string op_type = tmp->AsStmt().op_info()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + for (auto* tmp : node->outlinks) { + CHECK(tmp->IsStmt()); + std::string op_type = tmp->AsStmt().op_info()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + return true; + }; + + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (op_node->IsStmt()) { + auto inputs = op_node->inlinks; + auto outputs = op_node->outlinks; + std::vector requires(inputs.begin(), inputs.end()); + requires.insert(requires.end(), outputs.begin(), outputs.end()); + auto& stmt = op_node->AsStmt(); + // The feed and fetch op's inputs and outputs will not be reused. + if (stmt.op_info()->Type() == "feed" || + stmt.op_info()->Type() == "fetch") { + for (auto* node : op_node->outlinks) { + CHECK(node->IsArg()); + std::string var_name = node->AsArg().name; + TargetType target_type = node->AsArg().type->target(); + if (is_host(target_type)) target_type = TARGET(kHost); + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(0, std::numeric_limits::max())); + } + } else { + for (Node* node : requires) { + CHECK(node->IsArg()); + auto& arg = node->AsArg(); + if (arg.is_weight || arg.is_persist) continue; + if (!valid_var(node)) continue; + std::string var_name = arg.name; + TargetType target_type = node->AsArg().type->target(); + if (is_host(target_type)) target_type = TARGET(kHost); + + if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + } else { + int cur_life = + (*lifecycles)[TargetToStr(target_type)][var_name].second; + (*lifecycles)[TargetToStr(target_type)][var_name].second = + std::max(max_lifecycle_, cur_life); + } + } + } + ++max_lifecycle_; + } + } + LOG(INFO) << "There are " << (*lifecycles).size() << " types device var."; +} + +void MemoryOptimizePass::MakeReusePlan( + const lifecycle_map_t& lifecycles, + std::unordered_map* node2cluster) { + std::vector mem_nodes; + std::vector cluster; + for (auto& data : lifecycles) { + MemNode temp_node; + temp_node.name = data.first; + temp_node.cluster = -1; + temp_node.lifetime = data.second; + mem_nodes.push_back(temp_node); + } + auto overlap = [](std::pair a, std::pair b) -> bool { + return b.second >= a.first && a.second >= b.first; + }; + // If the lifetime of two nodes is overwritten, we set them as adjacent nodes. + for (size_t i = 0; i < mem_nodes.size(); i++) { + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (overlap(mem_nodes[i].lifetime, mem_nodes[j].lifetime)) { + mem_nodes[i].adj.insert(mem_nodes[j].name); + mem_nodes[j].adj.insert(mem_nodes[i].name); + } + } + } + + // Generating Memory Reuse Strategy Based on Greedy Way + // The vars can be reused if there is no overlap between them. + for (size_t i = 0; i < mem_nodes.size(); i++) { + if (mem_nodes[i].cluster >= 0) continue; + int cluster_index = cluster.size(); + mem_nodes[i].cluster = cluster_index; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + cluster.push_back(mem_nodes[i].name); + std::unordered_set cluster_adj = mem_nodes[i].adj; + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (mem_nodes[j].cluster < 0 && + (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) { + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + mem_nodes[j].cluster = cluster_index; + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } + } + } + } + for (auto& name : cluster) { + LOG(INFO) << "cluster: " << name; + } +} + +void MemoryOptimizePass::PerformReusePlan( + SSAGraph* graph, + const std::unordered_map& reuse_table) { + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto& stmt = op_node->AsStmt(); + auto* op_info = stmt.mutable_op_info(); + std::unordered_map> in_args, out_args; + // replace the op's input according the reuse table. + for (auto argument : op_info->inputs()) { + for (const auto& x : argument.second) { + auto name = x; + if (reuse_table.count(x) && reuse_table.at(x) != x) { + name = reuse_table.at(x); + } + in_args[argument.first].push_back(name); + VLOG(4) << op_info->Type() << " input " << x << " -> " << name; + } + } + + // modify the graph + for (Node* input_node : op_node->inlinks) { + CHECK(input_node->IsArg()) << "The op node's inputs should be var node."; + std::string name = input_node->AsArg().name; + if (reuse_table.count(name) && reuse_table.at(name) != name) { + auto replace_name = reuse_table.at(name); + input_node->AsArg().name = replace_name; + } + } + + // replace the op's output according the reuse table. + for (auto argument : op_info->outputs()) { + for (const auto& x : argument.second) { + auto name = x; + if (reuse_table.count(x) && reuse_table.at(x) != x) { + name = reuse_table.at(x); + } + out_args[argument.first].push_back(name); + VLOG(4) << op_info->Type() << " output " << x << " -> " << name; + } + } + + // modify the graph + for (Node* out_node : op_node->outlinks) { + CHECK(out_node->IsArg()) << "The op node's outputs should be var node."; + std::string name = out_node->AsArg().name; + if (reuse_table.count(name) && reuse_table.at(name) != name) { + auto replace_name = reuse_table.at(name); + out_node->AsArg().name = replace_name; + } + } + + for (auto& arg : in_args) { + op_info->SetInput(arg.first, arg.second); + } + for (auto& arg : out_args) { + op_info->SetOutput(arg.first, arg.second); + } + + auto original_selected_kernel = std::move(stmt.kernels().front()); + auto updated_op_info = *stmt.mutable_op_info(); + stmt.ResetOp(updated_op_info, graph->valid_places()); + stmt.kernels().clear(); + stmt.kernels().emplace_back(std::move(original_selected_kernel)); + for (auto& kernel : stmt.kernels()) { + VLOG(4) << "kernel info: " << kernel->name(); + stmt.op()->AttachKernel(kernel.get()); + } + graph->CheckValid(); + } +} + +void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { + // Memory optimization. + // We will perform the following operation: + // 1. Collect all var's lifetime, then classify them according to the device. + // Only the vars on the same device can be reused. + // 2. Make reuse plan: the vars can be reused if there is no overlap between + // them. + // The final plan is a mapping table in which the key represents the original + // name of var and the value in the table represents the current name of var. + // 3. Perform reuse plan: Replace all var's name in the model according to the + // mapping table. + std::unordered_map lifecycles; + CollectLifeCycleByDevice(&lifecycles, graph.get()); + for (auto& ele : lifecycles) { + std::unordered_map node2cluster; + MakeReusePlan(ele.second, &node2cluster); + PerformReusePlan(graph.get(), node2cluster); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) + .SetTargets({TARGET(kARM)}); diff --git a/lite/core/mir/memory_optimize_pass.h b/lite/core/mir/memory_optimize_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..874fb648cd05931175159bad43e7be38a7aee928 --- /dev/null +++ b/lite/core/mir/memory_optimize_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lite/core/kernel.h" +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * MemoryOptimizePass will + */ +class MemoryOptimizePass : public ProgramPass { + public: + using lifecycle_t = std::pair; + using lifecycle_map_t = std::unordered_map; + void Apply(const std::unique_ptr& graph) override; + + private: + void CollectLifeCycleByDevice( + std::unordered_map* lifecycles, SSAGraph*); + void MakeReusePlan( + const lifecycle_map_t& lifecycles, + std::unordered_map* node2cluster); + void PerformReusePlan( + SSAGraph* graph, + const std::unordered_map& reuse_table); + + private: + int max_lifecycle_{-1}; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 031ffded4508e972ee67dacb83475143995b60ec..4a0e95e26654dda58bd88828042b05bedddbc684 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -93,7 +93,7 @@ class Optimizer { "argument_type_display_pass", // "runtime_context_assign_pass", - "graph_visualze"}}); + "memory_optimize_pass"}}); } else { RunPasses(passes); }