diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 43b7c51aa9b282a0722335b61a4337004a99d66f..556a9e0af01854ff5c57a14dade72b81ed255964 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -316,11 +316,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc, } } } -#ifndef LITE_WITH_MLU if (is_quantized_model) { inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); } -#endif Program program(desc, scope_, inner_places); diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index c69584b2961c9a63b565536d33e36d8278f2c8ad..191f1543f3d8097ea9103a2df737c1b1ad7f7721 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -60,8 +60,19 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, CHECK(0) << "Unsupport cast type"; } cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + // create kernels - auto kernels = cast_op->CreateKernels(graph->valid_places()); + auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; bool is_found = false; for (auto& kernel : kernels) { @@ -150,8 +161,18 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + // create kernels - auto kernels = cast_op->CreateKernels(graph->valid_places()); + auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; bool is_found = false; for (auto& kernel : kernels) {