未验证 提交 ddce609e 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][PASS] Fix static kernel pick pass, if op is not int8, but kernel is...

[LITE][PASS] Fix static kernel pick pass, if op is not int8, but kernel is int8. test=develop (#2526)
上级 40137111
......@@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score;
scored.emplace_back(score, std::move(kernel));
......@@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places());
scored.clear();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......
......@@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass {
private:
// Score the kernel.
size_t KernelGrade(const lite::KernelBase& kernel,
size_t KernelGrade(const lite::mir::Node::Stmt& instruct,
const lite::KernelBase& kernel,
const std::vector<Place>& places) {
CHECK_GT(places.size(), 0) << "valid_places is empty.";
float final_score{-1.};
......@@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass {
// valid_places.size() as default.
// where i is the place's index in valid_places array.
// score: score is the weighted sum of target、percision and layout
for (int i = 0; i < place_size; ++i) {
for (size_t i = 0; i < place_size; ++i) {
const auto& place = places[i];
float weight = static_cast<float>(place_size - i) / place_size;
size_t score{};
......@@ -83,8 +84,12 @@ class StaticKernelPickPass : public mir::StmtPass {
(place.precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) ||
place.precision == PRECISION(kAny))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
// score skipped, if kernel is int8, but op is not int8
if (!(kernel.precision() == PRECISION(kInt8) &&
!instruct.op_info()->HasAttr("enable_int8"))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
}
}
VLOG(4) << "[score s2]:" << score;
if (kernel_pick_factors_.IsDataLayoutConsidered() &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册