Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d5118a9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d5118a9
编写于
9月 14, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'ups/develop' into refine/op/lstm
上级
ff858d35
23ec966c
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
68 addition
and
35 deletion
+68
-35
benchmark/fluid/args.py
benchmark/fluid/args.py
+6
-0
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+9
-0
benchmark/fluid/models/mnist.py
benchmark/fluid/models/mnist.py
+7
-4
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+2
-1
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+3
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+36
-10
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+2
-1
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+0
-15
未找到文件。
benchmark/fluid/args.py
浏览文件 @
7d5118a9
...
@@ -140,5 +140,11 @@ def parse_args():
...
@@ -140,5 +140,11 @@ def parse_args():
'--use_lars'
,
'--use_lars'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
choices
=
[
'reduce'
,
'all_reduce'
],
default
=
'all_reduce'
,
help
=
'Specify the reduce strategy, can be reduce, all_reduce'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
benchmark/fluid/fluid_benchmark.py
浏览文件 @
7d5118a9
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
strategy
=
fluid
.
ExecutionStrategy
()
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
args
.
cpus
strategy
.
num_threads
=
args
.
cpus
strategy
.
allow_op_delay
=
False
strategy
.
allow_op_delay
=
False
build_strategy
=
fluid
.
BuildStrategy
()
if
args
.
reduce_strategy
==
"reduce"
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
Reduce
else
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
(
).
ReduceStrategy
.
AllReduce
avg_loss
=
train_args
[
0
]
avg_loss
=
train_args
[
0
]
if
args
.
update_method
==
"pserver"
:
if
args
.
update_method
==
"pserver"
:
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
...
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
avg_loss
.
name
,
avg_loss
.
name
,
main_program
=
train_prog
,
main_program
=
train_prog
,
exec_strategy
=
strategy
,
exec_strategy
=
strategy
,
build_strategy
=
build_strategy
,
num_trainers
=
num_trainers
,
num_trainers
=
num_trainers
,
trainer_id
=
trainer_id
)
trainer_id
=
trainer_id
)
...
...
benchmark/fluid/models/mnist.py
浏览文件 @
7d5118a9
...
@@ -67,11 +67,14 @@ def cnn_model(data):
...
@@ -67,11 +67,14 @@ def cnn_model(data):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
def
get_model
(
args
,
is_train
,
main_prog
,
startup_prog
):
# NOTE: mnist is small, we don't implement data sharding yet.
# NOTE: mnist is small, we don't implement data sharding yet.
filelist
=
[
opt
=
None
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
data_file_handle
=
None
]
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
if
args
.
use_reader_op
:
if
args
.
use_reader_op
:
filelist
=
[
os
.
path
.
join
(
args
.
data_path
,
f
)
for
f
in
os
.
listdir
(
args
.
data_path
)
]
data_file_handle
=
fluid
.
layers
.
open_files
(
data_file_handle
=
fluid
.
layers
.
open_files
(
filenames
=
filelist
,
filenames
=
filelist
,
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
shapes
=
[[
-
1
,
1
,
28
,
28
],
(
-
1
,
1
)],
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
...
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
if
is_train
:
if
is_train
:
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
opt
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
)
opt
.
minimize
()
opt
.
minimize
(
avg_cost
)
if
args
.
memory_optimize
:
if
args
.
memory_optimize
:
fluid
.
memory_optimize
(
main_prog
)
fluid
.
memory_optimize
(
main_prog
)
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
7d5118a9
...
@@ -46,7 +46,8 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...
@@ -46,7 +46,8 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
#endif
void
AllReduceOpHandle
::
RunImpl
()
{
void
AllReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"all_reduce"
,
nullptr
);
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
if
(
NoDummyInputSize
()
==
1
)
{
if
(
NoDummyInputSize
()
==
1
)
{
return
;
// No need to all reduce when GPU count = 1;
return
;
// No need to all reduce when GPU count = 1;
}
else
{
}
else
{
...
...
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
7d5118a9
...
@@ -15,12 +15,15 @@
...
@@ -15,12 +15,15 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
BroadcastOpHandle
::
RunImpl
()
{
void
BroadcastOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// The input and output may have dummy vars.
// The input and output may have dummy vars.
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
7d5118a9
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
CreateRPCOp
(
&
result
,
node
);
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_vars_attr
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE
(
recv_vars_attr
.
size
()
==
2UL
);
// [parameter, gradient]
if
(
recv_vars_attr
[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
bcast_var_name_set
[
op_dev_id
].
emplace
(
recv_vars_attr
[
0
]);
}
}
is_dist_train
=
true
;
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
CreateDistTrainOp
(
&
result
,
node
);
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
}
}
else
if
(
IsScaleLossOp
(
node
))
{
}
else
if
(
IsScaleLossOp
(
node
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
...
@@ -414,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -414,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
if
(
!
is_dist_train
)
{
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
}
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
...
@@ -436,14 +455,19 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -436,14 +455,19 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
}
}
}
bool
use_gpu
=
false
;
bool
use_gpu
=
false
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
use_gpu
=
nccl_ctxs_
!=
nullptr
;
use_gpu
=
nccl_ctxs_
!=
nullptr
;
#endif
#endif
if
(
use_gpu
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
// Insert broadcast operators principle:
// Insert BCast Ops
// 1. Broadcast optimized parameters in Reduce strategy;
// 2. No need broadcast optimized parameters in AllReduce strategy because of
// the optimization sub-graph would be run on every GPU;
// 3. Allways broadcast received parameters in Distribute Training.
if
((
use_gpu
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
||
is_dist_train
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
...
@@ -675,8 +699,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -675,8 +699,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return
var
;
return
var
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
...
@@ -719,6 +743,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -719,6 +743,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node
->
Op
()
->
Type
());
node
->
Op
()
->
Type
());
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
return
op_dev_id
;
}
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
...
@@ -737,8 +762,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...
@@ -737,8 +762,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
...
@@ -824,6 +849,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -824,6 +849,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
outvar_dev_id
);
}
}
}
}
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
7d5118a9
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
7d5118a9
...
@@ -27,7 +27,8 @@ namespace framework {
...
@@ -27,7 +27,8 @@ namespace framework {
namespace
details
{
namespace
details
{
void
ReduceOpHandle
::
RunImpl
()
{
void
ReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
r
(
"reduce"
,
nullptr
);
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
begin
()
->
second
);
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// the input and output may have dummy var.
// the input and output may have dummy var.
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
7d5118a9
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
...
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->
stream
();
->
stream
();
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
VLOG
(
1
)
<<
place_
<<
"RUN Scale loss grad op"
;
VLOG
(
1
0
)
<<
place_
<<
"RUN Scale loss grad op"
;
});
});
#endif
#endif
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
7d5118a9
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
size_t
>
())
size_t
>
())
.
def
(
"_bcast_params"
,
&
ParallelExecutor
::
BCastParamsToDevices
)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
7d5118a9
...
@@ -142,11 +142,6 @@ class ParallelExecutor(object):
...
@@ -142,11 +142,6 @@ class ParallelExecutor(object):
main
=
main
if
main
else
framework
.
default_main_program
()
main
=
main
if
main
else
framework
.
default_main_program
()
if
scope
==
None
:
if
scope
==
None
:
scope
=
executor
.
global_scope
()
scope
=
executor
.
global_scope
()
# FIXME(Yancey1989): it's a temporary approach to determinate the distribute
# train program, call self.bcast_param() at the end of each mini-batch.
self
.
is_dist
=
True
if
"recv"
in
[
op
.
type
for
op
in
main
.
global_block
().
ops
]
else
False
if
share_vars_from
and
not
isinstance
(
share_vars_from
,
if
share_vars_from
and
not
isinstance
(
share_vars_from
,
ParallelExecutor
):
ParallelExecutor
):
...
@@ -286,21 +281,11 @@ class ParallelExecutor(object):
...
@@ -286,21 +281,11 @@ class ParallelExecutor(object):
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
if
self
.
is_dist
:
self
.
_bcast_params
()
if
return_numpy
:
if
return_numpy
:
return
executor
.
as_numpy
(
arr
)
return
executor
.
as_numpy
(
arr
)
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
def
_bcast_params
(
self
):
"""
Broadcast the parameters to other devices. It is used during
distributed training.
"""
self
.
executor
.
_bcast_params
(
set
(
self
.
persistable_vars
))
@
property
@
property
def
device_count
(
self
):
def
device_count
(
self
):
return
len
(
self
.
_act_places
)
return
len
(
self
.
_act_places
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录