Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5ce1a960
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ce1a960
编写于
9月 12, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move bcast op into pass
上级
b681537e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
82 addition
and
30 deletion
+82
-30
benchmark/fluid/args.py
benchmark/fluid/args.py
+6
-0
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+9
-0
benchmark/fluid/models/mnist.py
benchmark/fluid/models/mnist.py
+7
-4
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+6
-1
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+7
-0
paddle/fluid/framework/details/data_balance_op_handle.cc
paddle/fluid/framework/details/data_balance_op_handle.cc
+7
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+32
-10
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+5
-1
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+0
-10
未找到文件。
benchmark/fluid/args.py
浏览文件 @
5ce1a960
...
@@ -140,5 +140,11 @@ def parse_args():
...
@@ -140,5 +140,11 @@ def parse_args():
'--use_lars'
,
'--use_lars'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
choices
=
[
'reduce'
,
'all_reduce'
],
default
=
'all_reduce'
,
help
=
'Specify the reduce strategy, can be reduce, all_reduce'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
benchmark/fluid/fluid_benchmark.py
浏览文件 @
5ce1a960
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
strategy
=
fluid
.
ExecutionStrategy
()
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
args
.
cpus
strategy
.
num_threads
=
args
.
cpus
strategy
.
allow_op_delay
=
False
strategy
.
allow_op_delay
=
False
build_strategy
=
fluid
.
BuildStrategy
()
if
args
.
reduce_strategy
==
"reduce"
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
Reduce
else
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
AllReduce
avg_loss
=
train_args
[
0
]
avg_loss
=
train_args
[
0
]
if
args
.
update_method
==
"pserver"
:
if
args
.
update_method
==
"pserver"
:
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
avg_loss
.
name
,
avg_loss
.
name
,
main_program
=
train_prog
,
main_program
=
train_prog
,
exec_strategy
=
strategy
,
exec_strategy
=
strategy
,
build_strategy
=
build_strategy
,
num_trainers
=
num_trainers
,
num_trainers
=
num_trainers
,
trainer_id
=
trainer_id
)
trainer_id
=
trainer_id
)
...
...
benchmark/fluid/models/mnist.py
浏览文件 @
5ce1a960
...
@@ -67,11 +67,14 @@ def cnn_model(data):
...
@@ -67,11 +67,14 @@ def cnn_model(data):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
# NOTE: mnist is small, we don't implement data sharding yet.
# NOTE: mnist is small, we don't implement data sharding yet.
filelist
=
[
opt
=
None
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
data_file_handle
=
None
]
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
if
args
.
use_reader_op
:
if
args
.
use_reader_op
:
filelist
=
[
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
]
data_file_handle
=
fluid
.
layers
.
open_files
(
data_file_handle
=
fluid
.
layers
.
open_files
(
filenames
=
filelist
,
filenames
=
filelist
,
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
if
is_train
:
if
is_train
:
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
opt
.
minimize
()
opt
.
minimize
(
avg_cost
)
if
args
.
memory_optimize
:
if
args
.
memory_optimize
:
fluid
.
memory_optimize
(
main_prog
)
fluid
.
memory_optimize
(
main_prog
)
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...
@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
#endif
void
AllReduceOpHandle
::
RunImpl
()
{
void
AllReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"all_reduce"
,
nullptr
);
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
NoDummyInputSize
()
==
1
)
{
if
(
NoDummyInputSize
()
==
1
)
{
return
;
// No need to all reduce when GPU count = 1;
return
;
// No need to all reduce when GPU count = 1;
}
else
{
}
else
{
...
...
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -15,12 +15,19 @@
...
@@ -15,12 +15,19 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
BroadcastOpHandle
::
RunImpl
()
{
void
BroadcastOpHandle
::
RunImpl
()
{
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// The input and output may have dummy vars.
// The input and output may have dummy vars.
...
...
paddle/fluid/framework/details/data_balance_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
...
@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
}
}
void
DataBalanceOpHandle
::
RunImpl
()
{
void
DataBalanceOpHandle
::
RunImpl
()
{
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
PADDLE_ENFORCE_GT
(
places_
.
size
(),
1
,
PADDLE_ENFORCE_GT
(
places_
.
size
(),
1
,
"Data balance can only be enabled when the number of "
"Data balance can only be enabled when the number of "
"places to run larger than 1."
);
"places to run larger than 1."
);
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
5ce1a960
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
CreateRPCOp
(
&
result
,
node
);
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_vars_attr
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE
(
recv_vars_attr
.
size
()
==
2UL
);
// [parameter, gradient]
if
(
recv_vars_attr
[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
bcast_var_name_set
[
op_dev_id
].
emplace
(
recv_vars_attr
[
0
]);
}
}
is_dist_train
=
true
;
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
CreateDistTrainOp
(
&
result
,
node
);
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
}
}
else
if
(
IsScaleLossOp
(
node
))
{
}
else
if
(
IsScaleLossOp
(
node
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
...
@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
if
(
!
is_dist_train
)
{
// will send gradients directly when distributed training
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
}
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
...
@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
}
}
}
bool
use_gpu
=
false
;
bool
use_gpu
=
false
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
use_gpu
=
nccl_ctxs_
!=
nullptr
;
use_gpu
=
nccl_ctxs_
!=
nullptr
;
#endif
#endif
if
(
use_gpu
||
if
((
use_gpu
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
||
is_dist_train
)
{
// Insert BCast Ops
// Insert BCast Ops
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
...
@@ -676,8 +696,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -676,8 +696,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return
var
;
return
var
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
...
@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node
->
Op
()
->
Type
());
node
->
Op
()
->
Type
());
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
return
op_dev_id
;
}
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
...
@@ -738,8 +759,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...
@@ -738,8 +759,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
...
@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
}
}
}
}
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
5ce1a960
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -27,7 +27,11 @@ namespace framework {
...
@@ -27,7 +27,11 @@ namespace framework {
namespace
details
{
namespace
details
{
void
ReduceOpHandle
::
RunImpl
()
{
void
ReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"reduce"
,
nullptr
);
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// the input and output may have dummy var.
// the input and output may have dummy var.
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->
stream
();
->
stream
();
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
VLOG
(
1
)
<<
place_
<<
"RUN Scale loss grad op"
;
VLOG
(
1
0
)
<<
place_
<<
"RUN Scale loss grad op"
;
});
});
#endif
#endif
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
5ce1a960
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
size_t
>
())
size_t
>
())
.
def
(
"_bcast_params"
,
&
ParallelExecutor
::
BCastParamsToDevices
)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
5ce1a960
...
@@ -279,21 +279,11 @@ class ParallelExecutor(object):
...
@@ -279,21 +279,11 @@ class ParallelExecutor(object):
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
if
self
.
is_dist
:
self
.
_bcast_params
()
if
return_numpy
:
if
return_numpy
:
return
executor
.
as_numpy
(
arr
)
return
executor
.
as_numpy
(
arr
)
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
def
_bcast_params
(
self
):
"""
Broadcast the parameters to other devices. It is used during
distributed training.
"""
self
.
executor
.
_bcast_params
(
set
(
self
.
persistable_vars
))
@
property
@
property
def
device_count
(
self
):
def
device_count
(
self
):
return
len
(
self
.
_act_places
)
return
len
(
self
.
_act_places
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录