未验证 提交 f5811a60 编写于 作者: L Leo Chen 提交者: GitHub

fix cinn_instruction_run inplace var not found problem (#51769)

上级 b0337433
......@@ -310,6 +310,7 @@ inline void RunProgramAPI(
auto input_names = details::GetTensorsName(x);
auto output_names = details::GetTensorsName(out);
auto param_names = details::GetTensorsName(params);
auto dout_names = details::GetTensorsName(dout);
if (VLOG_IS_ON(6)) {
......@@ -319,6 +320,11 @@ inline void RunProgramAPI(
s << name << " ";
}
s << std::endl;
s << "param_names: ";
for (auto name : param_names) {
s << name << " ";
}
s << std::endl;
s << "output_names: ";
for (auto name : output_names) {
s << name << " ";
......
......@@ -243,6 +243,8 @@ std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames(
input_var_names.begin(), input_var_names.end(), exclude_names_fn);
std::for_each(
output_var_names.begin(), output_var_names.end(), exclude_names_fn);
VLOG(1) << "Internal var list: "
<< string::join_strings(remain_var_names, ", ");
return remain_var_names;
}
......@@ -279,19 +281,23 @@ void CinnLaunchContext::CheckTensorEquivalent(
void CinnLaunchContext::InitializeArguments() {
for (auto&& arg : cinn_argument_names_) {
auto cinn_buffer = std::make_unique<cinn_buffer_t>();
auto cinn_tensor = GetCinnTensorOfVar(cinn2paddle_varmap_.at(arg));
auto paddle_varname = cinn2paddle_varmap_.at(arg);
auto cinn_tensor = GetCinnTensorOfVar(paddle_varname);
// assign dimensions with corresponding compiled tensor
cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size());
cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type());
VLOG(4) << string::Sprintf(
"Append an argument:name(%s),dims(%s),type(%s)",
"Append an argument:name(%s),paddle_name(%s), "
"dims(%s),type(%s),tensor(%p)",
arg,
paddle_varname,
framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
cinn_tensor->type());
cinn_tensor->type(),
cinn_tensor.get());
name2argument_.emplace(arg, cinn_buffer.get());
auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg);
paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get());
paddle2argument_.emplace(paddle_varname, cinn_buffer.get());
hold_buffers_.emplace_back(std::move(cinn_buffer));
}
VLOG(4) << "Total argument size:" << name2argument_.size();
......@@ -480,14 +486,37 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
<< "; cached_scope_: " << cached_scope_;
VLOG(1) << "Internal var list: "
<< string::join_strings(internal_var_names_, ", ");
for (auto&& var_name : internal_var_names_) {
auto* var = scope->FindVar(var_name);
if (var != nullptr) {
continue;
}
VLOG(4) << "Create Variable " << var_name << " locally";
framework::InitializeVariable(scope->Var(var_name),
framework::proto::VarType::LOD_TENSOR);
}
// Actually, cinn_instruction will not use the var with name
// var_name+InplaceOutSuffix in paddle scope, but use the var with name.
// That means, var 'a' and 'a@InplaceOut' in cinn scope both links to var
// 'a' in paddle scope.
// So, why create 'a@InplaceOut' here?
// In order to make some paddle functions can visit all inputs and outputs
// of cinn_instruction_run op, for example, infer_shape function.
// It should be refined.
for (auto&& var_name : inplace_var_names_) {
auto name = var_name + InplaceOutSuffix;
auto* var = scope->FindVar(name);
if (var != nullptr) {
continue;
}
VLOG(4) << "Create Variable " << name << " locally";
framework::InitializeVariable(scope->Var(name),
framework::proto::VarType::LOD_TENSOR);
}
if (!interpreter_core_) {
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册