From 9bda10cdc830be52d8f2931551dcb0bcd1a54f3b Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Fri, 11 Nov 2022 11:35:35 +0800 Subject: [PATCH] [Inference] fix mixed precision (#47794) --- .../analysis/passes/convert_to_mixed_precision.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 1f7b83504d0..e9b188d78f1 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -111,7 +111,7 @@ class ConvertToMixedPrecisionPass { black_list_(black_list), place_(paddle::CPUPlace()), executor_(place_) { - black_list_.insert("assign"); + // black_list_.insert("assign"); black_list_.insert("fill_constant"); black_list_.insert("assign_value"); black_list_.insert("eye"); @@ -416,8 +416,7 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { std::unordered_set all_var_names_set; - std::vector> block_var_names_set( - program_desc_->Size()); + std::vector> block_var_names_set(program_desc_->Size()); for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) { for (auto* op : program_desc_->Block(idx).AllOps()) { const auto& in_names = op->InputArgumentNames(); @@ -449,7 +448,7 @@ void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { for (const auto& name : vars_in_multi_block) { vars_in_multi_block_with_pair_.emplace( - name, std::make_pair(VarType::FP32, idx)); + name, std::make_pair(VarType::Type(), idx)); } } } @@ -612,7 +611,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { if (!has_float_input) { support_precision = false; - VLOG(2) << " op doesn't has float input and output, just skip."; + VLOG(2) << " op doesn't has float input, just skip."; } } VLOG(2) << "op type: " << op_type -- GitLab