Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a6e99dc7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a6e99dc7
编写于
11月 09, 2021
作者:
A
Aurelius84
提交者:
GitHub
11月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor InterpretorCore and Modify into BlockDesc (#37056)
上级
993ec76a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
82 addition
and
81 deletion
+82
-81
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+15
-39
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+5
-8
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+36
-20
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+8
-4
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+2
-2
paddle/fluid/framework/new_executor/standalone_executor.cc
paddle/fluid/framework/new_executor/standalone_executor.cc
+11
-4
paddle/fluid/framework/new_executor/standalone_executor.h
paddle/fluid/framework/new_executor/standalone_executor.h
+1
-0
paddle/fluid/framework/new_executor/stream_analyzer.cc
paddle/fluid/framework/new_executor/stream_analyzer.cc
+4
-4
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
a6e99dc7
...
@@ -33,13 +33,11 @@ namespace framework {
...
@@ -33,13 +33,11 @@ namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it.
// NOTE(Aurelius84): Need a better strategy to determine it.
static
constexpr
size_t
kHostNumThreads
=
4
;
static
constexpr
size_t
kHostNumThreads
=
4
;
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
BlockDesc
*
block
,
const
ProgramDesc
&
main_prog
,
VariableScope
*
global_scope
,
VariableScope
*
global_scope
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
)
const
std
::
vector
<
std
::
string
>&
fetch_names
)
:
place_
(
place
),
:
place_
(
place
),
main_program_
(
main_prog
),
block_
(
block
),
global_scope_
(
global_scope
),
global_scope_
(
global_scope
),
stream_analyzer_
(
place
),
stream_analyzer_
(
place
),
async_work_queue_
(
kHostNumThreads
,
&
main_thread_blocker_
)
{
async_work_queue_
(
kHostNumThreads
,
&
main_thread_blocker_
)
{
...
@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
...
@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
exception_notifier_
=
main_thread_blocker_
.
RegisterEvent
(
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
kExceptionCaught
,
[
this
]()
{
return
exception_holder_
.
IsCaught
();
});
// Step1: add feedop and fetchop to main_program
AddFetch
(
fetch_names
);
// prune
// prune
// optmize graph pass
// optmize graph pass
...
@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
...
@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
// convert to run graph
// convert to run graph
}
}
void
InterpreterCore
::
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
)
{
auto
*
fetch_holder
=
main_program_
.
MutableBlock
(
0
)
->
Var
(
"fetch_vars"
);
fetch_holder
->
SetType
(
proto
::
VarType
::
FETCH_LIST
);
fetch_holder
->
SetPersistable
(
true
);
int
i
=
0
;
for
(
auto
&
fetch_name
:
fetch_names
)
{
// append fetch op
auto
*
op
=
main_program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"fetch_v2"
);
op
->
SetInput
(
"X"
,
{
fetch_name
});
op
->
SetOutput
(
"Out"
,
{
"fetch_vars"
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
i
++
;
}
}
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
paddle
::
framework
::
FetchList
InterpreterCore
::
Run
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
)
{
auto
FeedInput
=
[
&
]
{
auto
FeedInput
=
[
&
]
{
...
@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
...
@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
};
};
if
(
is_build_
==
false
)
{
if
(
is_build_
==
false
)
{
paddle
::
framework
::
interpreter
core
::
build_variable_scope
(
main_program
_
,
paddle
::
framework
::
interpreter
::
build_variable_scope
(
*
block
_
,
global_scope_
);
global_scope_
);
FeedInput
();
FeedInput
();
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
main_program
_
,
&
vec_func_list_
,
global_scope_
);
place_
,
*
block
_
,
&
vec_func_list_
,
global_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
();
Convert
();
...
@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
...
@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
}
}
// return Fetch Tensors
// return Fetch Tensors
auto
*
fetch_var
=
global_scope_
->
Var
(
"fetch_vars"
);
auto
*
fetch_var
=
global_scope_
->
Var
(
interpreter
::
kFetchVarName
);
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
return
*
(
fetch_var
->
GetMutable
<
framework
::
FetchList
>
());
}
}
...
@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
...
@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
std
::
vector
<
size_t
>
vec_temp
;
std
::
vector
<
size_t
>
vec_temp
;
for
(
auto
&
item
:
vec_instruction_
[
i
].
Outputs
())
{
for
(
auto
&
item
:
vec_instruction_
[
i
].
Outputs
())
{
for
(
auto
id
:
item
.
second
)
{
for
(
auto
id
:
item
.
second
)
{
vec_temp
=
vec_temp
=
interpreter
::
merge_vector
(
vec_temp
,
input_var2op_info_
[
id
]);
interpretercore
::
merge_vector
(
vec_temp
,
input_var2op_info_
[
id
]);
}
}
}
}
...
@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
...
@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
[
&
,
next_id
]
{
RunInstructionAsync
(
next_id
);
});
[
&
,
next_id
]
{
RunInstructionAsync
(
next_id
);
});
}
}
}
}
auto
direct_run_ops
=
interpreter
core
::
merge_vector
(
auto
direct_run_ops
=
interpreter
::
merge_vector
(
next_instr
.
SyncRunIds
(),
next_instr
.
SyncRunIds
(),
next_instr
.
DirectRunIds
());
next_instr
.
DirectRunIds
());
size_t
first_op
=
0
;
size_t
first_op
=
0
;
for
(
auto
next_id
:
direct_run_ops
)
{
for
(
auto
next_id
:
direct_run_ops
)
{
if
(
IsReady
(
next_id
))
{
if
(
IsReady
(
next_id
))
{
...
@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
...
@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
};
};
if
(
is_build_
==
false
)
{
if
(
is_build_
==
false
)
{
paddle
::
framework
::
interpreter
core
::
build_variable_scope
(
main_program
_
,
paddle
::
framework
::
interpreter
::
build_variable_scope
(
*
block
_
,
global_scope_
);
global_scope_
);
FeedInput
();
FeedInput
();
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
main_program
_
,
&
vec_func_list_
,
global_scope_
);
place_
,
*
block
_
,
&
vec_func_list_
,
global_scope_
);
is_build_
=
true
;
is_build_
=
true
;
// convert vec func_list to graph
// convert vec func_list to graph
Convert
();
Convert
();
...
...
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
a6e99dc7
...
@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
...
@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class
InterpreterCore
{
class
InterpreterCore
{
public:
public:
InterpreterCore
(
const
platform
::
Place
&
place
,
const
ProgramDesc
&
main_prog
,
InterpreterCore
(
const
platform
::
Place
&
place
,
BlockDesc
*
block
,
VariableScope
*
global_scope
,
VariableScope
*
global_scope
,
const
std
::
vector
<
std
::
string
>&
feed_names
,
const
std
::
vector
<
std
::
string
>&
feed_names
);
const
std
::
vector
<
std
::
string
>&
fetch_names
);
paddle
::
framework
::
FetchList
Run
(
paddle
::
framework
::
FetchList
Run
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
const
std
::
vector
<
framework
::
LoDTensor
>&
feed_tensors
);
...
@@ -72,15 +71,14 @@ class InterpreterCore {
...
@@ -72,15 +71,14 @@ class InterpreterCore {
void
RunInstructionAsync
(
size_t
instr_id
);
void
RunInstructionAsync
(
size_t
instr_id
);
void
RunNextInstructions
(
const
Instruction
&
instr_id
,
void
RunNextInstructions
(
const
Instruction
&
instr_id
,
std
::
queue
<
size_t
>*
reserved_next_ops
);
std
::
queue
<
size_t
>*
reserved_next_ops
);
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
BuildSkipShareLoDInfo
();
void
BuildSkipShareLoDInfo
();
bool
is_build_
;
bool
is_build_
;
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
ProgramDesc
main_program_
;
BlockDesc
*
block_
;
// not owned
VariableScope
*
global_scope_
;
VariableScope
*
global_scope_
;
// not owned
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list_
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list_
;
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
std
::
vector
<
Instruction
>
vec_instruction_
;
// deconstruct before OpFuncNode
...
@@ -88,7 +86,6 @@ class InterpreterCore {
...
@@ -88,7 +86,6 @@ class InterpreterCore {
InstructionInfo
instruction_info_
;
InstructionInfo
instruction_info_
;
std
::
vector
<
size_t
>
dependecy_count_
;
std
::
vector
<
size_t
>
dependecy_count_
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_var2op_info_
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_var2op_info_
;
std
::
vector
<
VariableMetaInfo
>
ref_coun_info_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
std
::
vector
<
VariableMetaInfo
>
vec_meta_info_
;
std
::
vector
<
std
::
string
>
feed_names_
;
std
::
vector
<
std
::
string
>
feed_names_
;
...
@@ -97,7 +94,7 @@ class InterpreterCore {
...
@@ -97,7 +94,7 @@ class InterpreterCore {
StreamAnalyzer
stream_analyzer_
;
StreamAnalyzer
stream_analyzer_
;
EventManager
event_manager_
;
EventManager
event_manager_
;
EventsWaiter
main_thread_blocker_
;
EventsWaiter
main_thread_blocker_
;
interpreter
core
::
AsyncWorkQueue
async_work_queue_
;
interpreter
::
AsyncWorkQueue
async_work_queue_
;
details
::
ExceptionHolder
exception_holder_
;
details
::
ExceptionHolder
exception_holder_
;
std
::
shared_ptr
<
EventsWaiter
::
EventNotifier
>
exception_notifier_
{
nullptr
};
std
::
shared_ptr
<
EventsWaiter
::
EventNotifier
>
exception_notifier_
{
nullptr
};
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
a6e99dc7
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpreter
core
{
namespace
interpreter
{
using
VariableIdMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
;
using
VariableIdMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
;
AtomicVectorSizeT
&
AsyncWorkQueue
::
PrepareAtomicDeps
(
AtomicVectorSizeT
&
AsyncWorkQueue
::
PrepareAtomicDeps
(
...
@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
...
@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
}
}
}
void
build_variable_scope
(
const
framework
::
ProgramDesc
&
pdesc
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
)
{
auto
&
global_block
=
pdesc
.
Block
(
0
);
for
(
auto
&
var_desc
:
block
.
AllVars
())
{
for
(
auto
&
var_desc
:
global_block
.
AllVars
())
{
auto
var_name
=
var_desc
->
Name
();
auto
var_name
=
var_desc
->
Name
();
if
(
var_name
==
framework
::
kEmptyVarName
)
{
if
(
var_name
==
framework
::
kEmptyVarName
)
{
continue
;
continue
;
...
@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
...
@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
std
::
vector
<
OpFuncNode
>
apply_data_transform
(
std
::
vector
<
OpFuncNode
>
apply_data_transform
(
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
const
OpKernelType
&
expected_kernel_key
,
const
platform
::
Place
&
place
,
VariableValueMap
&
ins_map_temp
,
VariableScope
*
var_scope
,
VariableValueMap
*
ins_map_temp
,
VariableScope
*
var_scope
,
OpFuncNode
&
op_func_node
)
{
OpFuncNode
*
op_func_node
)
{
auto
&
op_base
=
op_func_node
.
operator_base_
;
auto
&
op_base
=
op_func_node
->
operator_base_
;
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
op_base
,
platform
::
errors
::
PreconditionNotMet
(
"op_base is null, please pass a valid "
"op_base is null, please pass a valid "
"op_base in apply_data_transform."
));
"op_base in apply_data_transform."
));
...
@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
no_data_transform_index
;
// record the no need transform variable index.
no_data_transform_index
;
// record the no need transform variable index.
std
::
vector
<
OpFuncNode
>
copy_func_nodes
;
// return all the copy opfuncnode.
std
::
vector
<
OpFuncNode
>
copy_func_nodes
;
// return all the copy opfuncnode.
for
(
auto
&
var_name_item
:
ins_map_temp
)
{
for
(
auto
&
var_name_item
:
*
ins_map_temp
)
{
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
var
=
var_name_item
.
second
[
i
];
auto
var
=
var_name_item
.
second
[
i
];
auto
&
var_name
=
inputs_names
[
var_name_item
.
first
].
at
(
i
);
auto
&
var_name
=
inputs_names
[
var_name_item
.
first
].
at
(
i
);
...
@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
std
::
tie
(
new_var_name
,
copy_op_func_node
)
=
apply_place_transform_for_var
(
apply_place_transform_for_var
(
kernel_type_for_var
,
expected_kernel_key
,
place
,
var_name
,
kernel_type_for_var
,
expected_kernel_key
,
place
,
var_name
,
var_name_item
.
first
,
op_func_node
,
var
,
var_scope
);
var_name_item
.
first
,
*
op_func_node
,
var
,
var_scope
);
op_func_node
.
input_index
[
var_name_item
.
first
][
i
]
=
op_func_node
->
input_index
[
var_name_item
.
first
][
i
]
=
var_scope
->
VarId
(
new_var_name
);
var_scope
->
VarId
(
new_var_name
);
copy_func_nodes
.
push_back
(
copy_op_func_node
);
copy_func_nodes
.
push_back
(
copy_op_func_node
);
var_name_item
.
second
[
i
]
=
var_scope
->
Var
(
new_var_name
);
var_name_item
.
second
[
i
]
=
var_scope
->
Var
(
new_var_name
);
...
@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
...
@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
}
}
}
}
}
}
op_func_node
.
no_data_transform_index
=
std
::
move
(
no_data_transform_index
);
op_func_node
->
no_data_transform_index
=
std
::
move
(
no_data_transform_index
);
return
copy_func_nodes
;
return
copy_func_nodes
;
}
}
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
ProgramDesc
&
pdesc
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
)
{
VariableScope
*
var_scope
)
{
auto
&
global_block
=
pdesc
.
Block
(
0
);
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_op_kernels
=
OperatorWithKernel
::
AllOpKernels
();
// Step 1: create all ops for
global
block.
// Step 1: create all ops for
current
block.
auto
ops
=
create_all_ops
(
global_
block
);
auto
ops
=
create_all_ops
(
block
);
auto
unused_var_map
=
get_unused_vars
(
global_
block
,
ops
);
auto
unused_var_map
=
get_unused_vars
(
block
,
ops
);
size_t
ops_index
=
0
;
size_t
ops_index
=
0
;
for
(
auto
&
op
:
global_
block
.
AllOps
())
{
for
(
auto
&
op
:
block
.
AllOps
())
{
VLOG
(
6
)
<<
"Build OpFuncNode from : "
<<
op
->
Type
();
VLOG
(
6
)
<<
"Build OpFuncNode from : "
<<
op
->
Type
();
auto
op_base
=
ops
[
ops_index
++
];
auto
op_base
=
ops
[
ops_index
++
];
...
@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
// apply_data_transform.
op_func_node
.
operator_base_
=
op_base
;
op_func_node
.
operator_base_
=
op_base
;
copy_op_to_insert
=
apply_data_transform
(
copy_op_to_insert
=
apply_data_transform
(
expected_kernel_key
,
place
,
ins_map_temp
,
var_scope
,
op_func_node
);
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
&
op_func_node
);
for
(
auto
&
item
:
copy_op_to_insert
)
{
for
(
auto
&
item
:
copy_op_to_insert
)
{
vec_func_list
->
push_back
(
item
);
vec_func_list
->
push_back
(
item
);
}
}
...
@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
...
@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
}
}
}
}
void
add_fetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
,
framework
::
BlockDesc
*
block
)
{
auto
*
fetch_holder
=
block
->
Var
(
kFetchVarName
);
fetch_holder
->
SetType
(
proto
::
VarType
::
FETCH_LIST
);
fetch_holder
->
SetPersistable
(
true
);
int
i
=
0
;
for
(
auto
&
fetch_name
:
fetch_names
)
{
// append fetch op
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"fetch_v2"
);
op
->
SetInput
(
"X"
,
{
fetch_name
});
op
->
SetOutput
(
"Out"
,
{
kFetchVarName
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
i
++
;
}
}
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
const
std
::
vector
<
size_t
>&
second
)
{
const
std
::
vector
<
size_t
>&
second
)
{
std
::
vector
<
size_t
>
out
(
first
.
size
()
+
second
.
size
());
std
::
vector
<
size_t
>
out
(
first
.
size
()
+
second
.
size
());
...
@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
...
@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
return
out
;
return
out
;
}
}
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
a6e99dc7
...
@@ -48,9 +48,10 @@
...
@@ -48,9 +48,10 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpreter
core
{
namespace
interpreter
{
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
using
AtomicVectorSizeT
=
std
::
vector
<
std
::
unique_ptr
<
std
::
atomic
<
size_t
>>>
;
static
constexpr
char
kFetchVarName
[]
=
"fetch_vars"
;
class
AsyncWorkQueue
{
class
AsyncWorkQueue
{
public:
public:
...
@@ -96,17 +97,20 @@ class AsyncWorkQueue {
...
@@ -96,17 +97,20 @@ class AsyncWorkQueue {
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
const
platform
::
Place
&
dst_place
);
const
platform
::
Place
&
dst_place
);
void
build_variable_scope
(
const
framework
::
ProgramDesc
&
pdesc
,
void
build_variable_scope
(
const
framework
::
BlockDesc
&
block
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
);
void
build_op_func_list
(
const
platform
::
Place
&
place
,
void
build_op_func_list
(
const
platform
::
Place
&
place
,
const
framework
::
ProgramDesc
&
pdesc
,
const
framework
::
BlockDesc
&
block
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
std
::
vector
<
OpFuncNode
>*
vec_func_list
,
VariableScope
*
var_scope
);
VariableScope
*
var_scope
);
void
add_fetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
,
framework
::
BlockDesc
*
block
);
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
std
::
vector
<
size_t
>
merge_vector
(
const
std
::
vector
<
size_t
>&
first
,
const
std
::
vector
<
size_t
>&
second
);
const
std
::
vector
<
size_t
>&
second
);
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
a6e99dc7
...
@@ -776,7 +776,7 @@ class Instruction {
...
@@ -776,7 +776,7 @@ class Instruction {
std
::
vector
<
std
::
pair
<
Variable
*
,
Variable
*>>
vec_inplace_in_to_out_
;
std
::
vector
<
std
::
pair
<
Variable
*
,
Variable
*>>
vec_inplace_in_to_out_
;
};
};
namespace
interpreter
core
{
namespace
interpreter
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
...
@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
...
@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static
bool
IsMemcpyD2H
(
const
Instruction
&
instr
)
{
static
bool
IsMemcpyD2H
(
const
Instruction
&
instr
)
{
return
instr
.
OpBase
()
->
Type
()
==
kMemcpyD2H
;
return
instr
.
OpBase
()
->
Type
()
==
kMemcpyD2H
;
}
}
}
// namespace interpreter
core
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/standalone_executor.cc
浏览文件 @
a6e99dc7
...
@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
...
@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// run startup program
// run startup program
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
paddle
::
framework
::
interpreter
core
::
build_op_func_list
(
paddle
::
framework
::
interpreter
::
build_op_func_list
(
place_
,
startup_prog
,
&
vec_func_list
,
&
global_scope_
);
place_
,
startup_prog
.
Block
(
0
)
,
&
vec_func_list
,
&
global_scope_
);
}
}
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
paddle
::
framework
::
FetchList
StandaloneExecutor
::
Run
(
...
@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
...
@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if
(
iter
==
interpretercores_
.
end
())
{
if
(
iter
==
interpretercores_
.
end
())
{
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
VLOG
(
3
)
<<
"create interpreter_core for "
<<
oss
.
str
();
auto
core
=
std
::
make_shared
<
InterpreterCore
>
(
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
place_
,
main_prog_
,
&
global_scope_
,
feed_names
,
fetch_names
);
// new program.
auto
new_prog
=
std
::
make_shared
<
framework
::
ProgramDesc
>
(
main_prog_
);
auto
*
block
=
new_prog
->
MutableBlock
(
0
);
interpreter
::
add_fetch
(
fetch_names
,
block
);
auto
core
=
std
::
make_shared
<
InterpreterCore
>
(
place_
,
block
,
&
global_scope_
,
feed_names
);
programs_
.
emplace
(
oss
.
str
(),
new_prog
);
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
interpretercores_
.
emplace
(
oss
.
str
(),
core
);
return
core
;
return
core
;
}
else
{
}
else
{
...
...
paddle/fluid/framework/new_executor/standalone_executor.h
浏览文件 @
a6e99dc7
...
@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
...
@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
Scope
*
outer_scope_
;
Scope
*
outer_scope_
;
VariableScope
global_scope_
;
VariableScope
global_scope_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ProgramDesc
>>
programs_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
InterpreterCore
>>
interpretercores_
;
interpretercores_
;
};
};
...
...
paddle/fluid/framework/new_executor/stream_analyzer.cc
浏览文件 @
a6e99dc7
...
@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
...
@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const
OpFuncNode
&
op_func_node
)
{
const
OpFuncNode
&
op_func_node
)
{
auto
&
op_type
=
op_func_node
.
operator_base_
->
Type
();
auto
&
op_type
=
op_func_node
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpreter
core
::
kMemcpyH2D
)
{
if
(
op_type
==
interpreter
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpreter
core
::
kMemcpyD2H
)
{
}
else
if
(
op_type
==
interpreter
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
}
...
@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
...
@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
bool
StreamAnalyzer
::
IsDirectRun
(
Instruction
&
cur_instr
,
bool
StreamAnalyzer
::
IsDirectRun
(
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
const
Instruction
&
next_instr
)
{
return
(
&
cur_instr
.
DeviceContext
()
==
&
next_instr
.
DeviceContext
()
||
return
(
&
cur_instr
.
DeviceContext
()
==
&
next_instr
.
DeviceContext
()
||
interpreter
core
::
IsMemcpyD2H
(
cur_instr
)
||
interpreter
::
IsMemcpyD2H
(
cur_instr
)
||
interpreter
core
::
IsMemcpyH2D
(
next_instr
));
interpreter
::
IsMemcpyH2D
(
next_instr
));
}
}
platform
::
DeviceType
StreamAnalyzer
::
GetWaiterType
(
const
Instruction
&
instr
)
{
platform
::
DeviceType
StreamAnalyzer
::
GetWaiterType
(
const
Instruction
&
instr
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录