未验证 提交 92fa8f60 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] fix bug: feed_with_place should consider variable existence (#55756)

* fix bug: feed_with_place should consider variable existence

* fix

* fix build scope

* change method to set feed var name

* fix
上级 9dd85b6b
......@@ -368,11 +368,15 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
}
// add fetch with place op to program
auto *block = local_program.MutableBlock(0);
for (auto &in_t : x) {
auto name = in_t.name();
if (block->FindVarRecursive(name) == nullptr) {
continue;
}
auto place = in_t.place().GetType();
auto op_desc = local_program.MutableBlock(0)->PrependOp();
auto op_desc = block->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
......@@ -429,13 +433,17 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
auto local_program =
paddle::framework::ProgramDesc(*(backward_global_block->Program()));
// add feed kernel
auto *block = local_program.MutableBlock(0);
for (auto &out_grad_t : out_grad) {
auto name = out_grad_t.name();
if (block->FindVarRecursive(name) == nullptr) {
continue;
}
auto place = out_grad_t.place().GetType();
if (name == "@EMPTY@") {
continue;
}
auto op_desc = local_program.MutableBlock(0)->PrependOp();
auto op_desc = block->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
......
......@@ -38,10 +38,9 @@ void SetFeedVariable(Scope* scope,
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
if (FLAGS_enable_new_ir_in_executor) {
// shared data with input tensor
auto inner_var_name = var_name + "_" + std::to_string(index);
auto feed_ele = scope->Var(inner_var_name);
auto feed_ele = scope->Var(var_name);
if (!feed_ele->IsType<phi::DenseTensor>()) {
VLOG(3) << "Reset " << inner_var_name << " to phi::DenseTensor";
VLOG(3) << "Reset " << var_name << " to phi::DenseTensor";
feed_ele->Clear();
}
auto val = feed_ele->GetMutable<phi::DenseTensor>();
......
......@@ -1628,9 +1628,18 @@ std::string NewIRInterpreter::DebugValueInfo() {
<< "value -> var_name -> id -> variable*"
<< "\n";
for (auto kv : value_2_var_name_) {
PADDLE_ENFORCE((bool)kv.first,
platform::errors::PreconditionNotMet(
"vlaue(%s) should not be nullptr", kv.second));
PADDLE_ENFORCE(var_name_2_id_.count(kv.second) > 0,
platform::errors::PreconditionNotMet(
"var(%s) should exist in var_name_2_id_", kv.second));
auto* var = InnerScope()->FindVar(kv.second);
PADDLE_ENFORCE(var != nullptr,
platform::errors::PreconditionNotMet(
"var(%s) should exist in var_name_2_id_", kv.second));
os << kv.first.impl() << " -> " << kv.second << " -> "
<< var_name_2_id_.at(kv.second) << " -> "
<< InnerScope()->FindVar(kv.second) << "\n";
<< var_name_2_id_.at(kv.second) << " -> " << var << "\n";
}
return os.str();
}
......
......@@ -200,17 +200,17 @@ void HandleForSpecialOp(
auto value = op->result(0);
VLOG(6) << "link feed output to feed in variable" << inner_scope;
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
std::string name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
paddle::framework::Variable* var = inner_scope->FindVar(name);
PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", name));
auto feed_var_name = "feed_" + std::to_string(index);
value_2_var_name->emplace(value, feed_var_name);
variable_2_var_name->emplace(var, feed_var_name);
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(feed_var_name, id);
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
......@@ -226,6 +226,21 @@ void HandleForSpecialOp(
auto value = op->result(0);
value_2_var_name->emplace(value, var_name);
paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", var_name));
variable_2_var_name->emplace(var, var_name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(var_name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
}
if (op_name == "builtin.combine") {
......
......@@ -475,6 +475,9 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
std::string legacy_output_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
<< legacy_output_name;
// return empty type if this arg is optional and not shown in OpDesc
if (!op_desc.HasOutput(legacy_output_name)) {
VLOG(10) << "[output translating]"
......@@ -491,11 +494,19 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
<< legacy_output_name << " " << legacy_output_vars.size() << " "
<< is_vector;
// Specially process TensorArray, this because we cannot distinguish it with
// Vector<DenseTensor> by other conditions but we cannot support it like
// Vector<DenseTensor>
if (legacy_output_vars.size() == 1) {
VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
legacy_output_vars[0]);
if (var->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
ir::Type translated_var_type =
......@@ -519,6 +530,10 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
auto& var_name = legacy_output_vars[0];
VarDesc* var = block->FindVarRecursive(var_name);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
......
......@@ -26,7 +26,7 @@ from .framework import convert_np_dtype_to_dtype_, _apply_pass
from . import core
from . import unique_name
from . import compiler
from . import set_flags
from . import set_flags, get_flags
from .trainer_factory import TrainerFactory
from .trainer_factory import FetchHandlerMonitor
import copy
......@@ -1071,6 +1071,12 @@ class Executor:
)
check_feed_shape_type(var, cur_feed)
idx = op.desc.attr('col')
new_ir_flag_name = 'FLAGS_enable_new_ir_in_executor'
if get_flags(new_ir_flag_name)[new_ir_flag_name]:
core.set_feed_variable(
scope, cur_feed, feed_target_name, idx
)
else:
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册