From 11c7f57064bc81483fcf5af66dd4ad4e723563bd Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 22 Dec 2022 16:06:38 +0800 Subject: [PATCH] Fix mixed precision bug (#49239) * [Release2.4] Revert python link prs (#48573) * Revert "Fix mac link python (#48017)" This reverts commit 3fa7a736e32508e797616b6344d97814c37d3ff8. * Revert "[Cherry-pick] Fix python link error (#47811)" This reverts commit ff642c68d6681596844b5c1bae695a81d1baf3da. * Update config.go * fix mixed precision inference Co-authored-by: Chen Weihang --- .../framework/ir/auto_mixed_precision_pass.cc | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index bc03430198..44b41a8970 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -437,6 +437,20 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { vars_should_not_low_precision.insert(in_var_node->Var()->Name()); } } + + // when op_1 only support cpu kernel. if op_2's intput var is op_1's + // output var, then op_2 should not run half. + if (GetOpOriginalType(op_type) != "feed" && + !GpuKernelSupportPrecision(GetOpOriginalType(op_type), + phi::DataType::FLOAT32)) { + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (out_var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(out_var_node)) continue; + + vars_should_not_low_precision.insert(out_var_node->Var()->Name()); + } + } } } }; @@ -449,6 +463,25 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { for (auto* op_node : nodes) { if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (!VarNodeHasDtype(in_var_node)) continue; + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + if (vars_should_not_low_precision.count( + real_in_var_node->Var()->Name())) { + op_run_low_precision_.erase(op_node->Op()->Type()); + precision_updated = true; + VLOG(4) << op_node->Op()->Type() + << " should not run at low precision."; + break; + } + } + + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; + for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); if (!VarNodeHasDtype(out_var_node)) continue; -- GitLab