未验证 提交 624ffdf2 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle inference] fix mixed precision (#47654)

上级 0cbdcdda
...@@ -110,7 +110,14 @@ class ConvertToMixedPrecisionPass { ...@@ -110,7 +110,14 @@ class ConvertToMixedPrecisionPass {
keep_io_types_(keep_io_types), keep_io_types_(keep_io_types),
black_list_(black_list), black_list_(black_list),
place_(paddle::CPUPlace()), place_(paddle::CPUPlace()),
executor_(place_) {} executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run(); void Run();
...@@ -587,10 +594,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -587,10 +594,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
bool support_precision = bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// If the op has no input and output of float type, we will not choose the // If the op has no input of float type, we will not choose the
// low precision kernel. // low precision kernel.
{ {
bool has_float_input_and_output{false}; bool has_float_input{false};
for (auto* in_node : op_node->inputs) { for (auto* in_node : op_node->inputs) {
if (!in_node->IsVar()) continue; if (!in_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in_node); auto* real_node = GetRealVarNode(block_idx, in_node);
...@@ -598,22 +605,12 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -598,22 +605,12 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
real_node->Var()->GetDataType() == VarType::FP32 || real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 || real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) { real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true; has_float_input = true;
break;
}
}
for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break; break;
} }
} }
if (!has_float_input_and_output) {
if (!has_float_input) {
support_precision = false; support_precision = false;
VLOG(2) << " op doesn't has float input and output, just skip."; VLOG(2) << " op doesn't has float input and output, just skip.";
} }
...@@ -727,7 +724,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -727,7 +724,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
if (vars_in_multi_block_with_pair_.count(real_node->Name()) && if (vars_in_multi_block_with_pair_.count(real_node->Name()) &&
vars_in_multi_block_with_pair_.at(real_node->Name()).second == vars_in_multi_block_with_pair_.at(real_node->Name()).second ==
block_idx) { block_idx &&
vars_in_multi_block_with_pair_.at(real_node->Name()).first ==
VarType::Type()) {
vars_in_multi_block_with_pair_.at(real_node->Name()).first = vars_in_multi_block_with_pair_.at(real_node->Name()).first =
real_node->Var()->GetDataType(); real_node->Var()->GetDataType();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册