Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a6e99dc7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a6e99dc7
编写于
11月 09, 2021
作者:
A
Aurelius84
提交者:
GitHub
11月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor InterpretorCore and Modify into BlockDesc (#37056)
上级
993ec76a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
82 addition
and
81 deletion
+82
-81
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+15
-39
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+5
-8
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+36
-20
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+8
-4
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+2
-2
paddle/fluid/framework/new_executor/standalone_executor.cc
paddle/fluid/framework/new_executor/standalone_executor.cc
+11
-4
paddle/fluid/framework/new_executor/standalone_executor.h
paddle/fluid/framework/new_executor/standalone_executor.h
+1
-0
paddle/fluid/framework/new_executor/stream_analyzer.cc
paddle/fluid/framework/new_executor/stream_analyzer.cc
+4
-4
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
a6e99dc7
...
@@ -33,13 +33,11 @@ namespace framework {
...
@@ -33,13 +33,11 @@ namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it.
// NOTE(Aurelius84): Need a better strategy to determine it.
static
constexpr
size_t
kHostNumThreads
=
4
;
static
constexpr
size_t
kHostNumThreads
=
4
;
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
BlockDesc
*
block
,
const
ProgramDesc
&
main_prog
,
VariableScope
*
global_scope
,
VariableScope
*
global_scope
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
)
const
std
::
vector
<
std
::
string
>&
fetch_names
)
:
place_
(
place
),
:
place_
(
place
),
main_program_
(
main_prog
),
block_
(
block
),
global_scope_
(
global_scope
),
global_scope_
(
global_scope
),
stream_analyzer_
(
place
),
stream_analyzer_
(
place
),
async_work_queue_
(
kHostNumThreads
,
&
main_thread_blocker_
)
{
async_work_queue_
(
kHostNumThreads
,
&
main_thread_blocker_
)
{
...
@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
...
@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
// Step1: add feedop and fetchop to main_program
AddFetch
(
fetch_names
);
// prune
// prune
// optmize graph pass
// optmize graph pass
...
@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
...
@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
// convert to run graph
// convert to run graph
}
}
void
InterpreterCore
::
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
auto
*
fetch_holder
=
main_program_
.
MutableBlock
(
0
)
->
Var
(
"fetch_vars"
);
fetch_holder
->
SetType
(
proto
::
VarType
::
FETCH_LIST
);
fetch_holder
->
SetPersistable
(
true
);
int
i
=
0
;
for
(
auto
&
fetch_name
:
fetch_names
)
{
// append fetch op
auto
*
op
=
main_program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"fetch_v2"
);
op
->
SetInput
(
"X"
,
{
fetch_name
});
op
->
SetOutput
(
"Out"
,
{
"fetch_vars"
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
i
++
;
}
}
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
auto
FeedInput
=
[
&
]
{
auto
FeedInput
=
[
&
]
{
...
@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
...
@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
};
};
if
(
is_build_
==
false
)
{
if
(
is_build_
==
false
)
{
paddle
::
framework
::
interpreter
core
::
build_variable_scope
(
main_program
_
,
paddle
::
framework
::
interpreter
::
build_variable_scope
(
*
block
_
,
global_scope_
);
global_scope_
);
FeedInput
();
FeedInput
();
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
main_program
_
,
&
vec_func_list_
,
global_scope_
);
place_
,
*
block
_
,
&
vec_func_list_
,
global_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
();
Convert
();
...
@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
...
@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
}
}
// return Fetch Tensors
// return Fetch Tensors
auto
*
fetch_var
=
global_scope_
->
Var
(
"fetch_vars"
);
auto
*
fetch_var
=
global_scope_
->
Var
(
interpreter
::
kFetchVarName
);
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
}
}
...
@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
...
@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
std
::
vector
<
size_t
>
vec_temp
;
std
::
vector
<
size_t
>
vec_temp
;
for
(
auto
&
item
:
vec_instruction_
[
i
].
Outputs
())
{
for
(
auto
&
item
:
vec_instruction_
[
i
].
Outputs
())
{
for
(
auto
id
:
item
.
second
)
{
for
(
auto
id
:
item
.
second
)
{
vec_temp
=
vec_temp
=
interpreter
::
merge_vector
(
vec_temp
,
input_var2op_info_
[
id
]);
interpretercore
::
merge_vector
(
vec_temp
,
input_var2op_info_
[
id
]);
}
}
}
}
...
@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
...
@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
[
&
,
next_id
]
{
RunInstructionAsync
(
next_id
);
});
[
&
,
next_id
]
{
RunInstructionAsync
(
next_id
);
});
}
}
}
}
auto
direct_run_ops
=
interpreter
core
::
merge_vector
(
auto
direct_run_ops
=
interpreter
::
merge_vector
(
next_instr
.
SyncRunIds
(),
next_instr
.
SyncRunIds
(),
next_instr
.
DirectRunIds
());
next_instr
.
DirectRunIds
());
size_t
first_op
=
0
;
size_t
first_op
=
0
;
for
(
auto
next_id
:
direct_run_ops
)
{
for
(
auto
next_id
:
direct_run_ops
)
{
if
(
IsReady
(
next_id
))
{
if
(
IsReady
(
next_id
))
{
...
@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
...
@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
};
};
if
(
is_build_
==
false
)
{
if
(
is_build_
==
false
)
{
paddle
::
framework
::
interpreter
core
::
build_variable_scope
(
main_program
_
,
paddle
::
framework
::
interpreter
::
build_variable_scope
(
*
block
_
,
global_scope_
);
global_scope_
);
FeedInput
();
FeedInput
();
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
main_program
_
,
&
vec_func_list_
,
global_scope_
);
place_
,
*
block
_
,
&
vec_func_list_
,
global_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
();
Convert
();
...
...
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
a6e99dc7
...
@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
...
@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class
InterpreterCore
{
class
InterpreterCore
{
public:
public:
InterpreterCore
(
const
platform
::
Place
&
place
,
const
ProgramDesc
&
main_prog
,
InterpreterCore
(
const
platform
::
Place
&
place
,
BlockDesc
*
block
,
VariableScope
*
global_scope
,
VariableScope
*
global_scope
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
);
const
std
::
vector
<
std
::
string
>&
fetch_names
);
paddle
::
framework
::
FetchList
Run
(
paddle
::
framework
::
FetchList
Run
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
...
@@ -72,15 +71,14 @@ class InterpreterCore {
...
@@ -72,15 +71,14 @@ class InterpreterCore {
void
RunInstructionAsync
(
size_t
instr_id
);
void
RunInstructionAsync
(
size_t
instr_id
);
void
RunNextInstructions
(
const
Instruction
&
instr_id
,
void
RunNextInstructions
(
const
Instruction
&
instr_id
,
std
::
queue
<
size_t
>*
reserved_next_ops
);
std
::
queue
<
size_t
>*
reserved_next_ops
);
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
BuildSkipShareLoDInfo
();
void
BuildSkipShareLoDInfo
();
bool
is_build_
;
bool
is_build_
;
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
ProgramDesc
main_program_
;
BlockDesc
*
block_
;
// not owned
VariableScope
*
global_scope_
;
VariableScope
*
global_scope_
;
// not owned
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list_
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list_
;
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
...
@@ -88,7 +86,6 @@ class InterpreterCore {
...
@@ -88,7 +86,6 @@ class InterpreterCore {
InstructionInfo
instruction_info_
;
InstructionInfo
instruction_info_
;
std
::
vector
<
size_t
>
dependecy_count_
;
std
::
vector
<
size_t
>
dependecy_count_
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_var2op_info_
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_var2op_info_
;
std
::
vector
<
VariableMetaInfo
>
ref_coun_info_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
std
::
vector
<
std
::
string
>
feed_names_
;
std
::
vector
<
std
::
string
>
feed_names_
;
...
@@ -97,7 +94,7 @@ class InterpreterCore {
...
@@ -97,7 +94,7 @@ class InterpreterCore {
StreamAnalyzer
stream_analyzer_
;
StreamAnalyzer
stream_analyzer_
;
EventManager
event_manager_
;
EventManager
event_manager_
;
EventsWaiter
main_thread_blocker_
;
EventsWaiter
main_thread_blocker_
;
interpreter
core
::
AsyncWorkQueue
async_work_queue_
;
interpreter
::
AsyncWorkQueue
async_work_queue_
;
details
::
ExceptionHolder
exception_holder_
;
details
::
ExceptionHolder
exception_holder_
;
std
::
shared_ptr
<
EventsWaiter
::
EventNotifier
>
exception_notifier_
{
nullptr
};
std
::
shared_ptr
<
EventsWaiter
::
EventNotifier
>
exception_notifier_
{
nullptr
};
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
a6e99dc7
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpreter
core
{
namespace
interpreter
{
using
VariableIdMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
;
using
VariableIdMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
;
AtomicVectorSizeT
&
AsyncWorkQueue
::
PrepareAtomicDeps
(
AtomicVectorSizeT
&
AsyncWorkQueue
::
PrepareAtomicDeps
(
...
@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
...
@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
}
}
}
void
build_variable_scope
(
const
framework
::
ProgramDesc
&
pdesc
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
)
{
auto
&
global_block
=
pdesc
.
Block
(
0
);
for
(
auto
&
var_desc
:
block
.
AllVars
())
{
for
(
auto
&
var_desc
:
global_block
.
AllVars
())
{
auto
var_name
=
var_desc
->
Name
();
auto
var_name
=
var_desc
->
Name
();
if
(
var_name
==
framework
::
kEmptyVarName
)
{
if
(
var_name
==
framework
::
kEmptyVarName
)
{
continue
;
continue
;
...
@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
...
@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
std
::
vector
<
OpFuncNode
>
apply_data_transform
(
std
::
vector
<
OpFuncNode
>
apply_data_transform
(
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
VariableValueMap
&
ins_map_temp
,
VariableScope
*
var_scope
,
VariableValueMap
*
ins_map_temp
,
VariableScope
*
var_scope
,
OpFuncNode
&
op_func_node
)
{
OpFuncNode
*
op_func_node
)
{
auto
&
op_base
=
op_func_node
.
operator_base_
;
auto
&
op_base
=
op_func_node
->
operator_base_
;
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
"op_base is null, please pass a valid "
"op_base is null, please pass a valid "
"op_base in apply_data_transform."
));
"op_base in apply_data_transform."
));
...
@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
no_data_transform_index
;
// record the no need transform variable index.
no_data_transform_index
;
// record the no need transform variable index.
std
::
vector
<
OpFuncNode
>
copy_func_nodes
;
// return all the copy opfuncnode.
std
::
vector
<
OpFuncNode
>
copy_func_nodes
;
// return all the copy opfuncnode.
for
(
auto
&
var_name_item
:
ins_map_temp
)
{
for
(
auto
&
var_name_item
:
*
ins_map_temp
)
{
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
var
=
var_name_item
.
second
[
i
];
auto
var
=
var_name_item
.
second
[
i
];
auto
&
var_name
=
inputs_names
[
var_name_item
.
first
].
at
(
i
);
auto
&
var_name
=
inputs_names
[
var_name_item
.
first
].
at
(
i
);
...
@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
apply_place_transform_for_var
(
apply_place_transform_for_var
(
kernel_type_for_var
,
expected_kernel_key
,
place
,
var_name
,
kernel_type_for_var
,
expected_kernel_key
,
place
,
var_name
,
var_name_item
.
first
,
op_func_node
,
var
,
var_scope
);
var_name_item
.
first
,
*
op_func_node
,
var
,
var_scope
);
op_func_node
.
input_index
[
var_name_item
.
first
][
i
]
=
op_func_node
->
input_index
[
var_name_item
.
first
][
i
]
=
var_scope
->
VarId
(
new_var_name
);
var_scope
->
VarId
(
new_var_name
);
copy_func_nodes
.
push_back
(
copy_op_func_node
);
copy_func_nodes
.
push_back
(
copy_op_func_node
);
var_name_item
.
second
[
i
]
=
var_scope
->
Var
(
new_var_name
);
var_name_item
.
second
[
i
]
=
var_scope
->
Var
(
new_var_name
);
...
@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
}
}
}
}
}
}
op_func_node
.
no_data_transform_index
=
std
::
move
(
no_data_transform_index
);
op_func_node
->
no_data_transform_index
=
std
::
move
(
no_data_transform_index
);
return
copy_func_nodes
;
return
copy_func_nodes
;
}
}
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
ProgramDesc
&
pdesc
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
)
{
auto
&
global_block
=
pdesc
.
Block
(
0
);
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
// Step 1: create all ops for
global
block.
// Step 1: create all ops for
current
block.
auto
ops
=
create_all_ops
(
global_
block
);
auto
ops
=
create_all_ops
(
block
);
auto
unused_var_map
=
get_unused_vars
(
global_
block
,
ops
);
auto
unused_var_map
=
get_unused_vars
(
block
,
ops
);
size_t
ops_index
=
0
;
size_t
ops_index
=
0
;
for
(
auto
&
op
:
global_
block
.
AllOps
())
{
for
(
auto
&
op
:
block
.
AllOps
())
{
VLOG
(
6
)
<<
"Build OpFuncNode from : "
<<
op
->
Type
();
VLOG
(
6
)
<<
"Build OpFuncNode from : "
<<
op
->
Type
();
auto
op_base
=
ops
[
ops_index
++
];
auto
op_base
=
ops
[
ops_index
++
];
...
@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
// apply_data_transform.
op_func_node
.
operator_base_
=
op_base
;
op_func_node
.
operator_base_
=
op_base
;
copy_op_to_insert
=
apply_data_transform
(
copy_op_to_insert
=
apply_data_transform
(
expected_kernel_key
,
place
,
ins_map_temp
,
var_scope
,
op_func_node
);
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
&
op_func_node
);
for
(
auto
&
item
:
copy_op_to_insert
)
{
for
(
auto
&
item
:
copy_op_to_insert
)
{
vec_func_list
->
push_back
(
item
);
vec_func_list
->
push_back
(
item
);
}
}
...
@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
...
@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
}
}
}
}
void
add_fetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
,
framework
::
BlockDesc
*
block
)
{
auto
*
fetch_holder
=
block
->
Var
(
kFetchVarName
);
fetch_holder
->
SetType
(
proto
::
VarType
::
FETCH_LIST
);
fetch_holder
->
SetPersistable
(
true
);
int
i
=
0
;
for
(
auto
&
fetch_name
:
fetch_names
)
{
// append fetch op
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"fetch_v2"
);
op
->
SetInput
(
"X"
,
{
fetch_name
});
op
->
SetOutput
(
"Out"
,
{
kFetchVarName
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
i
++
;
}
}
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
const
std
::
vector
<
size_t
>&
second
)
{
const
std
::
vector
<
size_t
>&
second
)
{
std
::
vector
<
size_t
>
out
(
first
.
size
()
+
second
.
size
());
std
::
vector
<
size_t
>
out
(
first
.
size
()
+
second
.
size
());
...
@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
...
@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
return
out
;
return
out
;
}
}
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
a6e99dc7
...
@@ -48,9 +48,10 @@
...
@@ -48,9 +48,10 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpreter
core
{
namespace
interpreter
{
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
static
constexpr
char
kFetchVarName
[]
=
"fetch_vars"
;
class
AsyncWorkQueue
{
class
AsyncWorkQueue
{
public:
public:
...
@@ -96,17 +97,20 @@ class AsyncWorkQueue {
...
@@ -96,17 +97,20 @@ class AsyncWorkQueue {
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
const
platform
::
Place
&
dst_place
);
const
platform
::
Place
&
dst_place
);
void
build_variable_scope
(
const
framework
::
ProgramDesc
&
pdesc
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
);
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
ProgramDesc
&
pdesc
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
);
void
add_fetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
,
framework
::
BlockDesc
*
block
);
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
const
std
::
vector
<
size_t
>&
second
);
const
std
::
vector
<
size_t
>&
second
);
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
a6e99dc7
...
@@ -776,7 +776,7 @@ class Instruction {
...
@@ -776,7 +776,7 @@ class Instruction {
std
::
vector
<
std
::
pair
<
Variable
*
,
Variable
*>>
vec_inplace_in_to_out_
;
std
::
vector
<
std
::
pair
<
Variable
*
,
Variable
*>>
vec_inplace_in_to_out_
;
};
};
namespace
interpreter
core
{
namespace
interpreter
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
...
@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
...
@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static
bool
IsMemcpyD2H
(
const
Instruction
&
instr
)
{
static
bool
IsMemcpyD2H
(
const
Instruction
&
instr
)
{
return
instr
.
OpBase
()
->
Type
()
==
kMemcpyD2H
;
return
instr
.
OpBase
()
->
Type
()
==
kMemcpyD2H
;
}
}
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/standalone_executor.cc
浏览文件 @
a6e99dc7
...
@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
...
@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// run startup program
// run startup program
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
startup_prog
,
&
vec_func_list
,
&
global_scope_
);
place_
,
startup_prog
.
Block
(
0
)
,
&
vec_func_list
,
&
global_scope_
);
}
}
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
...
@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
...
@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if
(
iter
==
interpretercores_
.
end
())
{
if
(
iter
==
interpretercores_
.
end
())
{
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
auto
core
=
std
::
make_shared
<
InterpreterCore
>
(
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
place_
,
main_prog_
,
&
global_scope_
,
feed_names
,
fetch_names
);
// new program.
auto
new_prog
=
std
::
make_shared
<
framework
::
ProgramDesc
>
(
main_prog_
);
auto
*
block
=
new_prog
->
MutableBlock
(
0
);
interpreter
::
add_fetch
(
fetch_names
,
block
);
auto
core
=
std
::
make_shared
<
InterpreterCore
>
(
place_
,
block
,
&
global_scope_
,
feed_names
);
programs_
.
emplace
(
oss
.
str
(),
new_prog
);
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
return
core
;
return
core
;
}
else
{
}
else
{
...
...
paddle/fluid/framework/new_executor/standalone_executor.h
浏览文件 @
a6e99dc7
...
@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
...
@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
Scope
*
outer_scope_
;
Scope
*
outer_scope_
;
VariableScope
global_scope_
;
VariableScope
global_scope_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ProgramDesc
>>
programs_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
interpretercores_
;
interpretercores_
;
};
};
...
...
paddle/fluid/framework/new_executor/stream_analyzer.cc
浏览文件 @
a6e99dc7
...
@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
...
@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const
OpFuncNode
&
op_func_node
)
{
const
OpFuncNode
&
op_func_node
)
{
auto
&
op_type
=
op_func_node
.
operator_base_
->
Type
();
auto
&
op_type
=
op_func_node
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpreter
core
::
kMemcpyH2D
)
{
if
(
op_type
==
interpreter
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpreter
core
::
kMemcpyD2H
)
{
}
else
if
(
op_type
==
interpreter
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
}
...
@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
...
@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
bool
StreamAnalyzer
::
IsDirectRun
(
Instruction
&
cur_instr
,
bool
StreamAnalyzer
::
IsDirectRun
(
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
const
Instruction
&
next_instr
)
{
return
(
&
cur_instr
.
DeviceContext
()
==
&
next_instr
.
DeviceContext
()
||
return
(
&
cur_instr
.
DeviceContext
()
==
&
next_instr
.
DeviceContext
()
||
interpreter
core
::
IsMemcpyD2H
(
cur_instr
)
||
interpreter
::
IsMemcpyD2H
(
cur_instr
)
||
interpreter
core
::
IsMemcpyH2D
(
next_instr
));
interpreter
::
IsMemcpyH2D
(
next_instr
));
}
}
platform
::
DeviceType
StreamAnalyzer
::
GetWaiterType
(
const
Instruction
&
instr
)
{
platform
::
DeviceType
StreamAnalyzer
::
GetWaiterType
(
const
Instruction
&
instr
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录