From 238a26996477f1cb47eeb012fd183711a1a78008 Mon Sep 17 00:00:00 2001 From: dingminghui Date: Thu, 23 Apr 2020 16:20:59 +0800 Subject: [PATCH] fix(place): remove invalid place in subgraph_pass --- lite/core/mir/mlu_postprocess_pass.cc | 18 ------------------ lite/core/mir/subgraph/subgraph_pass.cc | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 3b1a03d364..f3c32bc829 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -62,15 +62,6 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto v_places = graph->valid_places(); - for (auto it = v_places.begin(); it != v_places.end();) { - if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && - it->target != TARGET(kX86)) { - it = v_places.erase(it); - } else { - ++it; - } - } - // create kernels auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; @@ -157,15 +148,6 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto v_places = graph->valid_places(); - for (auto it = v_places.begin(); it != v_places.end();) { - if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && - it->target != TARGET(kX86)) { - it = v_places.erase(it); - } else { - ++it; - } - } - // create kernels auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc index 5c5dc3204b..a378e2ed64 100644 --- a/lite/core/mir/subgraph/subgraph_pass.cc +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -84,6 +84,20 @@ void RKNPUSubgraphPass::Apply(const std::unique_ptr& graph) { } void MLUSubgraphPass::Apply(const std::unique_ptr& graph) { +#ifdef LITE_WITH_MLU + // remove invalid places, since only support X86, host, MLU + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + graph->SetValidPlaces(v_places); +#endif + std::unordered_set supported_lists; #define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); #include "lite/kernels/mlu/bridges/paddle_use_bridges.h" -- GitLab