未验证 提交 f5811a60 编写于 作者: L Leo Chen 提交者: GitHub

fix cinn_instruction_run inplace var not found problem (#51769)

上级 b0337433
...@@ -310,6 +310,7 @@ inline void RunProgramAPI( ...@@ -310,6 +310,7 @@ inline void RunProgramAPI(
auto input_names = details::GetTensorsName(x); auto input_names = details::GetTensorsName(x);
auto output_names = details::GetTensorsName(out); auto output_names = details::GetTensorsName(out);
auto param_names = details::GetTensorsName(params);
auto dout_names = details::GetTensorsName(dout); auto dout_names = details::GetTensorsName(dout);
if (VLOG_IS_ON(6)) { if (VLOG_IS_ON(6)) {
...@@ -319,6 +320,11 @@ inline void RunProgramAPI( ...@@ -319,6 +320,11 @@ inline void RunProgramAPI(
s << name << " "; s << name << " ";
} }
s << std::endl; s << std::endl;
s << "param_names: ";
for (auto name : param_names) {
s << name << " ";
}
s << std::endl;
s << "output_names: "; s << "output_names: ";
for (auto name : output_names) { for (auto name : output_names) {
s << name << " "; s << name << " ";
......
...@@ -243,6 +243,8 @@ std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames( ...@@ -243,6 +243,8 @@ std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames(
input_var_names.begin(), input_var_names.end(), exclude_names_fn); input_var_names.begin(), input_var_names.end(), exclude_names_fn);
std::for_each( std::for_each(
output_var_names.begin(), output_var_names.end(), exclude_names_fn); output_var_names.begin(), output_var_names.end(), exclude_names_fn);
VLOG(1) << "Internal var list: "
<< string::join_strings(remain_var_names, ", ");
return remain_var_names; return remain_var_names;
} }
...@@ -279,19 +281,23 @@ void CinnLaunchContext::CheckTensorEquivalent( ...@@ -279,19 +281,23 @@ void CinnLaunchContext::CheckTensorEquivalent(
void CinnLaunchContext::InitializeArguments() { void CinnLaunchContext::InitializeArguments() {
for (auto&& arg : cinn_argument_names_) { for (auto&& arg : cinn_argument_names_) {
auto cinn_buffer = std::make_unique<cinn_buffer_t>(); auto cinn_buffer = std::make_unique<cinn_buffer_t>();
auto cinn_tensor = GetCinnTensorOfVar(cinn2paddle_varmap_.at(arg)); auto paddle_varname = cinn2paddle_varmap_.at(arg);
auto cinn_tensor = GetCinnTensorOfVar(paddle_varname);
// assign dimensions with corresponding compiled tensor // assign dimensions with corresponding compiled tensor
cinn_buffer->resize(cinn_tensor->shape().data().data(), cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size()); cinn_tensor->shape().data().size());
cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type()); cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type());
VLOG(4) << string::Sprintf( VLOG(4) << string::Sprintf(
"Append an argument:name(%s),dims(%s),type(%s)", "Append an argument:name(%s),paddle_name(%s), "
"dims(%s),type(%s),tensor(%p)",
arg, arg,
paddle_varname,
framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(), framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
cinn_tensor->type()); cinn_tensor->type(),
cinn_tensor.get());
name2argument_.emplace(arg, cinn_buffer.get()); name2argument_.emplace(arg, cinn_buffer.get());
auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg);
paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get()); paddle2argument_.emplace(paddle_varname, cinn_buffer.get());
hold_buffers_.emplace_back(std::move(cinn_buffer)); hold_buffers_.emplace_back(std::move(cinn_buffer));
} }
VLOG(4) << "Total argument size:" << name2argument_.size(); VLOG(4) << "Total argument size:" << name2argument_.size();
...@@ -480,14 +486,37 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore( ...@@ -480,14 +486,37 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
<< "; cached_scope_: " << cached_scope_; << "; cached_scope_: " << cached_scope_;
VLOG(1) << "Internal var list: " VLOG(1) << "Internal var list: "
<< string::join_strings(internal_var_names_, ", "); << string::join_strings(internal_var_names_, ", ");
for (auto&& var_name : internal_var_names_) { for (auto&& var_name : internal_var_names_) {
auto* var = scope->FindVar(var_name); auto* var = scope->FindVar(var_name);
if (var != nullptr) { if (var != nullptr) {
continue; continue;
} }
VLOG(4) << "Create Variable " << var_name << " locally";
framework::InitializeVariable(scope->Var(var_name), framework::InitializeVariable(scope->Var(var_name),
framework::proto::VarType::LOD_TENSOR); framework::proto::VarType::LOD_TENSOR);
} }
// Actually, cinn_instruction will not use the var with name
// var_name+InplaceOutSuffix in paddle scope, but use the var with name.
// That means, var 'a' and 'a@InplaceOut' in cinn scope both links to var
// 'a' in paddle scope.
// So, why create 'a@InplaceOut' here?
// In order to make some paddle functions can visit all inputs and outputs
// of cinn_instruction_run op, for example, infer_shape function.
// It should be refined.
for (auto&& var_name : inplace_var_names_) {
auto name = var_name + InplaceOutSuffix;
auto* var = scope->FindVar(name);
if (var != nullptr) {
continue;
}
VLOG(4) << "Create Variable " << name << " locally";
framework::InitializeVariable(scope->Var(name),
framework::proto::VarType::LOD_TENSOR);
}
if (!interpreter_core_) { if (!interpreter_core_) {
framework::interpreter::ExecutionConfig execution_config; framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false; execution_config.create_local_scope = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册