提交 7f4a353f 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][PASS][OPENCL] Fix memory_resuse for opencl (#3077)

* Fix memory_resuse for opencl. test=develop

* remove useless code. test=develop
上级 03f8fc19
......@@ -40,15 +40,8 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
};
// Collect the invalid input and output variables that will not be reused.
std::unordered_set<std::string> invalid_var_names;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
// The all of input and output variables of the Ops will not be reused.
std::unordered_set<std::string> invalid_op_nodes = {
"while",
std::unordered_set<std::string> invalid_op_nodes = {"while",
"conditional_block",
"conditional_block_infer",
"merge_lod_tensor_infer",
......@@ -60,6 +53,34 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"subgraph",
"feed",
"fetch"};
auto insert_invalid_op_nodes_for_specific_target = [&](
std::unordered_set<std::string> op_node_set, TargetType specific_target) {
std::unordered_set<std::string> invalid_op_nodes_opencl = {"layout", "fc"};
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
TargetType op_target_type = op_node->AsStmt().place().target;
if (op_target_type == specific_target &&
specific_target == TARGET(kOpenCL)) {
invalid_op_nodes.insert(invalid_op_nodes_opencl.begin(),
invalid_op_nodes_opencl.end());
break;
}
// else if // you can add more targets
}
};
VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size();
insert_invalid_op_nodes_for_specific_target(invalid_op_nodes,
TARGET(kOpenCL));
VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size();
// Collect the invalid input and output variables that will not be reused.
std::unordered_set<std::string> invalid_var_names;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
auto invalid_op_node = invalid_op_nodes.find(op_type);
if (invalid_op_node != invalid_op_nodes.end()) {
for (auto in_var_node : op_node->inlinks) {
......@@ -282,5 +303,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.BindTargets({TARGET(kARM)})
.ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), TARGET(kBM)});
.BindTargets({TARGET(kARM), TARGET(kOpenCL)})
.ExcludeTargets({TARGET(kNPU), TARGET(kXPU), TARGET(kBM)});
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册