未验证 提交 d05ec503 编写于 作者: H hong 提交者: GitHub

[NewIR]fix new ir sgd op bug (#55982)

* fix new ir sgd op bug

* fix bug

* fix bug

* update

* revert code
上级 501a51fc
...@@ -386,6 +386,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -386,6 +386,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
op_desc->SetOutput("out", {name}); op_desc->SetOutput("out", {name});
} }
std::set<std::string> input_param_names;
for (auto &param : params) { for (auto &param : params) {
auto &name = param.name(); auto &name = param.name();
auto place = param.place().GetType(); auto place = param.place().GetType();
...@@ -398,6 +399,8 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -398,6 +399,8 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
op_desc->SetAttr("place", static_cast<int>(place)); op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", name); op_desc->SetAttr("name", name);
op_desc->SetOutput("out", {name}); op_desc->SetOutput("out", {name});
input_param_names.insert(name);
} }
std::set<std::string> set_parameter_names; std::set<std::string> set_parameter_names;
...@@ -419,6 +422,10 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -419,6 +422,10 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
continue; continue;
} }
if (input_param_names.count(name)) {
continue;
}
auto op_desc = local_program.MutableBlock(0)->AppendOp(); auto op_desc = local_program.MutableBlock(0)->AppendOp();
op_desc->SetType("shadow_output"); op_desc->SetType("shadow_output");
op_desc->SetAttr("name", name); op_desc->SetAttr("name", name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册