未验证 提交 9bda10cd 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] fix mixed precision (#47794)

上级 78c8c7de
...@@ -111,7 +111,7 @@ class ConvertToMixedPrecisionPass { ...@@ -111,7 +111,7 @@ class ConvertToMixedPrecisionPass {
black_list_(black_list), black_list_(black_list),
place_(paddle::CPUPlace()), place_(paddle::CPUPlace()),
executor_(place_) { executor_(place_) {
black_list_.insert("assign"); // black_list_.insert("assign");
black_list_.insert("fill_constant"); black_list_.insert("fill_constant");
black_list_.insert("assign_value"); black_list_.insert("assign_value");
black_list_.insert("eye"); black_list_.insert("eye");
...@@ -416,8 +416,7 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { ...@@ -416,8 +416,7 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::unordered_set<std::string> all_var_names_set; std::unordered_set<std::string> all_var_names_set;
std::vector<std::unordered_set<std::string>> block_var_names_set( std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
program_desc_->Size());
for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) { for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) {
for (auto* op : program_desc_->Block(idx).AllOps()) { for (auto* op : program_desc_->Block(idx).AllOps()) {
const auto& in_names = op->InputArgumentNames(); const auto& in_names = op->InputArgumentNames();
...@@ -449,7 +448,7 @@ void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { ...@@ -449,7 +448,7 @@ void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
for (const auto& name : vars_in_multi_block) { for (const auto& name : vars_in_multi_block) {
vars_in_multi_block_with_pair_.emplace( vars_in_multi_block_with_pair_.emplace(
name, std::make_pair(VarType::FP32, idx)); name, std::make_pair(VarType::Type(), idx));
} }
} }
} }
...@@ -612,7 +611,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -612,7 +611,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
if (!has_float_input) { if (!has_float_input) {
support_precision = false; support_precision = false;
VLOG(2) << " op doesn't has float input and output, just skip."; VLOG(2) << " op doesn't has float input, just skip.";
} }
} }
VLOG(2) << "op type: " << op_type VLOG(2) << "op type: " << op_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册