diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index bb079937329c2524d976f0700cfd44a702f46a40..814e766ffce7a157c5882c865514351264dfc0a1 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -368,11 +368,15 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( } // add fetch with place op to program + auto *block = local_program.MutableBlock(0); for (auto &in_t : x) { auto name = in_t.name(); + if (block->FindVarRecursive(name) == nullptr) { + continue; + } auto place = in_t.place().GetType(); - auto op_desc = local_program.MutableBlock(0)->PrependOp(); + auto op_desc = block->PrependOp(); op_desc->SetType("feed_with_place"); op_desc->SetAttr("index", 0); // TODO(phlrain) : using tensor dtype @@ -429,13 +433,17 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( auto local_program = paddle::framework::ProgramDesc(*(backward_global_block->Program())); // add feed kernel + auto *block = local_program.MutableBlock(0); for (auto &out_grad_t : out_grad) { auto name = out_grad_t.name(); + if (block->FindVarRecursive(name) == nullptr) { + continue; + } auto place = out_grad_t.place().GetType(); if (name == "@EMPTY@") { continue; } - auto op_desc = local_program.MutableBlock(0)->PrependOp(); + auto op_desc = block->PrependOp(); op_desc->SetType("feed_with_place"); op_desc->SetAttr("index", 0); // TODO(phlrain) : using tensor dtype diff --git a/paddle/fluid/framework/feed_fetch_method.cc b/paddle/fluid/framework/feed_fetch_method.cc index c089e4f0c13f95200ce5c25d170076a0557016dd..78e4abbd61e2814ce1810d6ffed3b44e7160208f 100644 --- a/paddle/fluid/framework/feed_fetch_method.cc +++ b/paddle/fluid/framework/feed_fetch_method.cc @@ -38,10 +38,9 @@ void SetFeedVariable(Scope* scope, VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; if (FLAGS_enable_new_ir_in_executor) { // shared data with input tensor - auto inner_var_name = var_name + "_" + std::to_string(index); - auto feed_ele = scope->Var(inner_var_name); + auto feed_ele = scope->Var(var_name); if (!feed_ele->IsType()) { - VLOG(3) << "Reset " << inner_var_name << " to phi::DenseTensor"; + VLOG(3) << "Reset " << var_name << " to phi::DenseTensor"; feed_ele->Clear(); } auto val = feed_ele->GetMutable(); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index df94728559a79bbf1d3fd9129eb6b8c7160b1447..af13dba92f0ec04a4af68e56600ea4662307f886 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -1628,9 +1628,18 @@ std::string NewIRInterpreter::DebugValueInfo() { << "value -> var_name -> id -> variable*" << "\n"; for (auto kv : value_2_var_name_) { + PADDLE_ENFORCE((bool)kv.first, + platform::errors::PreconditionNotMet( + "vlaue(%s) should not be nullptr", kv.second)); + PADDLE_ENFORCE(var_name_2_id_.count(kv.second) > 0, + platform::errors::PreconditionNotMet( + "var(%s) should exist in var_name_2_id_", kv.second)); + auto* var = InnerScope()->FindVar(kv.second); + PADDLE_ENFORCE(var != nullptr, + platform::errors::PreconditionNotMet( + "var(%s) should exist in var_name_2_id_", kv.second)); os << kv.first.impl() << " -> " << kv.second << " -> " - << var_name_2_id_.at(kv.second) << " -> " - << InnerScope()->FindVar(kv.second) << "\n"; + << var_name_2_id_.at(kv.second) << " -> " << var << "\n"; } return os.str(); } diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index d78a1291b1543045f1f2ca1f92523039a4be7bed..5dd27a04ad7cf03750935a4d7f923086f484a742 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -200,17 +200,17 @@ void HandleForSpecialOp( auto value = op->result(0); VLOG(6) << "link feed output to feed in variable" << inner_scope; - int index = - op->attributes().at("col").dyn_cast().data(); std::string name = op->attributes().at("name").dyn_cast().AsString(); paddle::framework::Variable* var = inner_scope->FindVar(name); + PADDLE_ENFORCE(var, + paddle::platform::errors::InvalidArgument( + "The variable %s shoud exist", name)); - auto feed_var_name = "feed_" + std::to_string(index); - value_2_var_name->emplace(value, feed_var_name); - variable_2_var_name->emplace(var, feed_var_name); + value_2_var_name->emplace(value, name); + variable_2_var_name->emplace(var, name); auto id = var_name_2_id->size(); - var_name_2_id->emplace(feed_var_name, id); + var_name_2_id->emplace(name, id); variable_list->push_back(var); PADDLE_ENFORCE_EQ( variable_list->size(), @@ -226,6 +226,21 @@ void HandleForSpecialOp( auto value = op->result(0); value_2_var_name->emplace(value, var_name); + + paddle::framework::Variable* var = inner_scope->FindVar(var_name); + PADDLE_ENFORCE(var, + paddle::platform::errors::InvalidArgument( + "The variable %s shoud exist", var_name)); + + variable_2_var_name->emplace(var, var_name); + auto id = var_name_2_id->size(); + var_name_2_id->emplace(var_name, id); + variable_list->push_back(var); + PADDLE_ENFORCE_EQ( + variable_list->size(), + var_name_2_id->size(), + paddle::platform::errors::InvalidArgument( + "The size of variable_list and var_name_2_id map should be equal")); } if (op_name == "builtin.combine") { diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 5b03badf25f9774f7baac3545e4b6d2568ee8cb6..f454811b08ff39573acfd56072b4a0073793c159 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -475,6 +475,9 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, std::string legacy_output_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " " + << legacy_output_name; + // return empty type if this arg is optional and not shown in OpDesc if (!op_desc.HasOutput(legacy_output_name)) { VLOG(10) << "[output translating]" @@ -491,11 +494,19 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, const auto& legacy_output_vars = op_desc.Output(legacy_output_name); bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " " + << legacy_output_name << " " << legacy_output_vars.size() << " " + << is_vector; + // Specially process TensorArray, this because we cannot distinguish it with // Vector by other conditions but we cannot support it like // Vector if (legacy_output_vars.size() == 1) { VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]); + IR_ENFORCE(var != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + legacy_output_vars[0]); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { ir::Type translated_var_type = @@ -519,6 +530,10 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, auto& var_name = legacy_output_vars[0]; VarDesc* var = block->FindVarRecursive(var_name); + IR_ENFORCE(var != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + var_name); VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " " << var_name << " " << var->GetType(); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index a9afe7f5c8d0d0c3e0dcffd319f754f0301606e1..69376d22f98c9d59ad820ed60adfe51348e0f472 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -26,7 +26,7 @@ from .framework import convert_np_dtype_to_dtype_, _apply_pass from . import core from . import unique_name from . import compiler -from . import set_flags +from . import set_flags, get_flags from .trainer_factory import TrainerFactory from .trainer_factory import FetchHandlerMonitor import copy @@ -1071,7 +1071,13 @@ class Executor: ) check_feed_shape_type(var, cur_feed) idx = op.desc.attr('col') - core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + new_ir_flag_name = 'FLAGS_enable_new_ir_in_executor' + if get_flags(new_ir_flag_name)[new_ir_flag_name]: + core.set_feed_variable( + scope, cur_feed, feed_target_name, idx + ) + else: + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) else: break