未验证 提交 c80d01bf 编写于 作者: W Wilber 提交者: GitHub

fix select fp16 kernel. (#45882)

上级 67d77846
...@@ -677,6 +677,26 @@ void ConvertTensorDtype( ...@@ -677,6 +677,26 @@ void ConvertTensorDtype(
OpSupportPrecision(op_type, backend, tensor_dtype, blacklist); OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << " support low precision " << support_precision; VLOG(2) << " support low precision " << support_precision;
// if op not has float input, we will not choose the low precision kernel.
{
bool has_float_input{false};
for (auto in_node : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map);
if (real_node->Var()->GetDataType() == proto::VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 ||
real_node->Var()->GetDataType() == proto::VarType::BF16) {
has_float_input = true;
break;
}
}
if (!has_float_input) {
support_precision = false;
VLOG(2) << " op doesn't has float input, just skip.";
}
}
if (support_precision) { if (support_precision) {
HandleSpecialOps(op_node->Op()); HandleSpecialOps(op_node->Op());
++num_low_precision; ++num_low_precision;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册