未验证 提交 43fcd01b 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix pd.feed bug in ir BuildScope (#55720)

* fix bug

* fix bug
上级 cbbd940e
......@@ -89,7 +89,7 @@ const std::map<std::string, int>& OpYamlInfoParser::InputName2Id() const {
}
const std::map<std::string, int>& OpYamlInfoParser::OutputName2Id() const {
return input_name2id_;
return output_name2id_;
}
bool OpYamlInfoParser::HasInplace(const std::string& out_name) const {
......
......@@ -199,9 +199,21 @@ void HandleForSpecialOp(
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
std::string name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
paddle::framework::Variable* var = inner_scope->FindVar(name);
auto feed_var_name = "feed_" + std::to_string(index);
value_2_var_name->emplace(value, feed_var_name);
variable_2_var_name->emplace(var, feed_var_name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(feed_var_name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
}
if (op_name == "pd.feed_with_place") {
......@@ -339,6 +351,9 @@ void HandleForInplaceOp(
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value value = op->result(i);
if (value.type().storage() == nullptr) {
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册