Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1f0512be
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1f0512be
编写于
11月 22, 2021
作者:
L
Leo Chen
提交者:
GitHub
11月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[new feature] add local scope for interpretercore (#37379)
上级
964e20e0
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
217 addition
and
109 deletion
+217
-109
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+50
-13
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+11
-1
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+52
-27
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+4
-3
paddle/fluid/framework/new_executor/new_executor_defs.cc
paddle/fluid/framework/new_executor/new_executor_defs.cc
+14
-6
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+16
-5
paddle/fluid/framework/new_executor/standalone_executor.cc
paddle/fluid/framework/new_executor/standalone_executor.cc
+41
-28
paddle/fluid/framework/new_executor/standalone_executor.h
paddle/fluid/framework/new_executor/standalone_executor.h
+1
-2
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+2
-2
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+1
-1
paddle/fluid/operators/controlflow/fetch_v2_op.cc
paddle/fluid/operators/controlflow/fetch_v2_op.cc
+0
-8
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+23
-11
python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py
...d/tests/unittests/interpreter/test_standalone_executor.py
+2
-2
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
1f0512be
...
@@ -22,6 +22,9 @@
...
@@ -22,6 +22,9 @@
PADDLE_DEFINE_EXPORTED_bool
(
new_executor_use_inplace
,
true
,
PADDLE_DEFINE_EXPORTED_bool
(
new_executor_use_inplace
,
true
,
"Use inplace in new executor"
);
"Use inplace in new executor"
);
PADDLE_DEFINE_EXPORTED_bool
(
new_executor_use_local_scope
,
true
,
"Use local_scope in new executor(especially used "
"in UT), can turn off for better performance"
);
DECLARE_bool
(
check_nan_inf
);
DECLARE_bool
(
check_nan_inf
);
DECLARE_bool
(
benchmark
);
DECLARE_bool
(
benchmark
);
...
@@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
...
@@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
create_local_scope_
=
FLAGS_new_executor_use_local_scope
;
if
(
FLAGS_new_executor_use_local_scope
)
{
auto
local_scope
=
&
global_scope
->
GetMutableScope
()
->
NewScope
();
local_scope
->
AddListener
(
global_scope
->
Listener
());
local_scope_
=
local_scope
;
}
VLOG
(
4
)
<<
"create_local_scope_ is "
<<
create_local_scope_
;
// prune
// prune
// optmize graph pass
// optmize graph pass
...
@@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() {
...
@@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() {
async_work_queue_
.
reset
(
nullptr
);
async_work_queue_
.
reset
(
nullptr
);
}
}
void
InterpreterCore
::
SetCopyProgram
(
std
::
shared_ptr
<
ProgramDesc
>
prog
)
{
copy_program_
=
prog
;
}
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
bool
is_build
=
is_build_
;
bool
is_build
=
is_build_
;
global_scope_
->
SetLocalScope
(
local_scope_
);
Prepare
(
feed_names
,
feed_tensors
,
is_build
);
Prepare
(
feed_names
,
feed_tensors
,
is_build
);
if
(
is_build
)
{
if
(
is_build
)
{
...
@@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run(
...
@@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
()
{
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
()
{
if
(
!
is_build_
)
{
if
(
!
is_build_
)
{
paddle
::
framework
::
interpreter
::
build_variable_scope
(
block_
,
global_scope_
);
if
(
create_local_scope_
&&
global_scope_
->
GetMutableLocalScope
()
!=
global_scope_
->
GetMutableScope
()
&&
global_scope_
->
GetMutableLocalScope
())
{
VLOG
(
4
)
<<
"Clear previous local scope before run"
;
VLOG
(
4
)
<<
global_scope_
->
GetMutableScope
()
<<
" "
<<
global_scope_
->
GetMutableLocalScope
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
)
->
Wait
();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_
->
SetLocalScope
(
local_scope_
);
paddle
::
framework
::
interpreter
::
build_variable_scope
(
block_
,
global_scope_
,
create_local_scope_
);
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
paddle
::
framework
::
interpreter
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
block_
,
&
op_func_nodes
,
global_scope_
);
place_
,
block_
,
&
op_func_nodes
,
global_scope_
,
create_local_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
(
&
op_func_nodes
);
Convert
(
&
op_func_nodes
);
}
else
{
}
else
{
ExecuteInstructionList
(
vec_instruction_
);
ExecuteInstructionList
(
vec_instruction_
);
}
}
...
@@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
...
@@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void
InterpreterCore
::
RunInstruction
(
const
Instruction
&
instr_node
)
{
void
InterpreterCore
::
RunInstruction
(
const
Instruction
&
instr_node
)
{
auto
*
op
=
instr_node
.
OpBase
();
auto
*
op
=
instr_node
.
OpBase
();
auto
place
=
instr_node
.
DeviceContext
().
GetPlace
();
auto
place
=
instr_node
.
DeviceContext
().
GetPlace
();
VLOG
(
4
)
<<
"Start run"
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
VLOG
(
4
)
<<
"Start run "
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
Scope
*
local_scope
=
create_local_scope_
?
global_scope_
->
GetMutableLocalScope
()
:
global_scope_
->
GetMutableScope
();
auto
op_with_kernel
=
dynamic_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
);
auto
op_with_kernel
=
dynamic_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
);
{
{
...
@@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
...
@@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
}
{
{
platform
::
RecordEvent
compute_event
(
"Compute"
);
platform
::
RecordEvent
compute_event
(
"Compute"
);
if
(
op_with_kernel
==
nullptr
)
if
(
op_with_kernel
==
nullptr
)
{
instr_node
.
OpBase
()
->
Run
(
*
global_scope_
->
GetScope
()
,
place_
);
instr_node
.
OpBase
()
->
Run
(
*
local_scope
,
place_
);
else
}
else
{
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
}
}
}
VLOG
(
4
)
<<
"End run"
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
VLOG
(
4
)
<<
"End run
"
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
/*For profiling/benchmark only*/
/*For profiling/benchmark only*/
if
(
FLAGS_benchmark
)
{
if
(
FLAGS_benchmark
)
{
...
@@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList(
...
@@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList(
}
}
}
}
auto
event_
id
=
main_thread_blocker_
.
WaitEvent
();
auto
event_
name
=
main_thread_blocker_
.
WaitEvent
();
VLOG
(
3
)
<<
"event_
id "
<<
event_id
;
VLOG
(
3
)
<<
"event_
name: "
<<
event_name
;
if
(
UNLIKELY
(
exception_holder_
.
IsCaught
()))
{
if
(
UNLIKELY
(
exception_holder_
.
IsCaught
()))
{
VLOG
(
4
)
<<
"Exception caught "
<<
exception_holder_
.
Type
();
VLOG
(
4
)
<<
"Exception caught "
<<
exception_holder_
.
Type
();
...
@@ -526,8 +560,9 @@ void InterpreterCore::Prepare(
...
@@ -526,8 +560,9 @@ void InterpreterCore::Prepare(
VLOG
(
4
)
<<
"Feed inputs"
;
VLOG
(
4
)
<<
"Feed inputs"
;
for
(
size_t
i
=
0
;
i
<
feed_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
feed_names
.
size
();
++
i
)
{
auto
*
feed_var
=
global_scope_
->
FindVar
(
feed_names
[
i
]);
auto
*
feed_var
=
global_scope_
->
FindVar
(
feed_names
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
feed_var
,
platform
::
errors
::
NotFound
(
PADDLE_ENFORCE_NOT_NULL
(
"feed_var shall not be nullptr."
));
feed_var
,
platform
::
errors
::
NotFound
(
"Variable %s should not be nullptr."
,
feed_names
[
i
]));
auto
feed_tensor
=
feed_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
feed_tensor
=
feed_var
->
GetMutable
<
framework
::
LoDTensor
>
();
feed_tensor
->
ShareDataWith
(
feed_tensors
[
i
]);
feed_tensor
->
ShareDataWith
(
feed_tensors
[
i
]);
...
@@ -536,11 +571,12 @@ void InterpreterCore::Prepare(
...
@@ -536,11 +571,12 @@ void InterpreterCore::Prepare(
};
};
if
(
!
is_build_
)
{
if
(
!
is_build_
)
{
paddle
::
framework
::
interpreter
::
build_variable_scope
(
block_
,
global_scope_
);
paddle
::
framework
::
interpreter
::
build_variable_scope
(
block_
,
global_scope_
,
create_local_scope_
);
FeedInput
();
FeedInput
();
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
paddle
::
framework
::
interpreter
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
block_
,
&
op_func_nodes
,
global_scope_
);
place_
,
block_
,
&
op_func_nodes
,
global_scope_
,
create_local_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
(
&
op_func_nodes
);
Convert
(
&
op_func_nodes
);
...
@@ -556,6 +592,7 @@ void InterpreterCore::Prepare(
...
@@ -556,6 +592,7 @@ void InterpreterCore::Prepare(
interpreter
::
CostInfo
InterpreterCore
::
DryRun
(
interpreter
::
CostInfo
InterpreterCore
::
DryRun
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
global_scope_
->
SetLocalScope
(
local_scope_
);
Prepare
(
feed_names
,
feed_tensors
,
true
);
Prepare
(
feed_names
,
feed_tensors
,
true
);
interpreter
::
CostInfo
cost_info
;
interpreter
::
CostInfo
cost_info
;
{
{
...
...
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
1f0512be
...
@@ -55,6 +55,8 @@ class InterpreterCore {
...
@@ -55,6 +55,8 @@ class InterpreterCore {
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
void
SetCopyProgram
(
std
::
shared_ptr
<
ProgramDesc
>
prog
);
private:
private:
void
Convert
(
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>*
op_func_nodes
);
void
Convert
(
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>*
op_func_nodes
);
...
@@ -85,7 +87,13 @@ class InterpreterCore {
...
@@ -85,7 +87,13 @@ class InterpreterCore {
bool
is_build_
;
bool
is_build_
;
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
const
BlockDesc
&
block_
;
// not owned
const
BlockDesc
&
block_
;
// not owned
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std
::
shared_ptr
<
ProgramDesc
>
copy_program_
{
nullptr
};
VariableScope
*
global_scope_
;
// not owned
VariableScope
*
global_scope_
;
// not owned
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
...
@@ -102,6 +110,8 @@ class InterpreterCore {
...
@@ -102,6 +110,8 @@ class InterpreterCore {
std
::
unique_ptr
<
InterpreterCoreGarbageCollector
>
gc_
;
std
::
unique_ptr
<
InterpreterCoreGarbageCollector
>
gc_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
bool
create_local_scope_
{
true
};
Scope
*
local_scope_
{
nullptr
};
// not owned
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
1f0512be
...
@@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place,
...
@@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
}
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
,
bool
use_local_scope
)
{
VLOG
(
3
)
<<
"Creating Variables"
;
auto
inner_scope
=
var_scope
->
GetMutableScope
();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope
*
local_scope
=
use_local_scope
?
var_scope
->
GetMutableLocalScope
()
:
var_scope
->
GetMutableScope
();
for
(
auto
&
var_desc
:
block
.
AllVars
())
{
for
(
auto
&
var_desc
:
block
.
AllVars
())
{
auto
var_name
=
var_desc
->
Name
();
auto
var_name
=
var_desc
->
Name
();
if
(
var_name
==
framework
::
kEmptyVarName
)
{
if
(
var_name
==
framework
::
kEmptyVarName
)
{
continue
;
continue
;
}
}
if
(
var_desc
->
Persistable
())
{
auto
*
ptr
=
inner_scope
->
Var
(
var_name
);
if
(
nullptr
==
var_scope
->
FindVar
(
var_name
))
{
VLOG
(
3
)
<<
"Initialize Variable "
<<
var_name
;
var_scope
->
AddVar
(
var_desc
->
Name
(),
var_desc
);
InitializeVariable
(
ptr
,
var_desc
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var_name
<<
" global, which pointer is "
<<
ptr
<<
" type is "
<<
static_cast
<
int
>
(
var_desc
->
GetType
());
}
else
{
}
else
{
auto
*
var_desc_tmp
=
var_scope
->
VarDesc
(
var_name
);
auto
*
ptr
=
local_scope
->
Var
(
var_name
);
if
(
nullptr
==
var_desc_tmp
)
{
InitializeVariable
(
ptr
,
var_desc
->
GetType
());
VLOG
(
3
)
<<
"update var:"
<<
var_name
<<
" desc from nullptr into "
VLOG
(
3
)
<<
"Create Variable "
<<
var_name
<<
" locally, which pointer is "
<<
var_desc
;
<<
ptr
<<
"Variable Type "
var_scope
->
SetVarDesc
(
var_name
,
var_desc
);
<<
static_cast
<
int
>
(
var_desc
->
GetType
());
}
}
}
var_scope
->
SetVarDesc
(
var_name
,
var_desc
);
}
}
}
}
...
@@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base,
...
@@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base,
void
deal_operator_base
(
const
platform
::
Place
&
place
,
void
deal_operator_base
(
const
platform
::
Place
&
place
,
const
VariableScope
*
var_scope
,
const
VariableScope
*
var_scope
,
std
::
shared_ptr
<
OperatorBase
>
op_base
,
std
::
shared_ptr
<
OperatorBase
>
op_base
,
OpFuncNode
*
op_func_node
)
{
OpFuncNode
*
op_func_node
,
Scope
*
local_scope
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
place
);
auto
*
dev_ctx
=
pool
.
Get
(
place
);
// input, output is prepared. set the other attributes.
// input, output is prepared. set the other attributes.
op_func_node
->
operator_base_
=
op_base
;
op_func_node
->
operator_base_
=
op_base
;
op_func_node
->
type_
=
OpFuncType
::
kQueueSync
;
// alway Sync
op_func_node
->
type_
=
OpFuncType
::
kQueueSync
;
// alway Sync
op_func_node
->
kernel_func_
=
nullptr
;
op_func_node
->
kernel_func_
=
nullptr
;
op_base
->
Run
(
*
var_scope
->
GetScope
()
,
place
);
// Run without data transformer.
op_base
->
Run
(
*
local_scope
,
place
);
// Run without data transformer.
std
::
unordered_set
<
int
>
no_data_transform_index
;
std
::
unordered_set
<
int
>
no_data_transform_index
;
for
(
auto
&
it
:
op_func_node
->
input_index
)
{
for
(
auto
&
it
:
op_func_node
->
input_index
)
{
...
@@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
...
@@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
const
std
::
string
&
var_name
,
const
std
::
string
&
outer_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
outer_name
,
const
OpFuncNode
&
op_func_node
,
Variable
*
var
,
VariableScope
*
var_scope
)
{
const
OpFuncNode
&
op_func_node
,
Variable
*
var
,
VariableScope
*
var_scope
,
bool
use_local_scope
=
true
)
{
Scope
*
local_scope
=
use_local_scope
?
var_scope
->
GetMutableLocalScope
()
:
var_scope
->
GetMutableScope
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
std
::
string
new_var_name
=
std
::
string
new_var_name
=
var_name
+
"_copy_"
+
std
::
to_string
(
var_scope
->
VarSize
()
+
1
);
var_name
+
"_copy_"
+
std
::
to_string
(
var_scope
->
VarSize
()
+
1
);
var_scope
->
AddVar
(
new_var_name
,
nullptr
);
auto
*
ptr
=
local_scope
->
Var
(
new_var_name
);
InitializeVariable
(
ptr
,
static_cast
<
proto
::
VarType
::
Type
>
(
var
->
Type
()));
VLOG
(
3
)
<<
"Create Variable "
<<
var_name
<<
" locally, which pointer is "
<<
ptr
<<
"Variable Type "
<<
var
->
Type
();
var_scope
->
SetVarDesc
(
var_name
,
nullptr
);
VariableNameMap
copy_in_map
;
VariableNameMap
copy_in_map
;
copy_in_map
[
"X"
]
=
{
var_name
};
copy_in_map
[
"X"
]
=
{
var_name
};
...
@@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
...
@@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
VariableValueMap
*
ins_map_temp
,
VariableValueMap
*
ins_map_temp
,
VariableScope
*
var_scope
,
OpFuncNode
*
op_func_node
,
VariableScope
*
var_scope
,
OpFuncNode
*
op_func_node
,
std
::
vector
<
OpFuncNode
>*
copy_func_nodes
)
{
std
::
vector
<
OpFuncNode
>*
copy_func_nodes
,
bool
use_local_scope
=
true
)
{
auto
op_base
=
op_func_node
->
operator_base_
.
get
();
auto
op_base
=
op_func_node
->
operator_base_
.
get
();
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
"op_base is null, please pass a valid "
"op_base is null, please pass a valid "
...
@@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
...
@@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
std
::
string
new_var_name
;
std
::
string
new_var_name
;
OpFuncNode
copy_op_func_node
;
OpFuncNode
copy_op_func_node
;
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
apply_place_transform_for_var
(
apply_place_transform_for_var
(
kernel_type_for_var
,
kernel_type_for_var
,
expected_kernel_key
,
place
,
var_name
,
expected_kernel_key
,
place
,
var_name
,
var_name_item
.
first
,
*
op_func_node
,
var
,
var_scope
);
var_name_item
.
first
,
*
op_func_node
,
var
,
var_scope
,
use_local_scope
);
op_func_node
->
input_index
[
var_name_item
.
first
][
i
]
=
op_func_node
->
input_index
[
var_name_item
.
first
][
i
]
=
var_scope
->
VarId
(
new_var_name
);
var_scope
->
VarId
(
new_var_name
);
copy_func_nodes
->
emplace_back
(
copy_op_func_node
);
copy_func_nodes
->
emplace_back
(
copy_op_func_node
);
...
@@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
...
@@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
BlockDesc
&
block
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
,
bool
use_local_scope
)
{
Scope
*
local_scope
=
use_local_scope
?
var_scope
->
GetMutableLocalScope
()
:
var_scope
->
GetMutableScope
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops
;
// its elements will be moved to vec_func_list
ops
;
// its elements will be moved to vec_func_list
...
@@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place,
if
(
dynamic_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
)
==
nullptr
)
{
if
(
dynamic_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
)
==
nullptr
)
{
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base
(
place
,
var_scope
,
ops
[
i
],
&
op_func_node
);
deal_operator_base
(
place
,
var_scope
,
ops
[
i
],
&
op_func_node
,
local_scope
);
}
else
{
}
else
{
// construct RuntimeContext and analysis KernelType
// construct RuntimeContext and analysis KernelType
RuntimeContext
runtime_context
({},
{});
RuntimeContext
runtime_context
({},
{});
...
@@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
// apply_data_transform.
op_func_node
.
operator_base_
=
ops
[
i
];
op_func_node
.
operator_base_
=
ops
[
i
];
apply_data_transform
(
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
apply_data_transform
(
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
&
op_func_node
,
&
copy_op_to_insert
);
&
op_func_node
,
&
copy_op_to_insert
,
use_local_scope
);
for
(
auto
&
item
:
copy_op_to_insert
)
{
for
(
auto
&
item
:
copy_op_to_insert
)
{
vec_func_list
->
push_back
(
item
);
vec_func_list
->
push_back
(
item
);
}
}
...
@@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
...
@@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
}
}
void
update_var_min_rw_op
(
const
std
::
map
<
int
,
std
::
set
<
int
>>&
op2dependences
,
void
update_var_min_rw_op
(
const
std
::
map
<
int
,
std
::
set
<
int
>>&
op2dependences
,
std
::
map
<
int
,
std
::
list
<
int
>>
&
var2min_rw_op
,
std
::
map
<
int
,
std
::
list
<
int
>>
*
var2min_rw_op
,
int
cur_op
,
int
rw_var
)
{
int
cur_op
,
int
rw_var
)
{
// rw_var is inputs or outputs of cur_op
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
// this function update the var2min_rw_op set .
if
(
var2min_rw_op
.
find
(
rw_var
)
==
var2min_rw_op
.
end
())
if
(
var2min_rw_op
->
find
(
rw_var
)
==
var2min_rw_op
->
end
())
var2min_rw_op
[
rw_var
]
=
std
::
list
<
int
>
();
(
*
var2min_rw_op
)
[
rw_var
]
=
std
::
list
<
int
>
();
for
(
auto
dep_op
:
op2dependences
.
at
(
cur_op
))
{
for
(
auto
dep_op
:
op2dependences
.
at
(
cur_op
))
{
var2min_rw_op
[
rw_var
].
remove
(
dep_op
);
(
*
var2min_rw_op
)
[
rw_var
].
remove
(
dep_op
);
}
}
var2min_rw_op
[
rw_var
].
push_back
(
cur_op
);
(
*
var2min_rw_op
)
[
rw_var
].
push_back
(
cur_op
);
}
}
std
::
map
<
int
,
std
::
list
<
int
>>
get_downstream_map
(
std
::
map
<
int
,
std
::
list
<
int
>>
get_downstream_map
(
...
@@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
...
@@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for
(
auto
&
item
:
for
(
auto
&
item
:
vec_instruction
[
op_idx
].
Inputs
())
{
// for all inputs(read only)
vec_instruction
[
op_idx
].
Inputs
())
{
// for all inputs(read only)
for
(
auto
var
:
item
.
second
)
{
for
(
auto
var
:
item
.
second
)
{
update_var_min_rw_op
(
op2dependences
,
var2min_rw_op
,
op_idx
,
var
);
update_var_min_rw_op
(
op2dependences
,
&
var2min_rw_op
,
op_idx
,
var
);
remove_duplicate
.
insert
(
var
);
remove_duplicate
.
insert
(
var
);
}
}
}
}
...
@@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
...
@@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
var2recent_write_op
[
var
]
=
op_idx
;
var2recent_write_op
[
var
]
=
op_idx
;
if
(
remove_duplicate
.
count
(
var
)
==
if
(
remove_duplicate
.
count
(
var
)
==
0
)
{
// var in input list and in output list, so remove it.
0
)
{
// var in input list and in output list, so remove it.
update_var_min_rw_op
(
op2dependences
,
var2min_rw_op
,
op_idx
,
var
);
update_var_min_rw_op
(
op2dependences
,
&
var2min_rw_op
,
op_idx
,
var
);
}
}
}
}
}
}
...
...
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
1f0512be
...
@@ -51,7 +51,7 @@ namespace framework {
...
@@ -51,7 +51,7 @@ namespace framework {
namespace
interpreter
{
namespace
interpreter
{
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
static
constexpr
char
kFetchVarName
[]
=
"fetch
_vars
"
;
static
constexpr
char
kFetchVarName
[]
=
"fetch"
;
class
AsyncWorkQueue
{
class
AsyncWorkQueue
{
public:
public:
...
@@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place,
...
@@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place,
const
platform
::
Place
&
dst_place
);
const
platform
::
Place
&
dst_place
);
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
,
bool
use_local_scope
=
true
);
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
BlockDesc
&
block
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
,
bool
use_local_scope
=
true
);
std
::
map
<
int
,
std
::
list
<
int
>>
build_op_downstream_map
(
std
::
map
<
int
,
std
::
list
<
int
>>
build_op_downstream_map
(
const
std
::
vector
<
Instruction
>&
vec_instruction
);
const
std
::
vector
<
Instruction
>&
vec_instruction
);
...
...
paddle/fluid/framework/new_executor/new_executor_defs.cc
浏览文件 @
1f0512be
...
@@ -497,7 +497,14 @@ VariableScope::~VariableScope() {
...
@@ -497,7 +497,14 @@ VariableScope::~VariableScope() {
}
}
}
}
const
Scope
*
VariableScope
::
GetScope
()
const
{
return
scope_
;
}
Scope
*
VariableScope
::
GetMutableScope
()
const
{
return
scope_
;
}
Scope
*
VariableScope
::
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
void
VariableScope
::
SetLocalScope
(
Scope
*
local_scope
)
{
VLOG
(
4
)
<<
"Set local scope: "
<<
local_scope
;
local_scope_
=
local_scope
;
}
Variable
*
VariableScope
::
FindVar
(
const
std
::
string
&
name
)
const
{
Variable
*
VariableScope
::
FindVar
(
const
std
::
string
&
name
)
const
{
auto
it
=
name2id_
.
find
(
name
);
auto
it
=
name2id_
.
find
(
name
);
...
@@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const {
...
@@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const {
size_t
VariableScope
::
VarSize
()
const
{
return
var_list_
.
size
();
}
size_t
VariableScope
::
VarSize
()
const
{
return
var_list_
.
size
();
}
void
VariableScope
::
AddVar
(
const
std
::
string
&
name
,
void
VariableScope
::
AddVar
(
const
std
::
string
&
name
,
framework
::
VarDesc
*
var_desc
)
{
// NOLINT
framework
::
VarDesc
*
var_desc
,
auto
v
=
scope_
->
Var
(
name
);
bool
local_scope
)
{
// NOLINT
auto
v
=
local_scope
?
local_scope_
->
Var
(
name
)
:
scope_
->
Var
(
name
);
if
(
nullptr
==
var_desc
)
{
if
(
nullptr
==
var_desc
)
{
v
->
GetMutable
<
LoDTensor
>
();
v
->
GetMutable
<
LoDTensor
>
();
}
else
{
}
else
{
...
@@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
...
@@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_
=
var_scope
;
var_scope_
=
var_scope
;
}
}
void
VariableScopeListener
::
onCreateVariable
(
const
std
::
string
&
name
)
{
void
VariableScopeListener
::
onCreateVariable
(
const
std
::
string
&
name
,
auto
v
=
var_scope_
->
scope_
->
GetVar
(
name
);
// must exsit in outer_scope_
Variable
*
v
)
{
if
(
!
var_scope_
->
HasVar
(
name
))
{
// may exist in variable scope.
if
(
!
var_scope_
->
HasVar
(
name
))
{
// may exist in variable scope.
VLOG
(
4
)
<<
"Calling VariableScope::onCreateVariable with var_name: "
VLOG
(
4
)
<<
"Calling VariableScope::onCreateVariable with var_name: "
<<
name
;
<<
name
;
var_scope_
->
name2id_
[
name
]
=
var_scope_
->
VarSize
();
var_scope_
->
name2id_
[
name
]
=
var_scope_
->
VarSize
();
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
1f0512be
...
@@ -155,7 +155,7 @@ class VariableScope;
...
@@ -155,7 +155,7 @@ class VariableScope;
class
VariableScopeListener
:
public
ScopeListener
{
class
VariableScopeListener
:
public
ScopeListener
{
public:
public:
explicit
VariableScopeListener
(
VariableScope
*
var_scope_
);
explicit
VariableScopeListener
(
VariableScope
*
var_scope_
);
void
onCreateVariable
(
const
std
::
string
&
name
)
override
;
void
onCreateVariable
(
const
std
::
string
&
name
,
Variable
*
v
)
override
;
void
onDeleteVariable
(
const
std
::
string
&
name
)
override
;
void
onDeleteVariable
(
const
std
::
string
&
name
)
override
;
void
onRenameVariable
(
const
std
::
string
&
old_name
,
void
onRenameVariable
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
override
;
const
std
::
string
&
new_name
)
override
;
...
@@ -177,7 +177,11 @@ class VariableScope : public ScopeBase {
...
@@ -177,7 +177,11 @@ class VariableScope : public ScopeBase {
public:
public:
explicit
VariableScope
(
Scope
*
scope
);
explicit
VariableScope
(
Scope
*
scope
);
const
Scope
*
GetScope
()
const
;
Scope
*
GetMutableScope
()
const
;
Scope
*
GetMutableLocalScope
()
const
;
void
SetLocalScope
(
Scope
*
local_scope
);
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
...
@@ -199,7 +203,8 @@ class VariableScope : public ScopeBase {
...
@@ -199,7 +203,8 @@ class VariableScope : public ScopeBase {
size_t
VarSize
()
const
;
size_t
VarSize
()
const
;
void
AddVar
(
const
std
::
string
&
name
,
VarDesc
*
var_desc
);
void
AddVar
(
const
std
::
string
&
name
,
VarDesc
*
var_desc
,
bool
local_scope
=
false
);
void
AddVar
(
const
std
::
string
&
name
,
const
Variable
&
var
);
void
AddVar
(
const
std
::
string
&
name
,
const
Variable
&
var
);
...
@@ -219,15 +224,21 @@ class VariableScope : public ScopeBase {
...
@@ -219,15 +224,21 @@ class VariableScope : public ScopeBase {
return
vec_meta_info_
;
return
vec_meta_info_
;
}
}
const
std
::
shared_ptr
<
VariableScopeListener
>&
Listener
()
const
{
return
listener_
;
}
friend
class
VariableScopeListener
;
friend
class
VariableScopeListener
;
private:
private:
std
::
vector
<
Variable
*>
var_list_
;
std
::
vector
<
Variable
*>
var_list_
;
std
::
map
<
std
::
string
,
int
>
name2id_
;
std
::
map
<
std
::
string
,
int
>
name2id_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
Scope
*
scope_
=
nullptr
;
Scope
*
scope_
{
nullptr
};
// TODO(zhiqiu): find a better way to support local scope.
Scope
*
local_scope_
{
nullptr
};
// mutable RWLock vars_lock_;
// mutable RWLock vars_lock_;
std
::
shared_ptr
<
VariableScopeListener
>
listener_
;
std
::
shared_ptr
<
VariableScopeListener
>
listener_
{
nullptr
}
;
};
};
class
NextInstruction
{
class
NextInstruction
{
...
...
paddle/fluid/framework/new_executor/standalone_executor.cc
浏览文件 @
1f0512be
...
@@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
...
@@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_
(
startup_prog
),
startup_prog_
(
startup_prog
),
main_prog_
(
main_prog
),
main_prog_
(
main_prog
),
global_scope_
(
VariableScope
(
scope
))
{
global_scope_
(
VariableScope
(
scope
))
{
// init scope
// NOTE(zhiqiu): for startup_program, initialize scope and run once
BuildVariableScope
(
startup_prog
,
&
global_scope_
);
// if startup_program is empty, the scope is initialize during first run
if
(
startup_prog
.
Block
(
0
).
AllOps
().
size
()
>
0
)
{
if
(
scope
!=
nullptr
)
{
VLOG
(
4
)
<<
"Run startup program"
;
auto
name_list
=
scope
->
LocalVarNames
();
// init scope
for
(
auto
name
:
name_list
)
{
BuildVariableScope
(
startup_prog
,
&
global_scope_
);
auto
v
=
scope
->
Var
(
name
);
if
(
!
global_scope_
.
HasVar
(
name
))
{
if
(
scope
!=
nullptr
)
{
global_scope_
.
AddVar
(
name
,
*
v
);
auto
name_list
=
scope
->
LocalVarNames
();
for
(
auto
name
:
name_list
)
{
auto
v
=
scope
->
Var
(
name
);
if
(
!
global_scope_
.
HasVar
(
name
))
{
global_scope_
.
AddVar
(
name
,
*
v
);
}
}
}
}
}
}
// run startup program
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
// No need to use_local_scope for startup_program, its variables are
paddle
::
framework
::
interpreter
::
build_op_func_list
(
// persistable
place_
,
startup_prog
.
Block
(
0
),
&
vec_func_list
,
&
global_scope_
);
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
startup_prog
.
Block
(
0
),
&
vec_func_list
,
&
global_scope_
,
false
);
}
}
}
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
,
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
auto
core
=
GetInterpreterCore
(
feed_names
,
fetch_names
);
auto
core
=
GetInterpreterCore
(
feed_names
,
fetch_names
,
true
);
return
core
->
Run
(
feed_names
,
feed_tensors
);
return
core
->
Run
(
feed_names
,
feed_tensors
);
}
}
...
@@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run(
...
@@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run(
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
auto
core
=
GetInterpreterCore
(
feed_names
,
fetch_names
);
auto
core
=
GetInterpreterCore
(
feed_names
,
fetch_names
,
false
);
VLOG
(
4
)
<<
"StandaloneExecutor: "
<<
this
<<
", InterpreterCore: "
<<
core
;
return
core
->
Run
();
return
core
->
Run
();
}
}
framework
::
interpreter
::
CostInfo
StandaloneExecutor
::
DryRun
(
framework
::
interpreter
::
CostInfo
StandaloneExecutor
::
DryRun
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
auto
core
=
GetInterpreterCore
(
feed_names
,
{});
auto
core
=
GetInterpreterCore
(
feed_names
,
{}
,
true
);
return
core
->
DryRun
(
feed_names
,
feed_tensors
);
return
core
->
DryRun
(
feed_names
,
feed_tensors
);
}
}
...
@@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
...
@@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
std
::
shared_ptr
<
InterpreterCore
>
StandaloneExecutor
::
GetInterpreterCore
(
std
::
shared_ptr
<
InterpreterCore
>
StandaloneExecutor
::
GetInterpreterCore
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
const
std
::
vector
<
std
::
string
>&
fetch_names
,
bool
add_fetch_op
)
{
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
oss
<<
"feed:"
;
oss
<<
"feed:"
;
for
(
auto
&
feedname
:
feed_names
)
{
for
(
auto
&
feedname
:
feed_names
)
{
...
@@ -100,15 +106,22 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
...
@@ -100,15 +106,22 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if
(
iter
==
interpretercores_
.
end
())
{
if
(
iter
==
interpretercores_
.
end
())
{
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
VLOG
(
3
)
<<
"add fetch op: "
<<
add_fetch_op
;
// new program.
std
::
shared_ptr
<
InterpreterCore
>
core
=
nullptr
;
auto
new_prog
=
std
::
make_shared
<
framework
::
ProgramDesc
>
(
main_prog_
);
if
(
add_fetch_op
)
{
auto
*
block
=
new_prog
->
MutableBlock
(
0
);
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
interpreter
::
add_fetch
(
fetch_names
,
block
);
// a
// new program.
auto
core
=
auto
new_prog
=
std
::
make_shared
<
framework
::
ProgramDesc
>
(
main_prog_
);
std
::
make_shared
<
InterpreterCore
>
(
place_
,
*
block
,
&
global_scope_
);
auto
*
block
=
new_prog
->
MutableBlock
(
0
);
programs_
.
emplace
(
oss
.
str
(),
new_prog
);
interpreter
::
add_fetch
(
fetch_names
,
block
);
core
=
std
::
make_shared
<
InterpreterCore
>
(
place_
,
*
block
,
&
global_scope_
);
core
->
SetCopyProgram
(
new_prog
);
}
else
{
core
=
std
::
make_shared
<
InterpreterCore
>
(
place_
,
main_prog_
.
Block
(
0
),
&
global_scope_
);
}
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
return
core
;
return
core
;
}
else
{
}
else
{
...
...
paddle/fluid/framework/new_executor/standalone_executor.h
浏览文件 @
1f0512be
...
@@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase {
...
@@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase {
std
::
shared_ptr
<
InterpreterCore
>
GetInterpreterCore
(
std
::
shared_ptr
<
InterpreterCore
>
GetInterpreterCore
(
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
fetch_names
);
const
std
::
vector
<
std
::
string
>&
fetch_names
,
bool
add_fetch_op
);
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
const
ProgramDesc
&
startup_prog_
;
const
ProgramDesc
&
startup_prog_
;
const
ProgramDesc
&
main_prog_
;
const
ProgramDesc
&
main_prog_
;
VariableScope
global_scope_
;
VariableScope
global_scope_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ProgramDesc
>>
programs_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
interpretercores_
;
interpretercores_
;
};
};
...
...
paddle/fluid/framework/scope.cc
浏览文件 @
1f0512be
...
@@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) {
...
@@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) {
ret
=
VarInternal
(
name
);
ret
=
VarInternal
(
name
);
}
}
for
(
auto
l
:
listeners_
)
{
for
(
auto
l
:
listeners_
)
{
l
->
onCreateVariable
(
name
);
l
->
onCreateVariable
(
name
,
ret
);
}
}
return
ret
;
return
ret
;
}
}
...
@@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) {
...
@@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) {
ret
=
VarInternal
(
new_name
);
ret
=
VarInternal
(
new_name
);
}
}
for
(
auto
l
:
listeners_
)
{
for
(
auto
l
:
listeners_
)
{
l
->
onCreateVariable
(
new_name
);
l
->
onCreateVariable
(
new_name
,
ret
);
}
}
return
ret
;
return
ret
;
}
}
...
...
paddle/fluid/framework/scope.h
浏览文件 @
1f0512be
...
@@ -58,7 +58,7 @@ class ScopeListener {
...
@@ -58,7 +58,7 @@ class ScopeListener {
// in original Scope.
// in original Scope.
public:
public:
virtual
~
ScopeListener
()
{}
virtual
~
ScopeListener
()
{}
virtual
void
onCreateVariable
(
const
std
::
string
&
name
)
{}
virtual
void
onCreateVariable
(
const
std
::
string
&
name
,
Variable
*
v
)
{}
virtual
void
onDeleteVariable
(
const
std
::
string
&
name
)
{}
virtual
void
onDeleteVariable
(
const
std
::
string
&
name
)
{}
virtual
void
onRenameVariable
(
const
std
::
string
&
old_name
,
virtual
void
onRenameVariable
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{}
const
std
::
string
&
new_name
)
{}
...
...
paddle/fluid/operators/controlflow/fetch_v2_op.cc
浏览文件 @
1f0512be
...
@@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel {
...
@@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel {
}
}
};
};
class
FetchV2InferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
ctx
->
SyncTypeAndDataType
(
"X"
,
"Out"
);
}
};
class
FetchV2Kernel
{
class
FetchV2Kernel
{
public:
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
@@ -211,7 +204,6 @@ namespace ops = paddle::operators;
...
@@ -211,7 +204,6 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fetch_v2
,
ops
::
FetchV2Op
,
ops
::
FetchV2OpProtoMaker
,
fetch_v2
,
ops
::
FetchV2Op
,
ops
::
FetchV2OpProtoMaker
,
ops
::
FetchV2InferVarType
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
...
...
python/paddle/fluid/executor.py
浏览文件 @
1f0512be
...
@@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
...
@@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
return
feed_count
>
0
return
feed_count
>
0
def
has_fetch_operators
(
block
,
fetch_targets
,
fetch_holder_name
):
def
has_fetch_operators
(
block
,
fetch_targets
,
fetch_holder_name
,
fetch_op
=
'fetch'
):
""" Check whether the block already has fetch operators.
""" Check whether the block already has fetch operators.
Return false if the block does not have any fetch operators.
Return false if the block does not have any fetch operators.
...
@@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
...
@@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_holder_name: the name of the variable that holds the data of
fetch_holder_name: the name of the variable that holds the data of
all fetch targets. The type of this fetch_holder variable is
all fetch targets. The type of this fetch_holder variable is
FETCH_LIST, which is essentially vector<LoDTensor>.
FETCH_LIST, which is essentially vector<LoDTensor>.
fetch_op: the operator name of fetch
Return:
Return:
A boolean value that indicates whether a block has fetch operators
A boolean value that indicates whether a block has fetch operators
...
@@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
...
@@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_count
=
0
fetch_count
=
0
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
if
op
.
desc
.
type
()
==
'fetch'
:
if
op
.
desc
.
type
()
==
fetch_op
:
fetch_count
+=
1
fetch_count
+=
1
assert
op
.
desc
.
output
(
'Out'
)[
0
]
==
fetch_holder_name
assert
op
.
desc
.
output
(
'Out'
)[
0
]
==
fetch_holder_name
fetch_target_name
=
op
.
desc
.
input
(
'X'
)[
0
]
fetch_target_name
=
op
.
desc
.
input
(
'X'
)[
0
]
...
@@ -740,7 +744,7 @@ class Executor(object):
...
@@ -740,7 +744,7 @@ class Executor(object):
fetch_list
,
fetch_list
,
feed_var_name
,
feed_var_name
,
fetch_var_name
,
fetch_var_name
,
skip_fetch
=
False
):
use_fetch_v2
=
False
):
tmp_program
=
program
.
clone
()
tmp_program
=
program
.
clone
()
global_block
=
tmp_program
.
global_block
()
global_block
=
tmp_program
.
global_block
()
...
@@ -775,17 +779,21 @@ class Executor(object):
...
@@ -775,17 +779,21 @@ class Executor(object):
warnings
.
warn
(
warnings
.
warn
(
"The variable %s is not found in program. It is not declared or is pruned."
"The variable %s is not found in program. It is not declared or is pruned."
%
name
)
%
name
)
if
skip_fetch
:
return
tmp_program
if
use_fetch_v2
:
fetch_op
=
'fetch_v2'
else
:
fetch_op
=
'fetch'
# append fetch_operators
# append fetch_operators
if
not
has_fetch_operators
(
global_block
,
fetch_list
,
fetch_var_name
):
if
not
has_fetch_operators
(
global_block
,
fetch_list
,
fetch_var_name
,
fetch_op
):
for
i
,
var
in
enumerate
(
fetch_list
):
for
i
,
var
in
enumerate
(
fetch_list
):
assert
isinstance
(
var
,
Variable
)
or
isinstance
(
assert
isinstance
(
var
,
Variable
)
or
isinstance
(
var
,
six
.
string_types
),
(
var
,
six
.
string_types
),
(
"Wrong type for fetch_list[%s]: %s"
%
(
i
,
type
(
var
)))
"Wrong type for fetch_list[%s]: %s"
%
(
i
,
type
(
var
)))
global_block
.
append_op
(
global_block
.
append_op
(
type
=
'fetch'
,
type
=
fetch_op
,
inputs
=
{
'X'
:
[
var
]},
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
fetch_var
]},
outputs
=
{
'Out'
:
[
fetch_var
]},
attrs
=
{
'col'
:
i
})
attrs
=
{
'col'
:
i
})
...
@@ -1345,7 +1353,13 @@ class Executor(object):
...
@@ -1345,7 +1353,13 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
feed_var_name
=
feed_var_name
,
feed_var_name
=
feed_var_name
,
fetch_var_name
=
fetch_var_name
,
fetch_var_name
=
fetch_var_name
,
skip_fetch
=
True
)
use_fetch_v2
=
True
)
# NPTE(zhiqiu): Construct standalone_executor first, so
# the scope is binded with the variable_scope of standalone_executor
new_exe
=
self
.
_executor_cache
.
_get_exe_from_cache
(
program
,
scope
)
self
.
_feed_data
(
program
,
feed
,
feed_var_name
,
scope
)
self
.
_feed_data
(
program
,
feed
,
feed_var_name
,
scope
)
if
hasattr
(
program
,
'lr_sheduler'
):
if
hasattr
(
program
,
'lr_sheduler'
):
from
paddle.optimizer.lr
import
LRScheduler
from
paddle.optimizer.lr
import
LRScheduler
...
@@ -1360,9 +1374,7 @@ class Executor(object):
...
@@ -1360,9 +1374,7 @@ class Executor(object):
lr_sheduler
.
_var_name
)
lr_sheduler
.
_var_name
)
tensor
.
set
(
data
,
self
.
place
)
tensor
.
set
(
data
,
self
.
place
)
return
self
.
_executor_cache
.
run
(
program
,
scope
,
return
new_exe
.
run
(
list
(
feed
.
keys
()),
fetch_list
,
return_numpy
)
list
(
feed
.
keys
()),
fetch_list
,
return_numpy
)
# use_prune can be overrided by putting optimize_ops in fetch_list
# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list
=
fetch_list
_origin_fetch_list
=
fetch_list
...
...
python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py
浏览文件 @
1f0512be
...
@@ -309,7 +309,7 @@ class TestException(unittest.TestCase):
...
@@ -309,7 +309,7 @@ class TestException(unittest.TestCase):
feed
[
1
][
'data'
][
0
]
=
np
.
nan
feed
[
1
][
'data'
][
0
]
=
np
.
nan
self
.
assertRaises
(
RuntimeError
,
self
.
run_new_executor
,
feed
)
self
.
assertRaises
(
RuntimeError
,
self
.
run_new_executor
,
feed
)
def
test_scope
(
self
):
def
test_scope
_find_temp_var
(
self
):
feed
=
[{
feed
=
[{
'id'
:
np
.
array
([
1
,
2
,
3
,
4
,
5
]).
astype
(
np
.
int64
),
'id'
:
np
.
array
([
1
,
2
,
3
,
4
,
5
]).
astype
(
np
.
int64
),
'data'
:
np
.
array
([
1
,
2
,
3
]).
astype
(
np
.
float32
),
'data'
:
np
.
array
([
1
,
2
,
3
]).
astype
(
np
.
float32
),
...
@@ -318,7 +318,7 @@ class TestException(unittest.TestCase):
...
@@ -318,7 +318,7 @@ class TestException(unittest.TestCase):
'data'
:
np
.
array
([
2
,
2
,
2
]).
astype
(
np
.
float32
),
'data'
:
np
.
array
([
2
,
2
,
2
]).
astype
(
np
.
float32
),
}]
}]
self
.
run_new_executor
(
feed
)
self
.
run_new_executor
(
feed
)
self
.
assertIsNo
tNo
ne
(
paddle
.
static
.
global_scope
().
find_var
(
self
.
assertIsNone
(
paddle
.
static
.
global_scope
().
find_var
(
self
.
fetch_vars
.
name
))
self
.
fetch_vars
.
name
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录