From 5930a75da2efb37ec964da707cc04e20b86cd149 Mon Sep 17 00:00:00 2001 From: jiaopu Date: Fri, 3 Apr 2020 10:05:00 +0800 Subject: [PATCH] remove arm place as valic place in mlu_postprocess --- lite/api/cxx_api.cc | 2 -- lite/core/mir/mlu_postprocess_pass.cc | 25 +++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 43b7c51aa9..556a9e0af0 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -316,11 +316,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc, } } } -#ifndef LITE_WITH_MLU if (is_quantized_model) { inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); } -#endif Program program(desc, scope_, inner_places); diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index c69584b296..191f1543f3 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -60,8 +60,19 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, CHECK(0) << "Unsupport cast type"; } cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + // create kernels - auto kernels = cast_op->CreateKernels(graph->valid_places()); + auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; bool is_found = false; for (auto& kernel : kernels) { @@ -150,8 +161,18 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto v_places = graph->valid_places(); + for (auto it = v_places.begin(); it != v_places.end();) { + if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) && + it->target != TARGET(kX86)) { + it = v_places.erase(it); + } else { + ++it; + } + } + // create kernels - auto kernels = cast_op->CreateKernels(graph->valid_places()); + auto kernels = cast_op->CreateKernels(v_places); std::vector> selected_kernels; bool is_found = false; for (auto& kernel : kernels) { -- GitLab