未验证 提交 eaf90003 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix mixed precision inference (#49238)

上级 5d8284c3
......@@ -437,6 +437,20 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;
vars_should_not_low_precision.insert(out_var_node->Var()->Name());
}
}
}
}
};
......@@ -449,6 +463,25 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_low_precision.count(
real_in_var_node->Var()->Name())) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册