From ddce609ec0a543d8ea4c97d523fd938f0772ea30 Mon Sep 17 00:00:00 2001 From: Yuan Shuai Date: Fri, 29 Nov 2019 15:58:24 +0800 Subject: [PATCH] [LITE][PASS] Fix static kernel pick pass, if op is not int8, but kernel is int8. test=develop (#2526) --- lite/core/mir/static_kernel_pick_pass.cc | 4 ++-- lite/core/mir/static_kernel_pick_pass.h | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 90aca56aec..c49e449709 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { << instruct.op_type(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(*kernel, graph->valid_places()); + float score = KernelGrade(instruct, *kernel, graph->valid_places()); VLOG(4) << "kernel->summary():" << kernel->summary() << " score:" << score; scored.emplace_back(score, std::move(kernel)); @@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { instruct.ResetOp(update_desc, graph->valid_places()); scored.clear(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(*kernel, graph->valid_places()); + float score = KernelGrade(instruct, *kernel, graph->valid_places()); scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index 90be0ea54e..cd54e2654c 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass { private: // Score the kernel. - size_t KernelGrade(const lite::KernelBase& kernel, + size_t KernelGrade(const lite::mir::Node::Stmt& instruct, + const lite::KernelBase& kernel, const std::vector& places) { CHECK_GT(places.size(), 0) << "valid_places is empty."; float final_score{-1.}; @@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass { // valid_places.size() as default. // where i is the place's index in valid_places array. // score: score is the weighted sum of target、percision and layout - for (int i = 0; i < place_size; ++i) { + for (size_t i = 0; i < place_size; ++i) { const auto& place = places[i]; float weight = static_cast(place_size - i) / place_size; size_t score{}; @@ -83,8 +84,12 @@ class StaticKernelPickPass : public mir::StmtPass { (place.precision == kernel.precision() || kernel.precision() == PRECISION(kAny) || place.precision == PRECISION(kAny))) { - score += kMax / static_cast( - core::KernelPickFactor::Factor::PrecisionFirst); + // score skipped, if kernel is int8, but op is not int8 + if (!(kernel.precision() == PRECISION(kInt8) && + !instruct.op_info()->HasAttr("enable_int8"))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::PrecisionFirst); + } } VLOG(4) << "[score s2]:" << score; if (kernel_pick_factors_.IsDataLayoutConsidered() && -- GitLab