未验证 提交 75ee1a88 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix bug (#56664)

上级 deee91d8
......@@ -300,10 +300,10 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision ["
......@@ -355,7 +355,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
if (op_node->Op()->HasAttr("in_dtype")) {
auto* var_node = op_node->inputs[0];
auto* real_var_node = real_vars_[var_node->Var()->Name()];
auto* real_var_node = real_vars_.count(var_node->Var()->Name())
? real_vars_.at(var_node->Var()->Name())
: var_node;
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
op_node->Op()->SetAttr(
"in_dtype",
......@@ -455,7 +457,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
......@@ -464,7 +466,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
......@@ -554,7 +556,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_low_precision.count(
......@@ -573,7 +575,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_low_precision = false;
......@@ -742,7 +744,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
......@@ -761,7 +763,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
......@@ -877,7 +879,7 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
auto in_var_type = real_in_var_node->Var()->GetDataType();
......
......@@ -201,7 +201,8 @@ int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const {
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph);
int found_count = CleanUselessOp(graph);
found_count += CleanTwoCastOp(graph);
AddStatis(found_count);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册