diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 06ed5bb94db3dda5d1d79ab3160c955a0bf1892b..efcc7cef992e8c26b746357cdddb90a92f072aa3 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -40,26 +40,47 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); }; + // The all of input and output variables of the Ops will not be reused. + std::unordered_set invalid_op_nodes = {"while", + "conditional_block", + "conditional_block_infer", + "merge_lod_tensor_infer", + "merge_lod_tensor", + "equal", + "lod_reset", + "concat", + "yolo_box", + "subgraph", + "feed", + "fetch"}; + + auto insert_invalid_op_nodes_for_specific_target = [&]( + std::unordered_set op_node_set, TargetType specific_target) { + std::unordered_set invalid_op_nodes_opencl = {"layout", "fc"}; + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + TargetType op_target_type = op_node->AsStmt().place().target; + if (op_target_type == specific_target && + specific_target == TARGET(kOpenCL)) { + invalid_op_nodes.insert(invalid_op_nodes_opencl.begin(), + invalid_op_nodes_opencl.end()); + break; + } + // else if // you can add more targets + } + }; + + VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); + insert_invalid_op_nodes_for_specific_target(invalid_op_nodes, + TARGET(kOpenCL)); + VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); + // Collect the invalid input and output variables that will not be reused. std::unordered_set invalid_var_names; for (auto& op_node : graph->StmtTopologicalOrder()) { if (!op_node->IsStmt()) continue; auto op_info = op_node->AsStmt().op_info(); auto op_type = op_info->Type(); - // The all of input and output variables of the Ops will not be reused. - std::unordered_set invalid_op_nodes = { - "while", - "conditional_block", - "conditional_block_infer", - "merge_lod_tensor_infer", - "merge_lod_tensor", - "equal", - "lod_reset", - "concat", - "yolo_box", - "subgraph", - "feed", - "fetch"}; auto invalid_op_node = invalid_op_nodes.find(op_type); if (invalid_op_node != invalid_op_nodes.end()) { for (auto in_var_node : op_node->inlinks) { @@ -282,5 +303,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) - .BindTargets({TARGET(kARM)}) - .ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), TARGET(kBM)}); + .BindTargets({TARGET(kARM), TARGET(kOpenCL)}) + .ExcludeTargets({TARGET(kNPU), TARGET(kXPU), TARGET(kBM)});