Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5ce1a960
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ce1a960
编写于
9月 12, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move bcast op into pass
上级
b681537e
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
82 addition
and
30 deletion
+82
-30
benchmark/fluid/args.py
benchmark/fluid/args.py
+6
-0
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+9
-0
benchmark/fluid/models/mnist.py
benchmark/fluid/models/mnist.py
+7
-4
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+6
-1
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+7
-0
paddle/fluid/framework/details/data_balance_op_handle.cc
paddle/fluid/framework/details/data_balance_op_handle.cc
+7
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+32
-10
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+5
-1
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+0
-10
未找到文件。
benchmark/fluid/args.py
浏览文件 @
5ce1a960
...
@@ -140,5 +140,11 @@ def parse_args():
...
@@ -140,5 +140,11 @@ def parse_args():
'--use_lars'
,
'--use_lars'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
choices
=
[
'reduce'
,
'all_reduce'
],
default
=
'all_reduce'
,
help
=
'Specify the reduce strategy, can be reduce, all_reduce'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
benchmark/fluid/fluid_benchmark.py
浏览文件 @
5ce1a960
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
strategy
=
fluid
.
ExecutionStrategy
()
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
args
.
cpus
strategy
.
num_threads
=
args
.
cpus
strategy
.
allow_op_delay
=
False
strategy
.
allow_op_delay
=
False
build_strategy
=
fluid
.
BuildStrategy
()
if
args
.
reduce_strategy
==
"reduce"
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
Reduce
else
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
AllReduce
avg_loss
=
train_args
[
0
]
avg_loss
=
train_args
[
0
]
if
args
.
update_method
==
"pserver"
:
if
args
.
update_method
==
"pserver"
:
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
avg_loss
.
name
,
avg_loss
.
name
,
main_program
=
train_prog
,
main_program
=
train_prog
,
exec_strategy
=
strategy
,
exec_strategy
=
strategy
,
build_strategy
=
build_strategy
,
num_trainers
=
num_trainers
,
num_trainers
=
num_trainers
,
trainer_id
=
trainer_id
)
trainer_id
=
trainer_id
)
...
...
benchmark/fluid/models/mnist.py
浏览文件 @
5ce1a960
...
@@ -67,11 +67,14 @@ def cnn_model(data):
...
@@ -67,11 +67,14 @@ def cnn_model(data):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
# NOTE: mnist is small, we don't implement data sharding yet.
# NOTE: mnist is small, we don't implement data sharding yet.
filelist
=
[
opt
=
None
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
data_file_handle
=
None
]
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
if
args
.
use_reader_op
:
if
args
.
use_reader_op
:
filelist
=
[
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
]
data_file_handle
=
fluid
.
layers
.
open_files
(
data_file_handle
=
fluid
.
layers
.
open_files
(
filenames
=
filelist
,
filenames
=
filelist
,
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
if
is_train
:
if
is_train
:
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
opt
.
minimize
()
opt
.
minimize
(
avg_cost
)
if
args
.
memory_optimize
:
if
args
.
memory_optimize
:
fluid
.
memory_optimize
(
main_prog
)
fluid
.
memory_optimize
(
main_prog
)
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...
@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
#endif
void
AllReduceOpHandle
::
RunImpl
()
{
void
AllReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"all_reduce"
,
nullptr
);
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
NoDummyInputSize
()
==
1
)
{
if
(
NoDummyInputSize
()
==
1
)
{
return
;
// No need to all reduce when GPU count = 1;
return
;
// No need to all reduce when GPU count = 1;
}
else
{
}
else
{
...
...
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -15,12 +15,19 @@
...
@@ -15,12 +15,19 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
BroadcastOpHandle
::
RunImpl
()
{
void
BroadcastOpHandle
::
RunImpl
()
{
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// The input and output may have dummy vars.
// The input and output may have dummy vars.
...
...
paddle/fluid/framework/details/data_balance_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
...
@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
}
}
void
DataBalanceOpHandle
::
RunImpl
()
{
void
DataBalanceOpHandle
::
RunImpl
()
{
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
PADDLE_ENFORCE_GT
(
places_
.
size
(),
1
,
PADDLE_ENFORCE_GT
(
places_
.
size
(),
1
,
"Data balance can only be enabled when the number of "
"Data balance can only be enabled when the number of "
"places to run larger than 1."
);
"places to run larger than 1."
);
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
5ce1a960
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
CreateRPCOp
(
&
result
,
node
);
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_vars_attr
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE
(
recv_vars_attr
.
size
()
==
2UL
);
// [parameter, gradient]
if
(
recv_vars_attr
[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
bcast_var_name_set
[
op_dev_id
].
emplace
(
recv_vars_attr
[
0
]);
}
}
is_dist_train
=
true
;
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
CreateDistTrainOp
(
&
result
,
node
);
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
}
}
else
if
(
IsScaleLossOp
(
node
))
{
}
else
if
(
IsScaleLossOp
(
node
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
...
@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
.
emplace
(
g_name
,
cur_device_id
);
if
(
!
is_dist_train
)
{
// will send gradients directly when distributed training
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
}
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
...
@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
}
}
}
bool
use_gpu
=
false
;
bool
use_gpu
=
false
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
use_gpu
=
nccl_ctxs_
!=
nullptr
;
use_gpu
=
nccl_ctxs_
!=
nullptr
;
#endif
#endif
if
(
use_gpu
||
if
((
use_gpu
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
||
is_dist_train
)
{
// Insert BCast Ops
// Insert BCast Ops
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
...
@@ -676,7 +696,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -676,7 +696,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return
var
;
return
var
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
...
@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node
->
Op
()
->
Type
());
node
->
Op
()
->
Type
());
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
return
op_dev_id
;
}
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
...
@@ -738,7 +759,7 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...
@@ -738,7 +759,7 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
...
@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
}
}
}
}
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
5ce1a960
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -27,7 +27,11 @@ namespace framework {
...
@@ -27,7 +27,11 @@ namespace framework {
namespace
details
{
namespace
details
{
void
ReduceOpHandle
::
RunImpl
()
{
void
ReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"reduce"
,
nullptr
);
if
(
dev_ctxes_
.
size
()
>
0UL
)
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
}
else
{
platform
::
RecordEvent
record_event
(
Name
(),
nullptr
);
}
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// the input and output may have dummy var.
// the input and output may have dummy var.
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
5ce1a960
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->
stream
();
->
stream
();
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
VLOG
(
1
)
<<
place_
<<
"RUN Scale loss grad op"
;
VLOG
(
1
0
)
<<
place_
<<
"RUN Scale loss grad op"
;
});
});
#endif
#endif
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
5ce1a960
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
size_t
>
())
size_t
>
())
.
def
(
"_bcast_params"
,
&
ParallelExecutor
::
BCastParamsToDevices
)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
5ce1a960
...
@@ -279,21 +279,11 @@ class ParallelExecutor(object):
...
@@ -279,21 +279,11 @@ class ParallelExecutor(object):
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
if
self
.
is_dist
:
self
.
_bcast_params
()
if
return_numpy
:
if
return_numpy
:
return
executor
.
as_numpy
(
arr
)
return
executor
.
as_numpy
(
arr
)
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
def
_bcast_params
(
self
):
"""
Broadcast the parameters to other devices. It is used during
distributed training.
"""
self
.
executor
.
_bcast_params
(
set
(
self
.
persistable_vars
))
@
property
@
property
def
device_count
(
self
):
def
device_count
(
self
):
return
len
(
self
.
_act_places
)
return
len
(
self
.
_act_places
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录