未验证 提交 9877fb88 编写于 作者: H hong 提交者: GitHub

[NewIR]filter new ir inplace var set parameter (#55979)

* filter new ir inplace var set parameter

* polish code

* fix conflict

* fix typo
上级 2d91a9bd
...@@ -1619,7 +1619,7 @@ void NewIRInterpreter::BuildInstruction() { ...@@ -1619,7 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
if (op_name == "builtin.combine" || op_name == "pd.feed" || if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" || op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shaddow_output") { op_name == "pd.data" || op_name == "pd.shadow_output") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
......
...@@ -328,12 +328,19 @@ void HandleForSpecialOp( ...@@ -328,12 +328,19 @@ void HandleForSpecialOp(
// change opreand name to param_name // change opreand name to param_name
auto orig_name = value_2_var_name->at(value); auto orig_name = value_2_var_name->at(value);
PADDLE_ENFORCE_NE(
param_name,
orig_name,
phi::errors::PreconditionNotMet(
"SetParamer param name should not equal with var name"));
if (inner_scope->root()->FindVar(param_name) == nullptr) { if (inner_scope->root()->FindVar(param_name) == nullptr) {
const_cast<paddle::framework::Scope*>(inner_scope->root()) const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, param_name); ->Rename(orig_name, param_name);
VLOG(6) << "set_parameter rename var: " << orig_name << " -> " VLOG(6) << "set_parameter rename var: " << orig_name << " -> "
<< param_name; << param_name;
} }
RenameData(value, RenameData(value,
param_name, param_name,
orig_name, orig_name,
......
...@@ -50,7 +50,7 @@ std::unordered_map<std::string, phi::DataType> Str2PhiDataType = { ...@@ -50,7 +50,7 @@ std::unordered_map<std::string, phi::DataType> Str2PhiDataType = {
}; };
const std::unordered_set<std::string> UnchangeOutputOps = { const std::unordered_set<std::string> UnchangeOutputOps = {
"pd.feed_with_place", "pd.data",
"builtin.combine", "builtin.combine",
"builtin.slice", "builtin.slice",
"pd.feed", "pd.feed",
......
...@@ -164,7 +164,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { ...@@ -164,7 +164,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
auto& op_translator = OpTranslator::instance(); auto& op_translator = OpTranslator::instance();
for (auto op : block.AllOps()) { for (auto op : block.AllOps()) {
OpTranslateFn& fn = op_translator[op->Type()]; OpTranslateFn& fn = op_translator[op->Type()];
if (op->Type() == "shaddow_output") { if (op->Type() == "shadow_output") {
if (!param_map_.count(op->Input("x")[0])) { if (!param_map_.count(op->Input("x")[0])) {
continue; continue;
} }
...@@ -177,6 +177,14 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { ...@@ -177,6 +177,14 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
const auto& ops = block.AllOps(); const auto& ops = block.AllOps();
for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) { for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) {
if ((*op_desc)->Type() == "data") {
continue;
}
const auto& input_var_names = (*op_desc)->InputArgumentNames();
std::unordered_set<std::string> set_input_var_names(input_var_names.begin(),
input_var_names.end());
for (const auto& n : (*op_desc)->Outputs()) { for (const auto& n : (*op_desc)->Outputs()) {
const auto& output_var_names = n.second; const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) { for (const auto& var_name : output_var_names) {
...@@ -184,6 +192,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ...@@ -184,6 +192,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
parameter_name_mappings_.end()); parameter_name_mappings_.end());
need_set_parameter_op &= (parameter_visited_.count(var_name) == 0); need_set_parameter_op &= (parameter_visited_.count(var_name) == 0);
need_set_parameter_op &= (param_map_.count(var_name) != 0); need_set_parameter_op &= (param_map_.count(var_name) != 0);
need_set_parameter_op &= (!set_input_var_names.count(var_name));
if (need_set_parameter_op) { if (need_set_parameter_op) {
ir::OpResult defining_op_result = param_map_[var_name].value; ir::OpResult defining_op_result = param_map_[var_name].value;
if (!defining_op_result) { if (!defining_op_result) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册