未验证 提交 44f409cf 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Add variable name prefix for BuildScope (#55536)

* add interface

* add code

* add code

* add code

* add code

* fix bug

* fix bug

* add var prefix
上级 1f79fd47
......@@ -66,8 +66,8 @@ paddle::framework::Variable* CreateVar(
}
paddle::framework::Variable* var = nullptr;
VLOG(6) << "var_name_prefix is: " << var_name_prefix;
std::string name = "inner_var_" + std::to_string(variable_2_var_name->size());
std::string name = var_name_prefix + "_inner_var_" +
std::to_string(variable_2_var_name->size());
if (force_persisable || is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root();
var = const_cast<paddle::framework::Scope*>(inner_scope->root())->Var(name);
......
......@@ -22,6 +22,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
......@@ -69,14 +70,23 @@ TEST(StandaloneExecutor, run) {
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
VLOG(0) << "&test_core" << &test_core;
VLOG(0) << "&test_core.impl" << test_core.Impl();
VLOG(0) << "&test_core.impl.cast"
<< reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
test_core.BetaRun({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_2")
->Get<phi::DenseTensor>();
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_2")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
......@@ -107,11 +117,16 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.BetaRun({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_0")
->Get<phi::DenseTensor>();
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_0")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册