From b601d81fd761735c913f7b3ff4b44866c6ec8b55 Mon Sep 17 00:00:00 2001 From: Yuan Shuai Date: Thu, 5 Mar 2020 10:30:48 +0800 Subject: [PATCH] [LITE][PASS][OPENCL] Fix memory_resuse for opencl (#3077) * Fix memory_resuse for opencl. test=develop * remove useless code. test=develop --- lite/core/mir/memory_optimize_pass.cc | 53 +++++++++++++++++++-------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 06ed5bb94d..efcc7cef99 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -40,26 +40,47 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); }; + // The all of input and output variables of the Ops will not be reused. + std::unordered_set invalid_op_nodes = {"while", + "conditional_block", + "conditional_block_infer", + "merge_lod_tensor_infer", + "merge_lod_tensor", + "equal", + "lod_reset", + "concat", + "yolo_box", + "subgraph", + "feed", + "fetch"}; + + auto insert_invalid_op_nodes_for_specific_target = [&]( + std::unordered_set op_node_set, TargetType specific_target) { + std::unordered_set invalid_op_nodes_opencl = {"layout", "fc"}; + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + TargetType op_target_type = op_node->AsStmt().place().target; + if (op_target_type == specific_target && + specific_target == TARGET(kOpenCL)) { + invalid_op_nodes.insert(invalid_op_nodes_opencl.begin(), + invalid_op_nodes_opencl.end()); + break; + } + // else if // you can add more targets + } + }; + + VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); + insert_invalid_op_nodes_for_specific_target(invalid_op_nodes, + TARGET(kOpenCL)); + VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); + // Collect the invalid input and output variables that will not be reused. std::unordered_set invalid_var_names; for (auto& op_node : graph->StmtTopologicalOrder()) { if (!op_node->IsStmt()) continue; auto op_info = op_node->AsStmt().op_info(); auto op_type = op_info->Type(); - // The all of input and output variables of the Ops will not be reused. - std::unordered_set invalid_op_nodes = { - "while", - "conditional_block", - "conditional_block_infer", - "merge_lod_tensor_infer", - "merge_lod_tensor", - "equal", - "lod_reset", - "concat", - "yolo_box", - "subgraph", - "feed", - "fetch"}; auto invalid_op_node = invalid_op_nodes.find(op_type); if (invalid_op_node != invalid_op_nodes.end()) { for (auto in_var_node : op_node->inlinks) { @@ -282,5 +303,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) - .BindTargets({TARGET(kARM)}) - .ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), TARGET(kBM)}); + .BindTargets({TARGET(kARM), TARGET(kOpenCL)}) + .ExcludeTargets({TARGET(kNPU), TARGET(kXPU), TARGET(kBM)}); -- GitLab