Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3abb2aa0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3abb2aa0
编写于
1月 09, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine MultiDevSSAGraph
test=release/1.2
上级
553df9d3
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
211 addition
and
277 deletion
+211
-277
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+198
-207
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+11
-8
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+0
-59
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+2
-3
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
3abb2aa0
...
@@ -42,6 +42,12 @@ namespace {
...
@@ -42,6 +42,12 @@ namespace {
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
const
char
kGraphOps
[]
=
"ops"
;
bool
OpHaveRole
(
const
ir
::
Node
&
node
,
const
framework
::
OpRole
&
role
)
{
return
boost
::
get
<
int
>
(
node
.
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
role
);
}
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
...
@@ -150,6 +156,7 @@ void MultiDevSSAGraphBuilder::Init() const {
...
@@ -150,6 +156,7 @@ void MultiDevSSAGraphBuilder::Init() const {
grad_names_
.
insert
(
GradVarName
(
p
));
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
balance_vars_
.
resize
(
places_
.
size
(),
0
);
if
(
strategy_
.
enable_data_balance_
&&
places_
.
size
()
==
1
)
{
if
(
strategy_
.
enable_data_balance_
&&
places_
.
size
()
==
1
)
{
LOG
(
WARNING
)
<<
"It is no need to enable data balance when there is only "
LOG
(
WARNING
)
<<
"It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False."
;
"one place. enable_data_balance is set to False."
;
...
@@ -157,145 +164,16 @@ void MultiDevSSAGraphBuilder::Init() const {
...
@@ -157,145 +164,16 @@ void MultiDevSSAGraphBuilder::Init() const {
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
result
,
input
,
p
,
place_id
);
op_handle
->
AddInput
(
var
);
}
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
ir
::
Node
*
new_node
=
nullptr
;
if
(
output
->
Var
())
{
new_node
=
result
->
CreateVarNode
(
output
->
Var
());
}
else
{
new_node
=
result
->
CreateEmptyNode
(
output
->
Name
(),
ir
::
Node
::
Type
::
kVariable
);
}
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
place_id
);
}
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for
(
auto
&
node
:
nodes
)
{
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if
(
op
->
Type
()
==
"send"
)
{
auto
op_vars
=
op
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
send_vars
.
insert
(
send_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
send_vars
;
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
&
node
:
nodes
)
{
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if
(
op
->
Type
()
==
"recv"
)
{
auto
op_vars
=
op
->
OutputArgumentNames
();
recv_vars
.
reserve
(
recv_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
recv_vars
.
insert
(
recv_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
recv_vars
;
}
size_t
MultiDevSSAGraphBuilder
::
GetAppropriateDeviceID
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
int64_t
numel_sum
=
0
;
for
(
auto
var_name
:
var_names
)
{
if
(
all_vars_
.
find
(
var_name
)
==
all_vars_
.
end
())
continue
;
auto
var_desc
=
all_vars_
.
at
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GT
(
numel
,
0
);
numel_sum
+=
numel
;
}
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_vars_
),
std
::
end
(
balance_vars_
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_vars_
),
smallest
));
balance_vars_
[
dev_id
]
+=
numel_sum
;
return
dev_id
;
}
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std
::
vector
<
ir
::
Node
*>
SortOpsAndDelayOptimizeOp
(
const
ir
::
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
ret
=
ir
::
TopologySortOperations
(
graph
);
size_t
last_backward
=
0
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kBackward
))
{
last_backward
=
i
;
}
}
std
::
vector
<
ir
::
Node
*>
optimize_ops
;
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
i
<
last_backward
)
{
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kOptimize
)))
{
optimize_ops
.
push_back
(
ret
[
i
]);
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
}
}
else
if
(
i
==
last_backward
)
{
sorted_ret
.
push_back
(
ret
[
i
]);
// Verify that no operations before optimize ops depends on optimize ops.
std
::
unordered_set
<
ir
::
Node
*>
optimize_set
(
optimize_ops
.
begin
(),
optimize_ops
.
end
());
for
(
ir
::
Node
*
n
:
sorted_ret
)
{
for
(
ir
::
Node
*
in
:
n
->
inputs
)
{
for
(
ir
::
Node
*
pre_n
:
in
->
inputs
)
{
PADDLE_ENFORCE
(
optimize_set
.
find
(
pre_n
)
==
optimize_set
.
end
(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s"
,
pre_n
->
Name
(),
n
->
Name
());
}
}
}
sorted_ret
.
insert
(
sorted_ret
.
end
(),
optimize_ops
.
begin
(),
optimize_ops
.
end
());
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
}
}
return
sorted_ret
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
Init
();
// Give the topology sort order and rebuild the graph structure.
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
ir
::
TopologySortOperations
(
*
graph
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
sorted_ops
=
SortForReduceMode
(
sorted_ops
);
}
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
...
@@ -304,31 +182,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -304,31 +182,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
}
}
}
}
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
// find send/recv vars so that we can place the distributed training
// related op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
sorted_ops
);
auto
recv_vars
=
FindDistTrainRecvVars
(
sorted_ops
);
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
size_t
cur_device_id
=
0
;
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
bool
is_dist_train
=
false
;
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
boost
::
get
<
int
>
(
if
(
OpHaveRole
(
*
node
,
OpRole
::
kRPC
))
{
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
,
&
sharded_var_device
);
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
,
&
sharded_var_device
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
"Can not schedule the RPC operator to the right place."
);
...
@@ -342,9 +211,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -342,9 +211,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
}
is_dist_train
=
true
;
is_dist_train
=
true
;
}
else
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
}
else
if
(
OpHaveRole
(
*
node
,
OpRole
::
kDist
))
{
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kDist
))
{
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
,
&
sharded_var_device
);
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
,
&
sharded_var_device
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
...
@@ -364,7 +231,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -364,7 +231,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block.
// the block.
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
,
sharded_var_device
);
int
op_dev_id
=
GetOpDeviceID
(
node
,
sharded_var_device
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
...
@@ -384,47 +251,48 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -384,47 +251,48 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
continue
;
// Currently, we assume that once gradient is generated, it can be
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
// broadcast, and each gradient is only broadcast once.
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
try
{
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
auto
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
static_cast
<
int
>
(
OpRole
::
kBackward
)))
{
node
->
Op
()
->
GetNullableAttr
(
try
{
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
auto
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
auto
&
p_name
=
backward_vars
[
i
];
auto
&
g_name
=
backward_vars
[
i
+
1
];
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
auto
&
p_name
=
backward_vars
[
i
];
size_t
cur_device_id
=
-
1
;
auto
&
g_name
=
backward_vars
[
i
+
1
];
switch
(
strategy_
.
reduce_
)
{
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
switch
(
strategy_
.
reduce_
)
{
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
sharded_var_device
.
emplace
(
g_name
,
cur_device_id
);
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
if
(
!
is_dist_train
)
{
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
sharded_var_device
.
emplace
(
g_name
,
cur_device_id
);
}
if
(
!
is_dist_train
)
{
break
;
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
}
if
(
IsSparseGradient
(
g_name
))
{
break
;
CreateReduceOp
(
&
result
,
g_name
,
0
);
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
if
(
IsSparseGradient
(
g_name
))
{
}
else
{
CreateReduceOp
(
&
result
,
g_name
,
0
);
InsertAllReduceOp
(
&
result
,
g_name
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
}
}
else
{
break
;
InsertAllReduceOp
(
&
result
,
g_name
);
default:
}
LOG
(
FATAL
)
<<
"Unknown reduce strategy "
;
break
;
break
;
default:
LOG
(
FATAL
)
<<
"Unknown reduce strategy "
;
break
;
}
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
}
}
}
}
...
@@ -468,12 +336,108 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -468,12 +336,108 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return
graph
;
return
graph
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
string
&
og
)
const
{
std
::
vector
<
ir
::
Node
*>
MultiDevSSAGraphBuilder
::
SortForReduceMode
(
PADDLE_ENFORCE
(
all_vars_
.
count
(
og
)
!=
0
);
const
std
::
vector
<
ir
::
Node
*>
&
topo_ops
)
const
{
if
(
all_vars_
.
at
(
og
)
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
return
true
;
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
delayed_op
;
sorted_ops
.
reserve
(
topo_ops
.
size
());
auto
insert_delayed_op
=
[
&
](
const
std
::
string
&
var_name
,
int
dev_id
)
{
sharded_var_device
.
emplace
(
var_name
,
dev_id
);
if
(
delayed_op
.
count
(
var_name
))
{
auto
&
ops
=
delayed_op
.
at
(
var_name
);
sorted_ops
.
insert
(
sorted_ops
.
end
(),
ops
.
begin
(),
ops
.
end
());
delayed_op
.
at
(
var_name
).
clear
();
}
};
for
(
ir
::
Node
*
node
:
topo_ops
)
{
int
op_dev_id
=
GetOpDeviceID
(
node
,
sharded_var_device
,
&
delayed_op
);
if
(
op_dev_id
>
-
1
)
{
// This op only runs on one specific device.
sorted_ops
.
emplace_back
(
node
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
insert_delayed_op
(
n
->
Name
(),
op_dev_id
);
}
}
else
if
(
op_dev_id
==
-
1
)
{
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops
.
emplace_back
(
node
);
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
continue
;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std
::
vector
<
std
::
string
>
backward_vars
;
try
{
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
}
catch
(
boost
::
bad_get
e
)
{
}
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
auto
&
g_name
=
backward_vars
[
i
+
1
];
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
insert_delayed_op
(
g_name
,
static_cast
<
int
>
(
cur_device_id
));
}
}
else
if
(
op_dev_id
==
-
2
)
{
// The Op on which the Op depends has not yet been generated.
}
}
}
return
false
;
PADDLE_ENFORCE_EQ
(
sorted_ops
.
size
(),
topo_ops
.
size
());
return
sorted_ops
;
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
result
,
input
,
p
,
place_id
);
op_handle
->
AddInput
(
var
);
}
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
ir
::
Node
*
new_node
=
nullptr
;
if
(
output
->
Var
())
{
new_node
=
result
->
CreateVarNode
(
output
->
Var
());
}
else
{
new_node
=
result
->
CreateEmptyNode
(
output
->
Name
(),
ir
::
Node
::
Type
::
kVariable
);
}
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
place_id
);
}
}
size_t
MultiDevSSAGraphBuilder
::
GetAppropriateDeviceID
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
int64_t
numel_sum
=
0
;
for
(
auto
var_name
:
var_names
)
{
if
(
all_vars_
.
find
(
var_name
)
==
all_vars_
.
end
())
continue
;
auto
var_desc
=
all_vars_
.
at
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GT
(
numel
,
0
);
numel_sum
+=
numel
;
}
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_vars_
),
std
::
end
(
balance_vars_
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_vars_
),
smallest
));
balance_vars_
[
dev_id
]
+=
numel_sum
;
return
dev_id
;
}
}
void
MultiDevSSAGraphBuilder
::
SetCommunicationContext
(
void
MultiDevSSAGraphBuilder
::
SetCommunicationContext
(
...
@@ -624,28 +588,52 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
...
@@ -624,28 +588,52 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
],
sharded_var_device
);
if
(
dev_id
==
-
1
)
{
(
*
delay_ops
)[
param_grad
[
1
]].
push_back
(
node
);
return
-
2
;
}
return
dev_id
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
return
-
1
;
}
}
int
op_role
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
if
(
op_role
!=
static_cast
<
int
>
(
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
return
-
1
;
}
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
],
sharded_var_device
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
],
sharded_var_device
);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
return
dev_id
;
}
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
,
const
std
::
string
&
varname
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
auto
got
=
sharded_var_device
.
find
(
varname
);
auto
got
=
sharded_var_device
.
find
(
varname
);
if
(
got
==
sharded_var_device
.
end
())
{
if
(
got
==
sharded_var_device
.
end
())
{
...
@@ -739,8 +727,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
...
@@ -739,8 +727,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
],
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
input_var_names
[
0
],
*
sharded_var_device
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
...
@@ -751,8 +738,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
...
@@ -751,8 +738,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
],
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
input_var_names
[
0
],
*
sharded_var_device
);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
}
}
...
@@ -793,8 +779,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -793,8 +779,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
(),
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
(),
*
sharded_var_device
);
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
...
@@ -824,8 +809,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -824,8 +809,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
if
(
recv_param_grad
.
size
()
==
2U
)
{
if
(
recv_param_grad
.
size
()
==
2U
)
{
op_dev_id
=
op_dev_id
=
GetVarDeviceID
(
recv_param_grad
[
1
],
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
recv_param_grad
[
1
],
*
sharded_var_device
);
VLOG
(
10
)
<<
"recv param "
<<
recv_param_grad
[
0
]
VLOG
(
10
)
<<
"recv param "
<<
recv_param_grad
[
0
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
<<
" place: "
<<
op_dev_id
;
...
@@ -860,8 +844,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -860,8 +844,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
int
outvar_dev_id
=
op_dev_id
;
int
outvar_dev_id
=
op_dev_id
;
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
outvar_dev_id
=
outvar_dev_id
=
GetVarDeviceID
(
output
->
Name
(),
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
output
->
Name
(),
*
sharded_var_device
);
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
,
"output name %s"
,
output
->
Name
());
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
,
"output name %s"
,
output
->
Name
());
}
}
p
=
places_
[
outvar_dev_id
];
p
=
places_
[
outvar_dev_id
];
...
@@ -878,6 +861,14 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -878,6 +861,14 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
return
op_dev_id
;
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
string
&
og
)
const
{
PADDLE_ENFORCE
(
all_vars_
.
count
(
og
)
!=
0
);
if
(
all_vars_
.
at
(
og
)
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
return
true
;
}
return
false
;
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
return
boost
::
get
<
int
>
(
return
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
3abb2aa0
...
@@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
#endif
#endif
int
GetVarDeviceID
(
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
,
const
std
::
string
&
varname
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
...
@@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
;
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
...
@@ -76,7 +70,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -76,7 +70,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
int
dev_id
)
const
;
int
dev_id
)
const
;
int
GetOpDeviceID
(
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
@@ -99,6 +93,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -99,6 +93,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
;
const
platform
::
Place
&
p
)
const
;
std
::
vector
<
ir
::
Node
*>
SortForReduceMode
(
const
std
::
vector
<
ir
::
Node
*>
&
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
shared_var_device
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
;
mutable
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
3abb2aa0
...
@@ -23,67 +23,8 @@ limitations under the License. */
...
@@ -23,67 +23,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
namespace
{
void
CheckProgram
(
const
ProgramDesc
&
program
)
{
#define _INT(role) static_cast<int>(role)
std
::
map
<
int
,
bool
>
visit
;
for
(
OpDesc
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// For backward compatibility, some program doesn't have role added.
if
(
!
op
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
continue
;
int
role_id
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
visit
[
role_id
]
=
true
;
switch
(
role_id
)
{
case
_INT
(
OpRole
::
kForward
):
if
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add backward operator before forward operator %s."
<<
op
->
Type
();
}
break
;
case
_INT
(
OpRole
::
kBackward
):
case
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add backward operator %s after optimize operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kForward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
))
==
visit
.
end
(),
"Cannot add backward|loss operator before "
"forward|loss operator %s."
,
op
->
Type
());
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add forward|loss operator %s after optimize operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kOptimize
):
case
_INT
(
OpRole
::
kOptimize
)
|
_INT
(
OpRole
::
kLRSched
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
(),
"Optimize operators %s must follow backward operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kLRSched
):
case
_INT
(
OpRole
::
kDist
):
case
_INT
(
OpRole
::
kRPC
):
case
_INT
(
OpRole
::
kNotSpecified
):
break
;
default:
LOG
(
FATAL
)
<<
"Unknown operator role. Don't add new role because "
"you don't know what you are doing."
;
}
}
#undef _INT
}
}
// namespace
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
CheckProgram
(
program_
);
auto
var_nodes
=
InitFromProgram
(
program_
);
auto
var_nodes
=
InitFromProgram
(
program_
);
ResolveHazard
(
var_nodes
);
ResolveHazard
(
var_nodes
);
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
3abb2aa0
...
@@ -215,6 +215,7 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -215,6 +215,7 @@ void ParallelExecutor::BCastParamsToDevices(
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std
::
vector
<
void
*>
buffers
;
std
::
vector
<
void
*>
buffers
;
buffers
.
reserve
(
member_
->
places_
.
size
());
size_t
numel
=
main_tensor
.
numel
();
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
...
@@ -248,9 +249,7 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -248,9 +249,7 @@ void ParallelExecutor::BCastParamsToDevices(
#endif
#endif
}
else
{
}
else
{
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
if
(
i
==
0
)
continue
;
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
*
t
=
local_scope
->
Var
(
var
)
->
GetMutable
<
LoDTensor
>
();
auto
*
t
=
local_scope
->
Var
(
var
)
->
GetMutable
<
LoDTensor
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录