未验证 提交 75ee1a88 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix bug (#56664)

上级 deee91d8
...@@ -300,10 +300,10 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { ...@@ -300,10 +300,10 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
VLOG(4) << "SetVarPrecision done"; VLOG(4) << "SetVarPrecision done";
ConvertWeightsData(); ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done"; VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp(); InsertCastOp();
VLOG(4) << "InsertCastOp done"; VLOG(4) << "InsertCastOp done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
RestoreOpOriginType(); RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done"; VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision [" LOG(INFO) << "The number of ops run at low precision ["
...@@ -355,7 +355,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { ...@@ -355,7 +355,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
if (op_node->Op()->HasAttr("in_dtype")) { if (op_node->Op()->HasAttr("in_dtype")) {
auto* var_node = op_node->inputs[0]; auto* var_node = op_node->inputs[0];
auto* real_var_node = real_vars_[var_node->Var()->Name()]; auto* real_var_node = real_vars_.count(var_node->Var()->Name())
? real_vars_.at(var_node->Var()->Name())
: var_node;
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) { if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
op_node->Op()->SetAttr( op_node->Op()->SetAttr(
"in_dtype", "in_dtype",
...@@ -455,7 +457,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -455,7 +457,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
// not run at low precision. // not run at low precision.
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
if (real_in_var_node->Var()->Persistable()) continue; if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision = support_low_precision =
...@@ -464,7 +466,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -464,7 +466,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
} }
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
if (real_out_var_node->Var()->Persistable()) continue; if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision = support_low_precision =
...@@ -554,7 +556,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -554,7 +556,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue; if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
if (real_in_var_node->Var()->Persistable()) continue; if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_low_precision.count( if (vars_should_not_low_precision.count(
...@@ -573,7 +575,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -573,7 +575,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue; if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
if (real_out_var_node->Var()->Persistable()) continue; if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_low_precision = false; bool not_run_low_precision = false;
...@@ -742,7 +744,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { ...@@ -742,7 +744,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
auto in_var_name = real_in_var_node->Var()->Name(); auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue; if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
...@@ -761,7 +763,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { ...@@ -761,7 +763,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
auto out_var_name = real_out_var_node->Var()->Name(); auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue; if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
...@@ -877,7 +879,7 @@ void AutoMixedPrecisionPass::InsertCastOp() const { ...@@ -877,7 +879,7 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
if (!VarNodeHasDtype(in_var_node)) continue; if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue; if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
auto in_var_type = real_in_var_node->Var()->GetDataType(); auto in_var_type = real_in_var_node->Var()->GetDataType();
......
...@@ -201,7 +201,8 @@ int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const { ...@@ -201,7 +201,8 @@ int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const {
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph); Init(name_scope_, graph);
int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph); int found_count = CleanUselessOp(graph);
found_count += CleanTwoCastOp(graph);
AddStatis(found_count); AddStatis(found_count);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册