未验证 提交 624ffdf2 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle inference] fix mixed precision (#47654)

上级 0cbdcdda
......@@ -110,7 +110,14 @@ class ConvertToMixedPrecisionPass {
keep_io_types_(keep_io_types),
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {}
executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run();
......@@ -587,10 +594,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// If the op has no input and output of float type, we will not choose the
// If the op has no input of float type, we will not choose the
// low precision kernel.
{
bool has_float_input_and_output{false};
bool has_float_input{false};
for (auto* in_node : op_node->inputs) {
if (!in_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in_node);
......@@ -598,22 +605,12 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break;
}
}
for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
has_float_input = true;
break;
}
}
if (!has_float_input_and_output) {
if (!has_float_input) {
support_precision = false;
VLOG(2) << " op doesn't has float input and output, just skip.";
}
......@@ -727,7 +724,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
if (vars_in_multi_block_with_pair_.count(real_node->Name()) &&
vars_in_multi_block_with_pair_.at(real_node->Name()).second ==
block_idx) {
block_idx &&
vars_in_multi_block_with_pair_.at(real_node->Name()).first ==
VarType::Type()) {
vars_in_multi_block_with_pair_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册