提交 5930a75d 编写于 作者: J jiaopu 提交者: jackzhang235

remove arm place as valic place in mlu_postprocess

上级 a3a27beb
...@@ -316,11 +316,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -316,11 +316,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
} }
} }
} }
#ifndef LITE_WITH_MLU
if (is_quantized_model) { if (is_quantized_model) {
inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)});
} }
#endif
Program program(desc, scope_, inner_places); Program program(desc, scope_, inner_places);
......
...@@ -60,8 +60,19 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -60,8 +60,19 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
CHECK(0) << "Unsupport cast type"; CHECK(0) << "Unsupport cast type";
} }
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
// create kernels // create kernels
auto kernels = cast_op->CreateKernels(graph->valid_places()); auto kernels = cast_op->CreateKernels(v_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels; std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false; bool is_found = false;
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
...@@ -150,8 +161,18 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -150,8 +161,18 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
// create kernels // create kernels
auto kernels = cast_op->CreateKernels(graph->valid_places()); auto kernels = cast_op->CreateKernels(v_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels; std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false; bool is_found = false;
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册