未验证 提交 ddce609e 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][PASS] Fix static kernel pick pass, if op is not int8, but kernel is...

[LITE][PASS] Fix static kernel pick pass, if op is not int8, but kernel is int8. test=develop (#2526)
上级 40137111
...@@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< instruct.op_type(); << instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) { for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places()); float score = KernelGrade(instruct, *kernel, graph->valid_places());
VLOG(4) << "kernel->summary():" << kernel->summary() VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score; << " score:" << score;
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
...@@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places()); instruct.ResetOp(update_desc, graph->valid_places());
scored.clear(); scored.clear();
for (auto&& kernel : instruct.kernels()) { for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places()); float score = KernelGrade(instruct, *kernel, graph->valid_places());
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
} }
std::sort(scored.begin(), scored.end(), KernelScoreCmp); std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......
...@@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass {
private: private:
// Score the kernel. // Score the kernel.
size_t KernelGrade(const lite::KernelBase& kernel, size_t KernelGrade(const lite::mir::Node::Stmt& instruct,
const lite::KernelBase& kernel,
const std::vector<Place>& places) { const std::vector<Place>& places) {
CHECK_GT(places.size(), 0) << "valid_places is empty."; CHECK_GT(places.size(), 0) << "valid_places is empty.";
float final_score{-1.}; float final_score{-1.};
...@@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass {
// valid_places.size() as default. // valid_places.size() as default.
// where i is the place's index in valid_places array. // where i is the place's index in valid_places array.
// score: score is the weighted sum of target、percision and layout // score: score is the weighted sum of target、percision and layout
for (int i = 0; i < place_size; ++i) { for (size_t i = 0; i < place_size; ++i) {
const auto& place = places[i]; const auto& place = places[i];
float weight = static_cast<float>(place_size - i) / place_size; float weight = static_cast<float>(place_size - i) / place_size;
size_t score{}; size_t score{};
...@@ -83,9 +84,13 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -83,9 +84,13 @@ class StaticKernelPickPass : public mir::StmtPass {
(place.precision == kernel.precision() || (place.precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) || kernel.precision() == PRECISION(kAny) ||
place.precision == PRECISION(kAny))) { place.precision == PRECISION(kAny))) {
// score skipped, if kernel is int8, but op is not int8
if (!(kernel.precision() == PRECISION(kInt8) &&
!instruct.op_info()->HasAttr("enable_int8"))) {
score += kMax / static_cast<int>( score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst); core::KernelPickFactor::Factor::PrecisionFirst);
} }
}
VLOG(4) << "[score s2]:" << score; VLOG(4) << "[score s2]:" << score;
if (kernel_pick_factors_.IsDataLayoutConsidered() && if (kernel_pick_factors_.IsDataLayoutConsidered() &&
(place.layout == kernel.layout() || (place.layout == kernel.layout() ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册