Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a4951843
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a4951843
编写于
2月 11, 2020
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
inference feed partial data, test=develop
上级
6dadb5de
变更
22
显示空白变更内容
内联
并排
Showing
22 changed file
with
904 addition
and
106 deletion
+904
-106
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-2
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+5
-8
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+2
-0
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+3
-1
paddle/fluid/framework/details/eager_deletion_op_handle.h
paddle/fluid/framework/details/eager_deletion_op_handle.h
+4
-1
paddle/fluid/framework/details/multi_devices_helper.cc
paddle/fluid/framework/details/multi_devices_helper.cc
+189
-1
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+6
-0
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
...le/fluid/framework/details/parallel_ssa_graph_executor.cc
+104
-21
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
+24
-3
paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc
.../framework/ir/memory_optimize_pass/eager_deletion_pass.cc
+1
-1
paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc
...optimize_pass/test_reference_count_pass_last_lived_ops.cc
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
...luid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_pass.cc
...r/multi_devices_graph_pass/set_reader_device_info_pass.cc
+101
-0
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+22
-18
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+137
-17
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+2
-0
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+4
-0
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+88
-12
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+3
-14
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+13
-5
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_inference_feed_partial_data.py
...sts/test_parallel_executor_inference_feed_partial_data.py
+190
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
a4951843
...
@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
...
@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library
(
share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor
)
cc_library
(
share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
multi_devices_helper
INTERFACE
SRCS multi_devices_helper.cc DEPS graph graph_helper
)
cc_library
(
multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
@@ -65,6 +65,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
...
@@ -65,6 +65,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper
sequential_execution_pass
sequential_execution_pass
modify_op_lock_and_record_event_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
all_reduce_deps_pass
...
@@ -72,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
...
@@ -72,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_
count
_pass
)
set_reader_device_
info
_pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
a4951843
...
@@ -66,7 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -66,7 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_fused_graph"
);
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_fused_graph"
);
AppendMultiDevPass
();
AppendMultiDevPass
();
AppendSetReaderDevice
Count
Pass
();
AppendSetReaderDevice
Index
Pass
();
AppendMultiGraphOptPasses
();
AppendMultiGraphOptPasses
();
AppendPassToSetMkldnnAttr
(
"mkldnn_placement_pass"
);
AppendPassToSetMkldnnAttr
(
"mkldnn_placement_pass"
);
...
@@ -225,8 +225,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -225,8 +225,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&
strategy_
);
&
strategy_
);
}
}
void
AppendSetReaderDevice
Count
Pass
()
{
void
AppendSetReaderDevice
Index
Pass
()
{
AppendPass
(
"set_reader_device_
count
_pass"
);
AppendPass
(
"set_reader_device_
index
_pass"
);
}
}
void
AppendPrintGraphPass
(
const
std
::
string
&
pass_name
,
void
AppendPrintGraphPass
(
const
std
::
string
&
pass_name
,
...
@@ -399,12 +399,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -399,12 +399,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."
;
"GPU, skipped."
;
continue
;
continue
;
}
}
}
else
if
(
pass
->
Type
()
==
"set_reader_device_
count
_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"set_reader_device_
index
_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
}
}
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
graph
=
pass
->
Apply
(
graph
);
graph
=
pass
->
Apply
(
graph
);
...
@@ -441,7 +438,7 @@ USE_PASS(fuse_sgd_op_pass);
...
@@ -441,7 +438,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS
(
fuse_momentum_op_pass
);
USE_PASS
(
fuse_momentum_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
set_reader_device_
count
_pass
);
USE_PASS
(
set_reader_device_
index
_pass
);
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
USE_PASS
(
mkldnn_placement_pass
);
USE_PASS
(
mkldnn_placement_pass
);
#endif
#endif
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
a4951843
...
@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
...
@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
OperatorBase
*
GetOp
()
{
return
op_
.
get
();
}
OperatorBase
*
GetOp
()
{
return
op_
.
get
();
}
const
OperatorBase
*
GetOp
()
const
{
return
op_
.
get
();
}
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
const
Scope
*
GetScope
()
const
{
return
scope_
;
}
const
Scope
*
GetScope
()
const
{
return
scope_
;
}
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
浏览文件 @
a4951843
...
@@ -31,10 +31,12 @@ namespace framework {
...
@@ -31,10 +31,12 @@ namespace framework {
namespace
details
{
namespace
details
{
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
const
platform
::
Place
&
place
,
ir
::
Node
*
node
,
Scope
*
scope
,
size_t
scope_idx
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
ir
::
MemOptVarInfo
*>
&
vars
,
GarbageCollector
*
gc
)
const
std
::
unordered_set
<
ir
::
MemOptVarInfo
*>
&
vars
,
GarbageCollector
*
gc
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
scope_
(
scope
),
scope_
(
scope
),
scope_idx_
(
scope_idx
),
place_
(
place
),
place_
(
place
),
var_infos_
(
vars
.
begin
(),
vars
.
end
()),
var_infos_
(
vars
.
begin
(),
vars
.
end
()),
gc_
(
gc
)
{
gc_
(
gc
)
{
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.h
浏览文件 @
a4951843
...
@@ -34,7 +34,7 @@ namespace details {
...
@@ -34,7 +34,7 @@ namespace details {
class
EagerDeletionOpHandle
:
public
OpHandleBase
{
class
EagerDeletionOpHandle
:
public
OpHandleBase
{
public:
public:
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
size_t
scope_idx
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
ir
::
MemOptVarInfo
*>
&
vars
,
const
std
::
unordered_set
<
ir
::
MemOptVarInfo
*>
&
vars
,
GarbageCollector
*
gc
);
GarbageCollector
*
gc
);
...
@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
*/
*/
Priority
GetPriority
()
const
override
{
return
kHighest
;
}
Priority
GetPriority
()
const
override
{
return
kHighest
;
}
size_t
GetScopeIdx
()
const
{
return
scope_idx_
;
}
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
...
@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void
CallOnce
();
void
CallOnce
();
Scope
*
scope_
;
Scope
*
scope_
;
size_t
scope_idx_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
std
::
vector
<
ir
::
MemOptVarInfo
*>
var_infos_
;
// not own
std
::
vector
<
ir
::
MemOptVarInfo
*>
var_infos_
;
// not own
GarbageCollector
*
gc_
;
// not own
GarbageCollector
*
gc_
;
// not own
...
...
paddle/fluid/framework/details/multi_devices_helper.cc
浏览文件 @
a4951843
...
@@ -12,9 +12,197 @@
...
@@ -12,9 +12,197 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{}
// namespace details
namespace
details
{
static
constexpr
size_t
kUndefinedDevIdx
=
-
1UL
;
static
std
::
unordered_set
<
std
::
string
>
kMultiDeviceOps
{
"sync_batch_norm"
,
"sync_batch_norm_grad"
,
"allreduce"
,
"c_allreduce_sum"
,
"c_allreduce_prod"
,
"c_allreduce_min"
,
"c_allreduce_max"
,
"c_allgather"
,
"c_reducescatter"
,
"c_broadcast"
,
"c_comm_init"
,
"c_comm_init_all"
,
"c_gen_nccl_id"
,
"c_sync_comm_stream"
,
"send"
,
"recv"
,
"send_barrier"
,
"fetch_barrier"
,
};
static
size_t
GetScopeIdxFromOp
(
const
details
::
OpHandleBase
&
op
)
{
if
(
auto
*
compute_op
=
dynamic_cast
<
const
details
::
ComputationOpHandle
*>
(
&
op
))
{
return
kMultiDeviceOps
.
count
(
compute_op
->
GetOp
()
->
Type
())
==
0
?
compute_op
->
GetScopeIdx
()
:
kUndefinedDevIdx
;
}
else
if
(
auto
*
gc_op
=
dynamic_cast
<
const
details
::
EagerDeletionOpHandle
*>
(
&
op
))
{
return
gc_op
->
GetScopeIdx
();
}
else
if
(
auto
*
share_op
=
dynamic_cast
<
const
details
::
ShareTensorBufferOpHandle
*>
(
&
op
))
{
return
share_op
->
GetScopeIdx
();
}
else
{
return
kUndefinedDevIdx
;
}
}
static
bool
ContainMultiDeviceOp
(
const
ProgramDesc
&
program
,
size_t
begin_block_idx
)
{
for
(
size_t
block_idx
=
begin_block_idx
;
block_idx
<
program
.
Size
();
++
block_idx
)
{
for
(
auto
*
op_desc
:
program
.
Block
(
block_idx
).
AllOps
())
{
if
(
kMultiDeviceOps
.
count
(
op_desc
->
Type
())
>
0
)
{
return
true
;
}
}
}
return
false
;
}
static
size_t
GetUniqueDeviceIdOfOp
(
const
details
::
OpHandleBase
&
op
)
{
size_t
dev_idx
=
GetScopeIdxFromOp
(
op
);
if
(
dev_idx
==
kUndefinedDevIdx
)
{
return
kUndefinedDevIdx
;
}
const
auto
&
ins
=
op
.
Inputs
();
const
auto
&
outs
=
op
.
Outputs
();
auto
in_outs
=
ins
;
in_outs
.
insert
(
in_outs
.
end
(),
outs
.
begin
(),
outs
.
end
());
for
(
auto
*
var
:
in_outs
)
{
auto
*
var_handle
=
dynamic_cast
<
details
::
VarHandle
*>
(
var
);
if
(
var_handle
==
nullptr
)
{
continue
;
}
if
(
dev_idx
!=
var_handle
->
scope_idx
())
{
return
kUndefinedDevIdx
;
}
}
return
dev_idx
;
}
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
TrySeparateToMultipleSingleDeviceGraphs
(
ir
::
Graph
*
graph
)
{
if
(
ContainMultiDeviceOp
(
graph
->
OriginProgram
(),
1
))
{
return
{};
}
size_t
place_num
=
0
;
auto
op_handles
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
if
(
op_handles
.
empty
())
{
return
{};
}
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
op_to_dev_idx
;
for
(
auto
&
op
:
op_handles
)
{
auto
dev_idx
=
GetUniqueDeviceIdOfOp
(
*
op
);
if
(
dev_idx
==
kUndefinedDevIdx
)
{
VLOG
(
10
)
<<
"Op "
<<
op
->
Name
()
<<
" is not determined"
;
return
{};
}
place_num
=
std
::
max
(
place_num
,
dev_idx
+
1
);
op_to_dev_idx
[
op
]
=
dev_idx
;
}
for
(
auto
&
op
:
op_handles
)
{
auto
dev_idx
=
op_to_dev_idx
.
at
(
op
);
for
(
auto
&
in_var
:
op
->
Inputs
())
{
if
(
in_var
->
GeneratedOp
())
{
auto
iter
=
op_to_dev_idx
.
find
(
in_var
->
GeneratedOp
());
if
(
iter
==
op_to_dev_idx
.
end
()
||
iter
->
second
!=
dev_idx
)
{
return
{};
}
}
}
for
(
auto
&
out_var
:
op
->
Outputs
())
{
for
(
auto
&
pending_op
:
out_var
->
PendingOps
())
{
auto
iter
=
op_to_dev_idx
.
find
(
pending_op
);
if
(
iter
==
op_to_dev_idx
.
end
()
||
iter
->
second
!=
dev_idx
)
{
return
{};
}
}
}
}
PADDLE_ENFORCE_GE
(
place_num
,
1
,
platform
::
errors
::
NotFound
(
"No place found, this may be a bug"
));
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
(
place_num
);
for
(
auto
&
g
:
graphs
)
{
g
.
reset
(
new
ir
::
Graph
(
ProgramDesc
()));
g
->
Set
(
kGraphVars
,
new
GraphVars
(
1UL
));
g
->
Set
(
kGraphDepVars
,
new
GraphDepVars
());
}
for
(
auto
&
op
:
op_handles
)
{
auto
dev_idx
=
op_to_dev_idx
.
at
(
op
);
auto
*
ret_graph
=
graphs
[
dev_idx
].
get
();
auto
&
ret_vars
=
ret_graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
0
];
auto
&
ret_dummy_vars
=
ret_graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
);
auto
&
origin_vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
dev_idx
];
ret_graph
->
AddNode
(
graph
->
RemoveNode
(
op
->
Node
()).
release
());
auto
handler
=
[
&
](
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
for
(
auto
*
var
:
vars
)
{
if
(
graph
->
Nodes
().
count
(
var
->
Node
())
>
0
)
{
ret_graph
->
AddNode
(
graph
->
RemoveNode
(
var
->
Node
()).
release
());
auto
*
dummy_var
=
dynamic_cast
<
DummyVarHandle
*>
(
var
);
if
(
dummy_var
==
nullptr
)
{
ret_vars
.
emplace
(
var
->
Name
(),
origin_vars
.
at
(
var
->
Name
()));
}
else
{
ret_dummy_vars
.
emplace
(
dummy_var
);
}
}
}
};
handler
(
op
->
Inputs
());
handler
(
op
->
Outputs
());
}
graph
->
Erase
(
kGraphVars
);
graph
->
Erase
(
kGraphDepVars
);
return
graphs
;
}
bool
HasDropLastReadOp
(
const
ir
::
Graph
&
graph
)
{
auto
ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
graph
);
for
(
auto
*
op
:
ops
)
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
&&
compute_op
->
GetOp
()
->
Type
()
==
"read"
&&
compute_op
->
GetOp
()
->
Attr
<
bool
>
(
"drop_last"
))
{
VLOG
(
10
)
<<
"The graph has drop_last=True read op"
;
return
true
;
}
}
VLOG
(
10
)
<<
"The graph does not have drop_last=True read op"
;
return
false
;
}
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_helper.h
浏览文件 @
a4951843
...
@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
...
@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kPlaces
[]
=
"places"
;
constexpr
char
kPlaces
[]
=
"places"
;
constexpr
char
kGlobalScope
[]
=
"global_scope"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kNCCLCtxs
[]
=
"nccl_ctxs"
;
constexpr
char
kNCCLCtxs
[]
=
"nccl_ctxs"
;
constexpr
char
kUseHierarchicalAllReduce
[]
=
"use_hierarchical_allreduce"
;
constexpr
char
kUseHierarchicalAllReduce
[]
=
"use_hierarchical_allreduce"
;
...
@@ -100,6 +101,11 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
...
@@ -100,6 +101,11 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
return
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
iter
->
second
);
return
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
iter
->
second
);
}
}
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
TrySeparateToMultipleSingleDeviceGraphs
(
ir
::
Graph
*
graph
);
bool
HasDropLastReadOp
(
const
ir
::
Graph
&
graph
);
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
浏览文件 @
a4951843
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm>
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
@@ -21,11 +22,11 @@ namespace paddle {
...
@@ -21,11 +22,11 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
st
d
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
st
atic
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
SeparateMultiDevicesGraph
(
ParallelSSAGraphExecutor
::
SeparateMultiDevicesGraph
(
ir
::
Graph
*
graph
)
{
ir
::
Graph
*
graph
,
size_t
place_num
)
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
graphs
.
reserve
(
place
s_
.
size
()
);
graphs
.
reserve
(
place
_num
);
for
(
size_t
i
=
0
;
i
<
place
s_
.
size
()
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
place
_num
;
++
i
)
{
ProgramDesc
empty
;
ProgramDesc
empty
;
graphs
.
emplace_back
(
std
::
unique_ptr
<
ir
::
Graph
>
(
new
ir
::
Graph
(
empty
)));
graphs
.
emplace_back
(
std
::
unique_ptr
<
ir
::
Graph
>
(
new
ir
::
Graph
(
empty
)));
auto
&
g
=
graphs
.
back
();
auto
&
g
=
graphs
.
back
();
...
@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
...
@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
}
}
}
}
for
(
size_t
dev_id
=
0
;
dev_id
<
place
s_
.
size
()
;
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
place
_num
;
++
dev_id
)
{
auto
&
dev_vars
=
graphs
[
dev_id
]
->
Get
<
GraphVars
>
(
kGraphVars
)[
0
];
auto
&
dev_vars
=
graphs
[
dev_id
]
->
Get
<
GraphVars
>
(
kGraphVars
)[
0
];
auto
&
origin_vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
dev_id
];
auto
&
origin_vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
dev_id
];
for
(
auto
&
name_pair
:
origin_vars
)
{
for
(
auto
&
name_pair
:
origin_vars
)
{
...
@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
...
@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
)
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
)
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
:
ParallelSSAGraphExecutor
(
strategy
,
local_scopes
,
local_exec_scopes
,
places
,
SeparateMultiDevicesGraph
(
graph
,
places
.
size
()))
{}
ParallelSSAGraphExecutor
::
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
)
:
strategy_
(
std
::
move
(
strategy
)),
:
strategy_
(
std
::
move
(
strategy
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
pool_
(
places
.
size
()
>=
2
?
new
::
ThreadPool
(
places
.
size
())
:
nullptr
),
pool_
(
places
.
size
()
>=
2
?
new
::
ThreadPool
(
places
.
size
())
:
nullptr
),
places_
(
std
::
move
(
places
)),
places_
(
places
),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
graphs_
(
std
::
move
(
graphs
)),
// attrs.
feed_status_
(
places
.
size
(),
FeedStatus
::
kNone
)
{
graphs_
(
SeparateMultiDevicesGraph
(
graph
))
{
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
graphs_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Graph number does not match place number"
));
PADDLE_ENFORCE_GT
(
places_
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"place number must be larger than 0"
));
auto
seq_allreduce_pass
=
auto
seq_allreduce_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"all_reduce_deps_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"all_reduce_deps_pass"
);
seq_allreduce_pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
false
));
seq_allreduce_pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
false
));
...
@@ -123,22 +143,43 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
...
@@ -123,22 +143,43 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return
result
;
return
result
;
}
}
enum
ExceptionStatus
{
kSuccess
=
0
,
kEOF
,
kOther
};
FeedFetchList
ParallelSSAGraphExecutor
::
Run
(
FeedFetchList
ParallelSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
size_t
feed_num
=
std
::
count
(
feed_status_
.
begin
(),
feed_status_
.
end
(),
FeedStatus
::
kHasFeed
);
bool
has_feed
=
(
feed_num
>
0
);
VLOG
(
10
)
<<
"Feed num "
<<
feed_num
;
size_t
place_num
=
places_
.
size
();
std
::
vector
<
std
::
future
<
FeedFetchList
>>
run_futures
;
std
::
vector
<
std
::
future
<
FeedFetchList
>>
run_futures
;
std
::
vector
<
ExceptionStatus
>
exception_status
(
place_num
,
ExceptionStatus
::
kSuccess
);
std
::
vector
<
FeedFetchList
>
fetch_data
;
std
::
vector
<
FeedFetchList
>
fetch_data
;
FeedFetchList
ret
;
FeedFetchList
ret
;
fetch_data
.
reserve
(
place
s_
.
size
()
);
fetch_data
.
reserve
(
place
_num
);
ret
.
reserve
(
fetch_tensors
.
size
()
);
ret
.
reserve
(
place_num
);
exception_holder_
.
Clear
();
exception_holder_
.
Clear
();
for
(
size_t
i
=
0
;
i
<
place
s_
.
size
()
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
place
_num
;
++
i
)
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
auto
call
=
[
&
,
i
]()
->
FeedFetchList
{
try
{
try
{
if
(
!
support_partial_feed_
||
!
has_feed
||
feed_status_
[
i
]
==
FeedStatus
::
kHasFeed
)
{
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
}
else
{
return
FeedFetchList
();
}
}
catch
(
platform
::
EOFException
&
)
{
exception_status
[
i
]
=
ExceptionStatus
::
kEOF
;
exception_holder_
.
Catch
(
std
::
current_exception
());
}
catch
(...)
{
}
catch
(...)
{
exception_status
[
i
]
=
ExceptionStatus
::
kOther
;
exception_holder_
.
Catch
(
std
::
current_exception
());
exception_holder_
.
Catch
(
std
::
current_exception
());
}
}
return
FeedFetchList
();
return
FeedFetchList
();
...
@@ -153,21 +194,63 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
...
@@ -153,21 +194,63 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if
(
pool_
)
{
if
(
pool_
)
{
for
(
auto
&
f
:
run_futures
)
{
for
(
auto
&
f
:
run_futures
)
{
if
(
exception_holder_
.
IsCaught
())
{
f
.
wait
();
}
else
{
fetch_data
.
emplace_back
(
f
.
get
());
fetch_data
.
emplace_back
(
f
.
get
());
}
}
}
}
bool
has_exception
=
exception_holder_
.
IsCaught
();
if
(
!
support_partial_feed_
&&
has_exception
)
{
VLOG
(
10
)
<<
"Exception rethrow because partial feed is not supported"
;
exception_holder_
.
ReThrow
();
}
}
if
(
exception_holder_
.
IsCaught
())
{
std
::
vector
<
bool
>
is_valid
(
place_num
,
true
);
if
(
support_partial_feed_
)
{
if
(
has_feed
)
{
for
(
size_t
i
=
0
;
i
<
place_num
;
++
i
)
{
if
(
feed_status_
[
i
]
==
FeedStatus
::
kNone
)
{
is_valid
[
i
]
=
false
;
}
else
if
(
exception_status
[
i
]
!=
ExceptionStatus
::
kSuccess
)
{
PADDLE_ENFORCE_EQ
(
has_exception
,
true
,
platform
::
errors
::
InvalidArgument
(
"Thread pool raises exception but not caught"
));
VLOG
(
10
)
<<
"Exception rethrow because non-EOF exception raises when "
"feed is given"
;
exception_holder_
.
ReThrow
();
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
place_num
;
++
i
)
{
if
(
exception_status
[
i
]
==
ExceptionStatus
::
kOther
)
{
PADDLE_ENFORCE_EQ
(
has_exception
,
true
,
platform
::
errors
::
InvalidArgument
(
"Thread pool raises exception but not caught"
));
VLOG
(
10
)
<<
"Exception rethrow because non-EOF exception raises when "
"feed is not given"
;
exception_holder_
.
ReThrow
();
}
else
if
(
exception_status
[
i
]
!=
ExceptionStatus
::
kSuccess
)
{
is_valid
[
i
]
=
false
;
}
}
}
}
if
(
std
::
count
(
is_valid
.
begin
(),
is_valid
.
end
(),
true
)
==
0
)
{
PADDLE_ENFORCE_EQ
(
has_exception
,
true
,
platform
::
errors
::
InvalidArgument
(
"Thread pool raises exception but not caught"
));
VLOG
(
10
)
<<
"Raise exception because there is no success worker"
;
exception_holder_
.
ReThrow
();
exception_holder_
.
ReThrow
();
}
}
for
(
size_t
fetch_idx
=
0
;
fetch_idx
<
fetch_tensors
.
size
();
++
fetch_idx
)
{
for
(
size_t
fetch_idx
=
0
;
fetch_idx
<
fetch_tensors
.
size
();
++
fetch_idx
)
{
std
::
vector
<
const
LoDTensor
*>
lodtensor_ptrs
;
std
::
vector
<
const
LoDTensor
*>
lodtensor_ptrs
;
lodtensor_ptrs
.
reserve
(
local_scopes_
.
size
());
lodtensor_ptrs
.
reserve
(
place_num
);
for
(
size_t
scope_idx
=
0
;
scope_idx
<
local_scopes_
.
size
();
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
place_num
;
++
scope_idx
)
{
if
(
!
is_valid
[
scope_idx
])
{
continue
;
}
lodtensor_ptrs
.
push_back
(
&
fetch_data
.
at
(
scope_idx
).
at
(
fetch_idx
));
lodtensor_ptrs
.
push_back
(
&
fetch_data
.
at
(
scope_idx
).
at
(
fetch_idx
));
}
}
ret
.
emplace_back
();
ret
.
emplace_back
();
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
浏览文件 @
a4951843
...
@@ -27,12 +27,25 @@ namespace framework {
...
@@ -27,12 +27,25 @@ namespace framework {
namespace
details
{
namespace
details
{
class
ParallelSSAGraphExecutor
:
public
SSAGraphExecutor
{
class
ParallelSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
enum
FeedStatus
{
kNone
=
0
,
// No feed
kHasFeed
=
1
// Has feed
};
public:
public:
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
);
ir
::
Graph
*
graph
);
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
);
~
ParallelSSAGraphExecutor
()
final
=
default
;
~
ParallelSSAGraphExecutor
()
final
=
default
;
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graphs_
[
0
];
}
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graphs_
[
0
];
}
...
@@ -41,10 +54,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -41,10 +54,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
private:
void
SetHasFeed
(
size_t
dev_idx
,
bool
has_feed
)
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
SeparateMultiDevicesGraph
(
feed_status_
[
dev_idx
]
=
has_feed
?
FeedStatus
::
kHasFeed
:
FeedStatus
::
kNone
;
ir
::
Graph
*
graph
);
}
void
EnablePartialFeedSupport
()
{
support_partial_feed_
=
true
;
}
bool
SupportPartialFeed
()
const
{
return
support_partial_feed_
;
}
private:
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
...
@@ -54,6 +72,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -54,6 +72,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
std
::
unique_ptr
<
details
::
FastThreadedSSAGraphExecutor
>>
std
::
vector
<
std
::
unique_ptr
<
details
::
FastThreadedSSAGraphExecutor
>>
executors_
;
executors_
;
ExceptionHolder
exception_holder_
;
ExceptionHolder
exception_holder_
;
bool
support_partial_feed_
{
false
};
std
::
vector
<
FeedStatus
>
feed_status_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc
浏览文件 @
a4951843
...
@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
}
}
auto
*
eager_deletion_op
=
new
details
::
EagerDeletionOpHandle
(
auto
*
eager_deletion_op
=
new
details
::
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
eager_deletion_node
,
op
->
GetScope
(),
op
->
Get
ScopeIdx
(),
op
->
Get
Place
(),
std
::
move
(
var_info
),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
());
std
::
move
(
var_info
),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
());
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
...
...
paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc
浏览文件 @
a4951843
...
@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
...
@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
ir
::
kMemOptVarInfoMapList
,
&
mem_opt_var_infos_
);
ref_cnt_pass
->
SetNotOwned
(
ir
::
kMemOptVarInfoMapList
,
&
mem_opt_var_infos_
);
ref_cnt_pass
->
SetNotOwned
(
ir
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars_
);
ref_cnt_pass
->
SetNotOwned
(
ir
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars_
);
ref_cnt_pass
->
Apply
(
&
graph_
);
ref_cnt_pass
->
Apply
(
&
const_cast
<
ir
::
Graph
&>
(
executor_
->
Graph
())
);
}
}
bool
IsLastLivedOps
(
const
std
::
string
&
name
,
bool
IsLastLivedOps
(
const
std
::
string
&
name
,
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
浏览文件 @
a4951843
...
@@ -11,7 +11,7 @@ endif()
...
@@ -11,7 +11,7 @@ endif()
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
${
ALL_REDUCE_OP_HANDLES
}
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
${
ALL_REDUCE_OP_HANDLES
}
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
set_reader_device_
count_pass SRCS set_reader_device_count
_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass
)
cc_library
(
set_reader_device_
info_pass SRCS set_reader_device_info
_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass
)
cc_library
(
fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle
)
cc_library
(
fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass
)
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_
count
_pass.cc
→
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_
info
_pass.cc
浏览文件 @
a4951843
...
@@ -22,35 +22,44 @@ namespace paddle {
...
@@ -22,35 +22,44 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
SetReaderDeviceCountPass
:
public
Pass
{
static
int
GetDeviceCountFromPassAttr
(
const
Pass
&
pass
)
{
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
private:
int
GetDeviceCount
()
const
;
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
const
;
const
Scope
*
GlobalScope
()
const
;
};
int
SetReaderDeviceCountPass
::
GetDeviceCount
()
const
{
return
static_cast
<
int
>
(
return
static_cast
<
int
>
(
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
).
size
());
pass
.
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
).
size
());
}
}
st
d
::
unordered_set
<
std
::
string
>
SetReaderDeviceCountPass
::
ReaderOpSet
()
const
{
st
atic
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
{
return
{
"create_py_reader"
};
return
{
"create_py_reader"
};
}
}
const
Scope
*
SetReaderDeviceCountPass
::
GlobalScope
()
const
{
class
InitReaderDeviceCountPass
:
public
Pass
{
return
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
)[
0
];
protected:
}
void
ApplyImpl
(
Graph
*
graph
)
const
override
{
using
QueueHolder
=
operators
::
reader
::
OrderedMultiDeviceLoDTensorBlockingQueueHolder
;
void
SetReaderDeviceCountPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
auto
dev_cnt
=
GetDeviceCount
();
auto
reader_ops
=
ReaderOpSet
();
auto
reader_ops
=
ReaderOpSet
();
auto
scope
=
GlobalScope
();
auto
dev_cnt
=
GetDeviceCountFromPassAttr
(
*
this
);
const
auto
&
scope
=
Get
<
const
Scope
>
(
details
::
kGlobalScope
);
for
(
auto
&
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
&&
reader_ops
.
count
(
node
->
Op
()
->
Type
())
!=
0
)
{
auto
queue_name
=
node
->
Op
()
->
Input
(
"blocking_queue"
)[
0
];
auto
var
=
scope
.
FindVar
(
queue_name
);
if
(
var
&&
var
->
IsType
<
QueueHolder
>
())
{
VLOG
(
10
)
<<
"Set device count of "
<<
queue_name
<<
" to be "
<<
dev_cnt
;
var
->
GetMutable
<
QueueHolder
>
()
->
GetQueue
()
->
SetDeviceCount
(
dev_cnt
);
}
}
}
}
};
class
SetReaderDeviceIndexPass
:
public
Pass
{
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
{
auto
dev_cnt
=
GetDeviceCountFromPassAttr
(
*
this
);
auto
reader_ops
=
ReaderOpSet
();
size_t
found_op_num
=
0
;
size_t
found_op_num
=
0
;
for
(
auto
&
node
:
graph
->
Nodes
())
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
...
@@ -69,30 +78,24 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
...
@@ -69,30 +78,24 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
auto
queue_name
=
op_handle
.
GetOp
()
->
Input
(
"blocking_queue"
);
auto
var
=
scope
->
FindVar
(
queue_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Blocking queue of DataLoader not found"
));
using
QueueHolder
=
operators
::
reader
::
OrderedMultiDeviceLoDTensorBlockingQueueHolder
;
if
(
var
->
IsType
<
QueueHolder
>
())
{
var
->
GetMutable
<
QueueHolder
>
()
->
GetQueue
()
->
SetDeviceCount
(
dev_cnt
);
}
++
found_op_num
;
++
found_op_num
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
}
}
}
}
VLOG
(
10
)
<<
"Found op number "
<<
found_op_num
;
VLOG
(
10
)
<<
"Found op number "
<<
found_op_num
;
}
}
};
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
set_reader_device_count_pass
,
REGISTER_PASS
(
init_reader_device_count_pass
,
paddle
::
framework
::
ir
::
SetReaderDeviceCountPass
)
paddle
::
framework
::
ir
::
InitReaderDeviceCountPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGlobalScope
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
);
REGISTER_PASS
(
set_reader_device_index_pass
,
paddle
::
framework
::
ir
::
SetReaderDeviceIndexPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
);
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
a4951843
...
@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
...
@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std
::
vector
<
LoDTensor
>
LoDTensor
::
SplitLoDTensor
(
std
::
vector
<
LoDTensor
>
LoDTensor
::
SplitLoDTensor
(
const
std
::
vector
<
platform
::
Place
>
places
)
const
{
const
std
::
vector
<
platform
::
Place
>
places
)
const
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"place number cannot be empty when splitting"
));
check_memory_size
();
check_memory_size
();
int
batch_size
=
size_t
batch_size
=
lod
().
empty
()
?
dims
()[
0
]
:
static_cast
<
int
>
(
lod
()[
0
].
size
())
-
1
;
lod
().
empty
()
?
static_cast
<
size_t
>
(
dims
()[
0
])
:
lod
()[
0
].
size
()
-
1
;
size_t
result_size
=
std
::
min
(
static_cast
<
size_t
>
(
batch_size
),
places
.
size
());
size_t
remainder
=
batch_size
%
places
.
size
();
std
::
vector
<
LoDTensor
>
results
;
// if batch_size is 0, just return #places.size() copys of empty
results
.
reserve
(
result_size
);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// tensors.
// tensors.
if
(
result_size
==
0
)
{
if
(
batch_size
==
0
)
{
std
::
vector
<
LoDTensor
>
empty_results
;
empty_results
.
reserve
(
places
.
size
());
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
LoDTensor
dst
;
LoDTensor
dst
;
dst
.
Resize
(
dims
());
dst
.
Resize
(
dims
());
...
@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
...
@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if
(
!
lod
().
empty
())
{
if
(
!
lod
().
empty
())
{
dst
.
set_lod
(
lod
());
dst
.
set_lod
(
lod
());
}
}
results
.
emplace_back
(
dst
);
empty_results
.
emplace_back
(
std
::
move
(
dst
)
);
}
}
return
results
;
return
empty_
results
;
}
}
int
step_width
=
static_cast
<
int
>
(
batch_size
/
result_size
);
auto
step_width
=
(
batch_size
+
places
.
size
()
-
1
)
/
places
.
size
();
auto
result_size
=
(
batch_size
+
step_width
-
1
)
/
step_width
;
std
::
vector
<
LoDTensor
>
results
;
results
.
reserve
(
result_size
);
for
(
size_t
i
=
0
;
i
<
result_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
result_size
;
++
i
)
{
int
begin
=
static_cast
<
int
>
(
i
*
step_width
)
;
auto
begin
=
i
*
step_width
;
int
end
=
static_cast
<
int
>
((
i
+
1
)
*
step_width
);
auto
end
=
std
::
min
<
size_t
>
((
i
+
1
)
*
step_width
,
batch_size
);
if
(
i
+
1
==
places
.
size
())
{
// last
PADDLE_ENFORCE_LT
(
begin
,
end
,
end
+=
remainder
;
platform
::
errors
::
InvalidArgument
(
}
"begin must be less than end, this may be a bug"
));
LoDTensor
dst
;
LoDTensor
dst
;
if
(
lod
().
empty
())
{
if
(
lod
().
empty
())
{
...
@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
...
@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
}
}
dst
.
set_lod
(
my_lod
);
dst
.
set_lod
(
my_lod
);
}
}
results
.
emplace_back
(
dst
);
results
.
emplace_back
(
std
::
move
(
dst
)
);
}
}
return
results
;
return
results
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
a4951843
...
@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
...
@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
class
ParallelExecutorPrivate
{
class
ParallelExecutorPrivate
{
public:
public:
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
:
places_
(
places
)
{
Scope
*
global_scope
)
:
places_
(
places
),
global_scope_
(
global_scope
)
{
if
(
!
FLAGS_pe_profile_fname
.
empty
())
{
if
(
!
FLAGS_pe_profile_fname
.
empty
())
{
std
::
call_once
(
gProfileOnce
,
[]
{
std
::
call_once
(
gProfileOnce
,
[]
{
#ifdef WITH_GPERFTOOLS
#ifdef WITH_GPERFTOOLS
...
@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
...
@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
}
}
}
}
void
InitReaderDeviceCount
(
ir
::
Graph
*
graph
)
const
{
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"init_reader_device_count_pass"
);
pass
->
SetNotOwned
<
const
Scope
>
(
details
::
kGlobalScope
,
global_scope_
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
,
&
places_
);
pass
->
Apply
(
graph
);
}
void
SetHasFeed
(
size_t
dev_idx
,
bool
has_feed
=
true
);
bool
AllowPartialFeed
()
const
;
ir
::
Graph
*
ApplyMemoryOptimizePass
(
ir
::
Graph
*
graph
);
ir
::
Graph
*
ApplyMemoryOptimizePass
(
ir
::
Graph
*
graph
);
inline
bool
HasGarbageCollectors
()
const
{
return
!
gcs_
.
empty
();
}
inline
bool
HasGarbageCollectors
()
const
{
return
!
gcs_
.
empty
();
}
...
@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
...
@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
ir
::
MemOptVarInfoMapList
mem_opt_var_infos_
;
ir
::
MemOptVarInfoMapList
mem_opt_var_infos_
;
ir
::
GarbageCollectorMap
gcs_
;
ir
::
GarbageCollectorMap
gcs_
;
details
::
ParallelSSAGraphExecutor
*
inference_executor_
{
nullptr
};
};
};
void
ParallelExecutorPrivate
::
SetHasFeed
(
size_t
dev_idx
,
bool
has_feed
)
{
if
(
inference_executor_
)
{
inference_executor_
->
SetHasFeed
(
dev_idx
,
has_feed
);
}
}
bool
ParallelExecutorPrivate
::
AllowPartialFeed
()
const
{
return
inference_executor_
&&
inference_executor_
->
SupportPartialFeed
();
}
ir
::
Graph
*
ParallelExecutorPrivate
::
ApplyMemoryOptimizePass
(
ir
::
Graph
*
graph
)
{
ir
::
Graph
*
ParallelExecutorPrivate
::
ApplyMemoryOptimizePass
(
ir
::
Graph
*
graph
)
{
if
(
FLAGS_use_ngraph
)
{
if
(
FLAGS_use_ngraph
)
{
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
...
@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
...
@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return
graph
;
return
graph
;
}
}
class
ResetHasFeedGuard
{
public:
explicit
ResetHasFeedGuard
(
ParallelExecutorPrivate
*
pe_member
)
:
pe_member_
(
pe_member
)
{}
~
ResetHasFeedGuard
()
{
for
(
size_t
i
=
0
;
i
<
pe_member_
->
places_
.
size
();
++
i
)
{
pe_member_
->
SetHasFeed
(
i
,
false
);
}
}
private:
ParallelExecutorPrivate
*
pe_member_
;
};
size_t
ParallelExecutor
::
DeviceCount
()
const
{
return
member_
->
places_
.
size
();
}
size_t
ParallelExecutor
::
DeviceCount
()
const
{
return
member_
->
places_
.
size
();
}
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
...
@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const
ExecutionStrategy
&
exec_strategy
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
,
const
BuildStrategy
&
build_strategy
,
ir
::
Graph
*
graph
)
ir
::
Graph
*
graph
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
:
member_
(
new
ParallelExecutorPrivate
(
places
,
scope
))
{
member_
->
global_scope_
=
scope
;
member_
->
InitReaderDeviceCount
(
graph
)
;
member_
->
use_cuda_
=
exec_strategy
.
use_cuda_
;
member_
->
use_cuda_
=
exec_strategy
.
use_cuda_
;
member_
->
build_strategy_
=
build_strategy
;
member_
->
build_strategy_
=
build_strategy
;
member_
->
use_all_reduce_
=
member_
->
build_strategy_
.
reduce_
==
member_
->
use_all_reduce_
=
member_
->
build_strategy_
.
reduce_
==
...
@@ -605,6 +646,22 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -605,6 +646,22 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
PADDLE_THROW
(
PADDLE_THROW
(
"Paddle should be compiled with CUDA for ParallelGraph Execution."
);
"Paddle should be compiled with CUDA for ParallelGraph Execution."
);
#endif
#endif
}
else
{
bool
has_drop_last_read_op
=
details
::
HasDropLastReadOp
(
*
graph
);
auto
possible_inference_graphs
=
details
::
TrySeparateToMultipleSingleDeviceGraphs
(
graph
);
if
(
!
possible_inference_graphs
.
empty
())
{
VLOG
(
5
)
<<
"Use ParallelSSAGraphExecutor in inference phase"
;
auto
*
pg_exe
=
new
details
::
ParallelSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
std
::
move
(
possible_inference_graphs
));
if
(
!
has_drop_last_read_op
)
{
VLOG
(
5
)
<<
"Enable partial feed support in inference phase"
;
pg_exe
->
EnablePartialFeedSupport
();
}
final_graphs
=
pg_exe
->
Graphs
();
member_
->
executor_
.
reset
(
pg_exe
);
member_
->
inference_executor_
=
pg_exe
;
}
else
{
}
else
{
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
VLOG
(
3
)
<<
"use ThreadedSSAGraphExecutor"
;
VLOG
(
3
)
<<
"use ThreadedSSAGraphExecutor"
;
...
@@ -619,6 +676,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -619,6 +676,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
}
}
final_graphs
.
emplace_back
(
graph
);
final_graphs
.
emplace_back
(
graph
);
}
}
}
VLOG
(
3
)
<<
"use ScopeBufferedSSAGraphExecutor"
;
VLOG
(
3
)
<<
"use ScopeBufferedSSAGraphExecutor"
;
if
(
!
member_
->
build_strategy_
.
async_mode_
)
{
if
(
!
member_
->
build_strategy_
.
async_mode_
)
{
...
@@ -724,6 +782,8 @@ FeedFetchList ParallelExecutor::Run(
...
@@ -724,6 +782,8 @@ FeedFetchList ParallelExecutor::Run(
platform
::
RecordBlock
b
(
0
);
platform
::
RecordBlock
b
(
0
);
ResetHasFeedGuard
reset_has_feed_guard
(
member_
);
ir
::
SkipMemOptVarsGuard
guard
(
&
(
member_
->
mem_opt_var_infos_
),
fetch_tensors
,
ir
::
SkipMemOptVarsGuard
guard
(
&
(
member_
->
mem_opt_var_infos_
),
fetch_tensors
,
member_
->
HasGarbageCollectors
());
member_
->
HasGarbageCollectors
());
...
@@ -734,10 +794,22 @@ FeedFetchList ParallelExecutor::Run(
...
@@ -734,10 +794,22 @@ FeedFetchList ParallelExecutor::Run(
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
const
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
LoDTensor
>>
&
tensors
)
{
const
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
LoDTensor
>>
&
tensors
)
{
PADDLE_ENFORCE_EQ
(
member_
->
local_scopes_
.
size
(),
tensors
.
size
());
if
(
!
member_
->
AllowPartialFeed
())
{
PADDLE_ENFORCE_EQ
(
member_
->
local_scopes_
.
size
(),
tensors
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The feed tensor number does not match the device number"
));
}
else
{
PADDLE_ENFORCE_GE
(
member_
->
local_scopes_
.
size
(),
tensors
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The feed tensor number exceeds the device number"
));
}
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
auto
&
map
=
tensors
[
i
];
auto
&
map
=
tensors
[
i
];
if
(
!
map
.
empty
())
{
member_
->
SetHasFeed
(
i
);
}
for
(
auto
&
pair
:
map
)
{
for
(
auto
&
pair
:
map
)
{
bool
is_persistable
=
member_
->
IsPersistable
(
pair
.
first
);
bool
is_persistable
=
member_
->
IsPersistable
(
pair
.
first
);
if
(
!
is_persistable
)
{
if
(
!
is_persistable
)
{
...
@@ -757,6 +829,11 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
...
@@ -757,6 +829,11 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void
ParallelExecutor
::
FeedAndSplitTensorIntoLocalScopes
(
void
ParallelExecutor
::
FeedAndSplitTensorIntoLocalScopes
(
const
std
::
unordered_map
<
std
::
string
,
LoDTensor
>
&
tensors
)
{
const
std
::
unordered_map
<
std
::
string
,
LoDTensor
>
&
tensors
)
{
size_t
num_places
=
member_
->
places_
.
size
();
size_t
num_places
=
member_
->
places_
.
size
();
bool
allow_partial_feed
=
member_
->
AllowPartialFeed
();
size_t
persistable_feed_len
=
-
1UL
;
size_t
non_persistable_feed_len
=
-
1UL
;
for
(
auto
&
pair
:
tensors
)
{
for
(
auto
&
pair
:
tensors
)
{
bool
is_persistable
=
member_
->
IsPersistable
(
pair
.
first
);
bool
is_persistable
=
member_
->
IsPersistable
(
pair
.
first
);
VLOG
(
3
)
<<
"Split "
<<
(
is_persistable
?
"persistable"
:
"no persistable"
)
VLOG
(
3
)
<<
"Split "
<<
(
is_persistable
?
"persistable"
:
"no persistable"
)
...
@@ -764,7 +841,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -764,7 +841,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
<<
", place: "
<<
pair
.
second
.
place
();
<<
", place: "
<<
pair
.
second
.
place
();
auto
lod_tensors
=
pair
.
second
.
SplitLoDTensor
(
member_
->
places_
);
auto
lod_tensors
=
pair
.
second
.
SplitLoDTensor
(
member_
->
places_
);
bool
is_cpu_place
=
platform
::
is_cpu_place
(
member_
->
places_
.
front
());
bool
is_cpu_place
=
platform
::
is_cpu_place
(
member_
->
places_
.
front
());
if
(
!
is_persistable
&&
num_places
!=
lod_tensors
.
size
())
{
if
(
!
is_persistable
&&
num_places
!=
lod_tensors
.
size
()
&&
!
allow_partial_feed
)
{
auto
error_info
=
string
::
Sprintf
(
auto
error_info
=
string
::
Sprintf
(
"The number(%d) of samples[%s] of current batch is less than the "
"The number(%d) of samples[%s] of current batch is less than the "
"count(%d) of devices(%s), currently, it is not allowed. "
,
"count(%d) of devices(%s), currently, it is not allowed. "
,
...
@@ -790,7 +868,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -790,7 +868,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
framework
::
TensorCopy
(
pair
.
second
,
member_
->
places_
.
at
(
i
),
&
tmp
);
framework
::
TensorCopy
(
pair
.
second
,
member_
->
places_
.
at
(
i
),
&
tmp
);
}
}
}
}
if
(
lod_tensors
.
size
()
!=
num_places
)
{
if
(
lod_tensors
.
size
()
!=
num_places
&&
!
allow_partial_feed
)
{
auto
error_info
=
string
::
Sprintf
(
auto
error_info
=
string
::
Sprintf
(
"The number(%d) of samples[%s] of the current batch does not match "
"The number(%d) of samples[%s] of the current batch does not match "
"the count(%d) of devices(%s). Because that %s is a persistable "
"the count(%d) of devices(%s). Because that %s is a persistable "
...
@@ -804,7 +882,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -804,7 +882,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
}
}
for
(
size_t
j
=
0
;
j
<
num_places
;
++
j
)
{
if
(
allow_partial_feed
)
{
if
(
is_persistable
)
{
if
(
persistable_feed_len
==
-
1UL
)
{
persistable_feed_len
=
lod_tensors
.
size
();
}
else
{
PADDLE_ENFORCE_EQ
(
persistable_feed_len
,
lod_tensors
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The feeded number of different persistable variables "
"should be the same"
));
}
}
else
{
if
(
non_persistable_feed_len
==
-
1UL
)
{
non_persistable_feed_len
=
lod_tensors
.
size
();
}
else
{
PADDLE_ENFORCE_EQ
(
non_persistable_feed_len
,
lod_tensors
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The feeded number of different non-persistable variables "
"should be the same"
));
}
}
}
for
(
size_t
j
=
0
;
j
<
lod_tensors
.
size
();
++
j
)
{
auto
*
feed_scope
=
is_persistable
?
member_
->
local_scopes_
[
j
]
auto
*
feed_scope
=
is_persistable
?
member_
->
local_scopes_
[
j
]
:
member_
->
local_exec_scopes_
[
j
];
:
member_
->
local_exec_scopes_
[
j
];
auto
*
feed_var
=
feed_scope
->
Var
(
pair
.
first
);
auto
*
feed_var
=
feed_scope
->
Var
(
pair
.
first
);
...
@@ -814,6 +916,19 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -814,6 +916,19 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t
->
set_lod
(
lod_tensors
[
j
].
lod
());
t
->
set_lod
(
lod_tensors
[
j
].
lod
());
}
}
}
}
if
(
allow_partial_feed
&&
persistable_feed_len
!=
-
1UL
&&
non_persistable_feed_len
!=
-
1UL
)
{
VLOG
(
10
)
<<
"Persistable len "
<<
persistable_feed_len
;
VLOG
(
10
)
<<
"Non persistable len "
<<
non_persistable_feed_len
;
PADDLE_ENFORCE_GE
(
persistable_feed_len
,
non_persistable_feed_len
,
platform
::
errors
::
InvalidArgument
(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"
));
for
(
size_t
i
=
0
;
i
<
non_persistable_feed_len
;
++
i
)
{
member_
->
SetHasFeed
(
i
);
}
}
}
}
ParallelExecutor
::~
ParallelExecutor
()
{
ParallelExecutor
::~
ParallelExecutor
()
{
...
@@ -864,6 +979,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
...
@@ -864,6 +979,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return
enable_parallel_graph
;
return
enable_parallel_graph
;
}
}
const
ir
::
Graph
&
ParallelExecutor
::
Graph
()
const
{
return
member_
->
executor_
->
Graph
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
@@ -871,3 +990,4 @@ USE_PASS(reference_count_pass);
...
@@ -871,3 +990,4 @@ USE_PASS(reference_count_pass);
USE_PASS
(
eager_deletion_pass
);
USE_PASS
(
eager_deletion_pass
);
USE_PASS
(
buffer_shared_inplace_pass
);
USE_PASS
(
buffer_shared_inplace_pass
);
USE_PASS
(
buffer_shared_cross_op_memory_reuse_pass
);
USE_PASS
(
buffer_shared_cross_op_memory_reuse_pass
);
USE_PASS
(
init_reader_device_count_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
a4951843
...
@@ -79,6 +79,8 @@ class ParallelExecutor {
...
@@ -79,6 +79,8 @@ class ParallelExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
);
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
);
const
ir
::
Graph
&
Graph
()
const
;
private:
private:
// broadcast the parameters from the 0th device.
// broadcast the parameters from the 0th device.
// trainer_id the trainer index in nccl distributed training.
// trainer_id the trainer index in nccl distributed training.
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
a4951843
...
@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" and it is set by ParallelExecutor instance, not users."
)
" and it is set by ParallelExecutor instance, not users."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"infer_out"
,
""
).
SetDefault
(
true
);
AddAttr
<
bool
>
(
"infer_out"
,
""
).
SetDefault
(
true
);
AddAttr
<
bool
>
(
"drop_last"
,
"Whether to drop last batches whose number is less than CPU "
"cores/GPU cards number"
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Read Operator
Read Operator
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
a4951843
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "Python.h"
#include "Python.h"
#include "boost/optional.hpp"
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
...
@@ -41,6 +42,58 @@ namespace pybind {
...
@@ -41,6 +42,58 @@ namespace pybind {
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
reader
=
operators
::
reader
;
namespace
reader
=
operators
::
reader
;
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static
boost
::
optional
<
std
::
vector
<
int64_t
>>
DiffTensorShapeWithVarDesc
(
const
framework
::
LoDTensor
&
tensor
,
const
framework
::
VarDesc
&
var_desc
,
size_t
num_places
)
{
auto
tensor_shape
=
tensor
.
dims
();
auto
desc_shape
=
var_desc
.
GetShape
();
int64_t
rank
=
tensor_shape
.
size
();
if
(
UNLIKELY
(
rank
==
0
))
{
if
(
desc_shape
.
size
()
!=
0
)
{
// Tensor rank = 0 but desc does not match
return
framework
::
vectorize
<
int64_t
>
(
tensor_shape
);
}
else
{
return
boost
::
none
;
}
}
PADDLE_ENFORCE_GE
(
tensor_shape
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"Tensor shape must not be less than 0"
));
if
(
!
tensor
.
lod
().
empty
())
{
tensor_shape
[
0
]
=
-
1
;
// unknown shape
}
else
{
int64_t
split_size
=
(
tensor_shape
[
0
]
+
num_places
-
1
)
/
num_places
;
int64_t
remainder
=
(
split_size
==
0
?
0
:
tensor_shape
[
0
]
%
split_size
);
tensor_shape
[
0
]
=
split_size
;
if
(
desc_shape
[
0
]
>=
0
)
{
// need check dim 0
if
(
tensor_shape
[
0
]
!=
desc_shape
[
0
])
{
return
framework
::
vectorize
<
int64_t
>
(
tensor_shape
);
}
if
(
remainder
>
0
)
{
tensor_shape
[
0
]
=
remainder
;
return
framework
::
vectorize
<
int64_t
>
(
tensor_shape
);
}
}
}
for
(
int64_t
idx
=
1
;
idx
<
rank
;
++
idx
)
{
PADDLE_ENFORCE_GE
(
tensor_shape
[
idx
],
0
,
platform
::
errors
::
InvalidArgument
(
"Tensor shape must not be less than 0"
));
if
(
desc_shape
[
idx
]
>=
0
&&
tensor_shape
[
idx
]
!=
desc_shape
[
idx
])
{
return
framework
::
vectorize
<
int64_t
>
(
tensor_shape
);
}
}
return
boost
::
none
;
}
static
const
std
::
shared_ptr
<
reader
::
LoDTensorBlockingQueue
>
&
GetQueue
(
static
const
std
::
shared_ptr
<
reader
::
LoDTensorBlockingQueue
>
&
GetQueue
(
const
std
::
shared_ptr
<
reader
::
LoDTensorBlockingQueue
>
&
queue
,
size_t
idx
)
{
const
std
::
shared_ptr
<
reader
::
LoDTensorBlockingQueue
>
&
queue
,
size_t
idx
)
{
return
queue
;
return
queue
;
...
@@ -66,10 +119,12 @@ class MultiDeviceFeedReader {
...
@@ -66,10 +119,12 @@ class MultiDeviceFeedReader {
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
,
bool
drop_last
)
:
queue_
(
queue
),
:
queue_
(
queue
),
names_
(
names
),
names_
(
names
),
pool_
(
new
::
ThreadPool
(
dst_places
.
size
()))
{
pool_
(
new
::
ThreadPool
(
dst_places
.
size
())),
drop_last_
(
drop_last
)
{
std
::
vector
<
framework
::
DDim
>
dims
;
std
::
vector
<
framework
::
DDim
>
dims
;
for
(
auto
&
shape
:
shapes
)
{
for
(
auto
&
shape
:
shapes
)
{
dims
.
push_back
(
framework
::
make_ddim
(
shape
));
dims
.
push_back
(
framework
::
make_ddim
(
shape
));
...
@@ -113,14 +168,18 @@ class MultiDeviceFeedReader {
...
@@ -113,14 +168,18 @@ class MultiDeviceFeedReader {
ReadAsync
();
ReadAsync
();
}
}
bool
DropLast
()
const
{
return
drop_last_
;
}
ResultDictList
ReadNext
()
{
ResultDictList
ReadNext
()
{
CheckNextStatus
();
CheckNextStatus
();
ResultDictList
result
(
ret_
.
size
());
ResultDictList
result
(
ret_
.
size
());
for
(
size_t
i
=
0
;
i
<
ret_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ret_
.
size
();
++
i
)
{
if
(
!
ret_
[
i
].
empty
())
{
for
(
size_t
j
=
0
;
j
<
names_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
names_
.
size
();
++
j
)
{
result
[
i
].
emplace
(
names_
[
j
],
std
::
move
(
ret_
[
i
][
j
]));
result
[
i
].
emplace
(
names_
[
j
],
std
::
move
(
ret_
[
i
][
j
]));
}
}
}
}
}
ReadAsync
();
ReadAsync
();
return
result
;
return
result
;
}
}
...
@@ -155,24 +214,29 @@ class MultiDeviceFeedReader {
...
@@ -155,24 +214,29 @@ class MultiDeviceFeedReader {
};
};
Status
WaitFutures
(
std
::
exception_ptr
*
excep
)
{
Status
WaitFutures
(
std
::
exception_ptr
*
excep
)
{
bool
is_success
=
true
;
*
excep
=
nullptr
;
*
excep
=
nullptr
;
size_t
success_num
=
0
;
for
(
size_t
i
=
0
;
i
<
futures_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
futures_
.
size
();
++
i
)
{
auto
each_status
=
futures_
[
i
].
get
();
auto
each_status
=
futures_
[
i
].
get
();
if
(
UNLIKELY
(
each_status
!=
Status
::
kSuccess
))
{
if
(
UNLIKELY
(
each_status
!=
Status
::
kSuccess
))
{
is_success
=
false
;
if
(
UNLIKELY
(
each_status
==
Status
::
kException
))
{
if
(
UNLIKELY
(
each_status
==
Status
::
kException
))
{
PADDLE_ENFORCE_NOT_NULL
(
exceptions_
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
exceptions_
[
i
]);
*
excep
=
exceptions_
[
i
];
*
excep
=
exceptions_
[
i
];
exceptions_
[
i
]
=
nullptr
;
exceptions_
[
i
]
=
nullptr
;
}
}
}
else
{
++
success_num
;
}
}
}
}
if
(
UNLIKELY
(
*
excep
))
{
if
(
UNLIKELY
(
*
excep
))
{
return
Status
::
kException
;
return
Status
::
kException
;
}
if
(
drop_last_
)
{
return
success_num
==
futures_
.
size
()
?
Status
::
kSuccess
:
Status
::
kEOF
;
}
else
{
}
else
{
return
is_success
?
Status
::
kSuccess
:
Status
::
kEOF
;
return
success_num
>
0
?
Status
::
kSuccess
:
Status
::
kEOF
;
}
}
}
}
...
@@ -226,6 +290,7 @@ class MultiDeviceFeedReader {
...
@@ -226,6 +290,7 @@ class MultiDeviceFeedReader {
std
::
vector
<
std
::
exception_ptr
>
exceptions_
;
std
::
vector
<
std
::
exception_ptr
>
exceptions_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
ret_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
ret_
;
bool
drop_last_
;
};
};
template
<
typename
QueueType
>
template
<
typename
QueueType
>
...
@@ -270,6 +335,17 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
...
@@ -270,6 +335,17 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
void
BindReader
(
py
::
module
*
module
)
{
void
BindReader
(
py
::
module
*
module
)
{
auto
&
m
=
*
module
;
auto
&
m
=
*
module
;
m
.
def
(
"diff_tensor_shape"
,
[](
const
framework
::
LoDTensor
&
tensor
,
const
framework
::
VarDesc
&
var_desc
,
size_t
num_places
)
->
py
::
object
{
auto
diff
=
DiffTensorShapeWithVarDesc
(
tensor
,
var_desc
,
num_places
);
if
(
diff
)
{
return
py
::
cast
(
std
::
move
(
diff
.
get
()));
}
else
{
return
py
::
cast
(
nullptr
);
}
});
m
.
def
(
"init_lod_tensor_blocking_queue"
,
m
.
def
(
"init_lod_tensor_blocking_queue"
,
[](
framework
::
Variable
&
var
,
size_t
capacity
,
[](
framework
::
Variable
&
var
,
size_t
capacity
,
bool
is_ordered
)
->
py
::
object
{
bool
is_ordered
)
->
py
::
object
{
...
@@ -337,10 +413,10 @@ void BindReader(py::module *module) {
...
@@ -337,10 +413,10 @@ void BindReader(py::module *module) {
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
{
bool
use_double_buffer
,
bool
drop_last
)
{
return
new
MultiDeviceFeedReader
<
reader
::
LoDTensorBlockingQueue
>
(
return
new
MultiDeviceFeedReader
<
reader
::
LoDTensorBlockingQueue
>
(
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
use_double_buffer
);
use_double_buffer
,
drop_last
);
},
},
py
::
return_value_policy
::
take_ownership
);
py
::
return_value_policy
::
take_ownership
);
...
@@ -352,13 +428,13 @@ void BindReader(py::module *module) {
...
@@ -352,13 +428,13 @@ void BindReader(py::module *module) {
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
,
bool
use_double_buffer
)
{
bool
drop_last
)
{
queue
->
SetDeviceCount
(
dst_places
.
size
());
queue
->
SetDeviceCount
(
dst_places
.
size
());
return
new
MultiDeviceFeedReader
<
return
new
MultiDeviceFeedReader
<
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
use_double_buffer
);
use_double_buffer
,
drop_last
);
},
},
py
::
return_value_policy
::
take_ownership
);
py
::
return_value_policy
::
take_ownership
);
}
}
...
...
python/paddle/fluid/executor.py
浏览文件 @
a4951843
...
@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
...
@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
the feed value
the feed value
"""
"""
if
var
.
desc
.
need_check_feed
():
if
var
.
desc
.
need_check_feed
():
feed_shape
=
feed
.
shape
()
diff_shape
=
core
.
diff_tensor_shape
(
feed
,
var
.
desc
,
num_places
)
if
six
.
PY2
:
if
diff_shape
is
not
None
:
feed_shape
[
0
]
=
long
(
feed_shape
[
0
]
/
num_places
)
if
len
(
feed
.
lod
())
==
0
else
-
1
else
:
feed_shape
[
0
]
=
int
(
feed_shape
[
0
]
/
num_places
)
if
len
(
feed
.
lod
())
==
0
else
-
1
if
not
dimension_is_compatible_with
(
feed_shape
,
var
.
shape
):
raise
ValueError
(
raise
ValueError
(
'The feeded Variable %r should have dimensions = %d, shape = '
'The feeded Variable %r should have dimensions = %d, shape = '
'%r, but received feeded shape %r on each device'
%
'%r, but received feeded shape %r on each device'
%
(
var
.
name
,
len
(
var
.
shape
),
var
.
shape
,
feed
_shape
))
(
var
.
name
,
len
(
var
.
shape
),
var
.
shape
,
diff
_shape
))
if
not
dtype_is_compatible_with
(
feed
.
_dtype
(),
var
.
dtype
):
if
not
dtype_is_compatible_with
(
feed
.
_dtype
(),
var
.
dtype
):
var_dtype_format
=
convert_dtype
(
var
.
dtype
)
if
isinstance
(
var_dtype_format
=
convert_dtype
(
var
.
dtype
)
if
isinstance
(
var
.
dtype
,
core
.
VarDesc
.
VarType
)
else
var
.
dtype
var
.
dtype
,
core
.
VarDesc
.
VarType
)
else
var
.
dtype
...
@@ -646,11 +640,6 @@ class Executor(object):
...
@@ -646,11 +640,6 @@ class Executor(object):
exe
.
feed_and_split_tensor_into_local_scopes
(
feed_tensor_dict
)
exe
.
feed_and_split_tensor_into_local_scopes
(
feed_tensor_dict
)
elif
isinstance
(
feed
,
list
)
or
isinstance
(
feed
,
tuple
):
elif
isinstance
(
feed
,
list
)
or
isinstance
(
feed
,
tuple
):
if
len
(
feed
)
!=
len
(
program
.
_places
):
raise
ValueError
(
"Feed a list of tensor, the list should be the same size as places"
)
res
=
list
()
res
=
list
()
for
i
,
each
in
enumerate
(
feed
):
for
i
,
each
in
enumerate
(
feed
):
if
not
isinstance
(
each
,
dict
):
if
not
isinstance
(
each
,
dict
):
...
...
python/paddle/fluid/reader.py
浏览文件 @
a4951843
...
@@ -88,6 +88,7 @@ class DataLoader(object):
...
@@ -88,6 +88,7 @@ class DataLoader(object):
iterable
=
True
,
iterable
=
True
,
return_list
=
False
,
return_list
=
False
,
use_multiprocess
=
False
,
use_multiprocess
=
False
,
drop_last
=
True
,
keep_order
=
False
):
keep_order
=
False
):
"""
"""
Create a DataLoader object for loading data from Python generator.
Create a DataLoader object for loading data from Python generator.
...
@@ -134,6 +135,9 @@ class DataLoader(object):
...
@@ -134,6 +135,9 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode,
can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect.
whether this parameter is set or not has no effect.
The Default value is False.
The Default value is False.
drop_last (bool): whether to drop the last batches whose number is
less than the CPU core/GPU card number. The default value is
True.
keep_order (bool): whether to assign the data to CPU cores or GPU
keep_order (bool): whether to assign the data to CPU cores or GPU
cards in order. Supposing that there are 2 batches and we use
cards in order. Supposing that there are 2 batches and we use
2 GPU cards to run the network. If keep_order=True, GPU 0 would
2 GPU cards to run the network. If keep_order=True, GPU 0 would
...
@@ -289,7 +293,7 @@ class DataLoader(object):
...
@@ -289,7 +293,7 @@ class DataLoader(object):
return_list
,
use_multiprocess
)
return_list
,
use_multiprocess
)
else
:
else
:
return
GeneratorLoader
(
feed_list
,
capacity
,
use_double_buffer
,
return
GeneratorLoader
(
feed_list
,
capacity
,
use_double_buffer
,
iterable
,
return_list
,
keep_order
)
iterable
,
return_list
,
drop_last
,
keep_order
)
@
staticmethod
@
staticmethod
def
from_dataset
(
dataset
,
places
,
drop_last
=
True
):
def
from_dataset
(
dataset
,
places
,
drop_last
=
True
):
...
@@ -422,7 +426,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -422,7 +426,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
core
.
Variable
(),
self
.
_capacity
,
False
)
core
.
Variable
(),
self
.
_capacity
,
False
)
self
.
_reader
=
core
.
create_py_reader
(
self
.
_reader
=
core
.
create_py_reader
(
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
,
True
)
def
_start
(
self
):
def
_start
(
self
):
if
self
.
_use_multiprocess
:
if
self
.
_use_multiprocess
:
...
@@ -628,6 +632,7 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -628,6 +632,7 @@ class GeneratorLoader(DataLoaderBase):
use_double_buffer
=
True
,
use_double_buffer
=
True
,
iterable
=
True
,
iterable
=
True
,
return_list
=
False
,
return_list
=
False
,
drop_last
=
True
,
keep_order
=
False
):
keep_order
=
False
):
self
.
_tensor_reader
=
None
self
.
_tensor_reader
=
None
self
.
_places
=
None
self
.
_places
=
None
...
@@ -635,6 +640,8 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -635,6 +640,8 @@ class GeneratorLoader(DataLoaderBase):
self
.
_queue
=
None
self
.
_queue
=
None
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_exited
=
False
self
.
_exited
=
False
self
.
_drop_last
=
drop_last
self
.
_keep_order
=
keep_order
if
not
capacity
:
if
not
capacity
:
raise
ValueError
(
"Please give value to capacity."
)
raise
ValueError
(
"Please give value to capacity."
)
self
.
_iterable
=
iterable
self
.
_iterable
=
iterable
...
@@ -643,7 +650,6 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -643,7 +650,6 @@ class GeneratorLoader(DataLoaderBase):
raise
Exception
(
"Feed list must be given under static mode."
)
raise
Exception
(
"Feed list must be given under static mode."
)
self
.
_use_double_buffer
=
use_double_buffer
self
.
_use_double_buffer
=
use_double_buffer
self
.
_capacity
=
capacity
self
.
_capacity
=
capacity
self
.
_keep_order
=
keep_order
if
not
self
.
_iterable
:
if
not
self
.
_iterable
:
self
.
_init_non_iterable
()
self
.
_init_non_iterable
()
...
@@ -667,7 +673,8 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -667,7 +673,8 @@ class GeneratorLoader(DataLoaderBase):
core
.
Variable
(),
self
.
_capacity
,
self
.
_keep_order
)
core
.
Variable
(),
self
.
_capacity
,
self
.
_keep_order
)
self
.
_reader
=
core
.
create_py_reader
(
self
.
_reader
=
core
.
create_py_reader
(
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
,
self
.
_drop_last
)
def
_init_non_iterable
(
self
):
def
_init_non_iterable
(
self
):
lod_levels
=
[]
lod_levels
=
[]
...
@@ -744,7 +751,8 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -744,7 +751,8 @@ class GeneratorLoader(DataLoaderBase):
default_main_program
().
current_block
().
append_op
(
default_main_program
().
current_block
().
append_op
(
type
=
'read'
,
type
=
'read'
,
inputs
=
{
'Reader'
:
[
self
.
_reader
]},
inputs
=
{
'Reader'
:
[
self
.
_reader
]},
outputs
=
{
'Out'
:
self
.
_feed_list
})
outputs
=
{
'Out'
:
self
.
_feed_list
},
attrs
=
{
'drop_last'
:
self
.
_drop_last
})
@
property
@
property
def
queue
(
self
):
def
queue
(
self
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
a4951843
...
@@ -355,4 +355,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
...
@@ -355,4 +355,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_dataloader_keep_order
test_optimizer_in_control_flow test_dataloader_keep_order
test_parallel_executor_inference_feed_partial_data
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS
"RUN_TYPE=DIST"
)
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS
"RUN_TYPE=DIST"
)
python/paddle/fluid/tests/unittests/test_parallel_executor_inference_feed_partial_data.py
0 → 100644
浏览文件 @
a4951843
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
numpy
as
np
import
unittest
import
six
class
TestInferencePartialFeed
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
iterations
=
10
self
.
size
=
10
def
run_network
(
self
,
places
,
use_split
):
startup_prog
=
fluid
.
Program
()
main_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
None
,
self
.
size
],
dtype
=
'float32'
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
None
,
self
.
size
],
dtype
=
'float32'
)
lr
=
fluid
.
data
(
name
=
'lr'
,
shape
=
[
1
],
dtype
=
'float32'
)
lr
.
persistable
=
True
relu_x
=
fluid
.
layers
.
relu
(
x
)
relu_y
=
fluid
.
layers
.
relu
(
y
)
relu_lr
=
fluid
.
layers
.
relu
(
lr
)
exe
=
fluid
.
Executor
(
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
).
with_data_parallel
(
places
=
places
)
gen_random
=
lambda
shape
:
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
shape
).
astype
(
'float32'
)
assert_result
=
lambda
feed
,
result
:
self
.
assertTrue
(
np
.
array_equal
(
np
.
maximum
(
0
,
feed
),
result
))
def
feed_split_test
():
for
place_num
in
six
.
moves
.
range
(
1
,
len
(
places
)
*
3
):
x_np
=
gen_random
([
place_num
,
self
.
size
])
y_np
=
gen_random
([
place_num
,
self
.
size
])
if
place_num
<=
len
(
places
):
lr_np
=
gen_random
([
place_num
])
else
:
lr_np
=
gen_random
([
1
])
relu_x_np
,
relu_y_np
,
relu_lr_np
=
exe
.
run
(
prog
,
feed
=
{
x
.
name
:
x_np
,
y
.
name
:
y_np
,
lr
.
name
:
lr_np
},
fetch_list
=
[
relu_x
,
relu_y
,
relu_lr
])
assert_result
(
x_np
,
relu_x_np
)
assert_result
(
y_np
,
relu_y_np
)
if
place_num
<=
len
(
places
):
assert_result
(
lr_np
,
relu_lr_np
)
else
:
expected_relu_lr_np
=
max
(
lr_np
[
0
],
0
)
self
.
assertTrue
(
np
.
all
(
expected_relu_lr_np
==
relu_lr_np
))
def
feed_list_test
():
for
place_num
in
six
.
moves
.
range
(
1
,
len
(
places
)
+
1
):
x_np_list
=
[]
y_np_list
=
[]
lr_np_list
=
[]
feed_list
=
[]
for
_
in
six
.
moves
.
range
(
place_num
):
x_np
=
gen_random
([
1
,
self
.
size
])
y_np
=
gen_random
([
1
,
self
.
size
])
lr_np
=
gen_random
([
1
])
x_np_list
.
append
(
x_np
)
y_np_list
.
append
(
y_np
)
lr_np_list
.
append
(
lr_np
)
feed_list
.
append
({
x
.
name
:
x_np
,
y
.
name
:
y_np
,
lr
.
name
:
lr_np
})
relu_x_np
,
relu_y_np
,
relu_lr_np
=
exe
.
run
(
prog
,
feed
=
feed_list
,
fetch_list
=
[
relu_x
,
relu_y
,
relu_lr
])
x_np
=
np
.
concatenate
(
x_np_list
)
y_np
=
np
.
concatenate
(
y_np_list
)
lr_np
=
np
.
concatenate
(
lr_np_list
)
assert_result
(
x_np
,
relu_x_np
)
assert_result
(
y_np
,
relu_y_np
)
assert_result
(
lr_np
,
relu_lr_np
)
for
_
in
six
.
moves
.
range
(
self
.
iterations
):
if
use_split
:
feed_split_test
()
else
:
feed_list_test
()
def
test_main
(
self
):
places
=
[
fluid
.
cpu_places
(
4
)]
if
fluid
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
cuda_places
())
for
p
in
places
:
self
.
run_network
(
p
,
use_split
=
True
)
self
.
run_network
(
p
,
use_split
=
False
)
class
TestInferencePartialFeedUsingDataLoader
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
epoch_num
=
3
self
.
batch_num
=
101
# a prime number
self
.
batch_size
=
32
def
create_reader
(
self
):
def
__impl__
():
for
_
in
six
.
moves
.
range
(
self
.
batch_num
):
yield
np
.
random
.
random
([
self
.
batch_size
,
1
]).
astype
(
'float32'
),
return
__impl__
def
run_network
(
self
,
iterable
,
use_cuda
,
drop_last
):
x
=
fluid
.
data
(
shape
=
[
None
,
1
],
name
=
'x'
,
dtype
=
'float32'
)
places
=
fluid
.
cuda_places
()
if
use_cuda
else
fluid
.
cpu_places
(
4
)
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
x
],
capacity
=
16
,
iterable
=
iterable
,
drop_last
=
drop_last
)
y
=
fluid
.
layers
.
fc
(
x
,
size
=
10
)
loss
=
fluid
.
layers
.
reduce_mean
(
y
)
exe
=
fluid
.
Executor
(
places
[
0
])
exe
.
run
(
fluid
.
default_startup_program
())
prog
=
fluid
.
CompiledProgram
(
fluid
.
default_main_program
(
)).
with_data_parallel
(
places
=
places
,
loss_name
=
loss
.
name
)
loader
.
set_batch_generator
(
self
.
create_reader
(),
places
=
places
if
iterable
else
None
)
for
_
in
six
.
moves
.
range
(
self
.
epoch_num
):
actual_batch_num
=
0
if
loader
.
iterable
:
for
feed_data
in
loader
():
x_data
,
=
exe
.
run
(
prog
,
feed
=
feed_data
,
fetch_list
=
[
x
])
self
.
assertEqual
(
x_data
.
shape
[
0
]
%
self
.
batch_size
,
0
)
self
.
assertTrue
(
x_data
.
shape
[
0
]
!=
0
)
actual_batch_num
+=
int
(
x_data
.
shape
[
0
]
/
self
.
batch_size
)
else
:
loader
.
start
()
try
:
while
True
:
x_data
,
=
exe
.
run
(
prog
,
fetch_list
=
[
x
])
self
.
assertEqual
(
x_data
.
shape
[
0
]
%
self
.
batch_size
,
0
)
self
.
assertTrue
(
x_data
.
shape
[
0
]
!=
0
)
actual_batch_num
+=
int
(
x_data
.
shape
[
0
]
/
self
.
batch_size
)
except
fluid
.
core
.
EOFException
:
loader
.
reset
()
if
not
drop_last
or
len
(
places
)
==
1
:
self
.
assertEqual
(
self
.
batch_num
,
actual_batch_num
)
else
:
self
.
assertGreater
(
self
.
batch_num
,
actual_batch_num
)
def
test_main
(
self
):
use_cuda_list
=
[
False
,
True
]
if
fluid
.
is_compiled_with_cuda
(
)
else
[
False
]
iterable_list
=
[
False
,
True
]
drop_last_list
=
[
False
,
True
]
for
iterable
in
iterable_list
:
for
use_cuda
in
use_cuda_list
:
for
drop_last
in
drop_last_list
:
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
self
.
run_network
(
iterable
,
use_cuda
,
drop_last
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录