diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 3b1a03d364786954f13b73519a3491faec8ca45d..f3c32bc8292f15ca87fbd049ef7cc5868e1caee5 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -62,15 +62,6 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto v_places = graph->valid_places(); - for (auto it = v_places.begin(); it != v_places.end();) { - if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && - it->target != TARGET(kX86)) { - it = v_places.erase(it); - } else { - ++it; - } - } - // create kernels auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; @@ -157,15 +148,6 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto v_places = graph->valid_places(); - for (auto it = v_places.begin(); it != v_places.end();) { - if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && - it->target != TARGET(kX86)) { - it = v_places.erase(it); - } else { - ++it; - } - } - // create kernels auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc index 5c5dc3204b8728e8b30661fae21b056db6960179..a378e2ed642f3da940636c1bcf8637e8bf0d8f88 100644 --- a/lite/core/mir/subgraph/subgraph_pass.cc +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -84,6 +84,20 @@ void RKNPUSubgraphPass::Apply(const std::unique_ptr& graph) { } void MLUSubgraphPass::Apply(const std::unique_ptr& graph) { +#ifdef LITE_WITH_MLU + // remove invalid places, since only support X86, host, MLU + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + graph->SetValidPlaces(v_places); +#endif + std::unordered_set supported_lists; #define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); #include "lite/kernels/mlu/bridges/paddle_use_bridges.h"