Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7722baa8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7722baa8
编写于
5月 04, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments and clean code
上级
c8911895
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
160 addition
and
107 deletion
+160
-107
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+46
-41
paddle/fluid/framework/details/gather_op_handle.cc
paddle/fluid/framework/details/gather_op_handle.cc
+14
-20
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+19
-16
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+22
-19
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+0
-11
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+10
-0
paddle/fluid/framework/details/variable_visitor.cc
paddle/fluid/framework/details/variable_visitor.cc
+46
-0
paddle/fluid/framework/details/variable_visitor.h
paddle/fluid/framework/details/variable_visitor.h
+3
-0
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
7722baa8
...
@@ -22,9 +22,9 @@ namespace details {
...
@@ -22,9 +22,9 @@ namespace details {
void
BroadcastOpHandle
::
RunImpl
()
{
void
BroadcastOpHandle
::
RunImpl
()
{
if
(
places_
.
size
()
==
1
)
return
;
if
(
places_
.
size
()
==
1
)
return
;
// the input and output may have dummy var.
VarHandle
*
in_var_handle
;
// The input and output may have dummy vars.
VarHandle
*
in_var_handle
;
{
{
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
1
,
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
1
,
...
@@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() {
...
@@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() {
Tensor
&
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
Tensor
&
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
if
(
platform
::
is_cpu_place
(
in_tensor
.
place
()))
{
// NOTE(zcd): the Place of input can be get from in_tensor and in_var_handle ,
for
(
auto
*
out
:
out_var_handles
)
{
// maybe they are different, because the Place that getting from in_tensor is
if
(
*
out
==
*
in_var_handle
)
{
// determined at runtime, the other is determined at building SSA graph stage.
// If they are different, DataTransform should be applied. Currently, it has
// not been done yet.
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
*
out_var_handle
==
*
in_var_handle
)
{
continue
;
continue
;
}
}
auto
&
out_p
=
out_var_handle
->
place_
;
auto
&
out_p
=
out
->
place_
;
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
auto
*
out_var
=
var_scopes
.
at
(
out
->
scope_idx_
)
->
FindVar
(
out
->
name_
);
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
PADDLE_ENFORCE_EQ
(
out_p
.
which
(),
in_tensor
.
place
().
which
(),
PADDLE_ENFORCE_EQ
(
"Places must be all on CPU or all on CUDA."
);
out_p
.
which
(),
in_tensor
.
place
().
which
(),
"Currently, Places of input and output must be all on CPU "
"or all on GPU."
);
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
out_p
,
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
out_p
,
in_tensor
.
type
());
in_tensor
.
type
());
}
if
(
platform
::
is_cpu_place
(
in_tensor
.
place
()))
{
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
*
out_var_handle
==
*
in_var_handle
)
{
continue
;
}
auto
&
out_p
=
out_var_handle
->
place_
;
auto
dev_ctx
=
dev_ctxes_
.
at
(
out_p
);
auto
dev_ctx
=
dev_ctxes_
.
at
(
out_p
);
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
RunAndRecordEvent
(
out_p
,
[
in_tensor
,
out_var
,
dev_ctx
,
out_p
]
{
RunAndRecordEvent
(
out_p
,
[
in_tensor
,
out_var
,
dev_ctx
,
out_p
]
{
paddle
::
framework
::
TensorCopy
(
paddle
::
framework
::
TensorCopy
(
in_tensor
,
out_p
,
*
dev_ctx
,
in_tensor
,
out_p
,
*
dev_ctx
,
...
@@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() {
...
@@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() {
}
}
}
else
{
}
else
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
in_tensor
.
place
()));
VarHandle
*
out_handle
=
nullptr
;
VarHandle
*
out_handle
;
int
root_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
in_tensor
.
place
()).
device
;
int
root
=
boost
::
get
<
platform
::
CUDAPlace
>
(
in_tensor
.
place
()).
device
;
std
::
vector
<
std
::
function
<
void
()
>>
broadcast_calls
;
std
::
vector
<
std
::
function
<
void
()
>>
broadcast_calls
;
for
(
size_t
j
=
0
;
j
<
out_var_handles
.
size
();
++
j
)
{
for
(
auto
out_var_handle
:
out_var_handles
)
{
VarHandle
*
out_var_handle
=
out_var_handles
[
j
];
Variable
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
Variable
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
->
FindVar
(
out_var_handle
->
name_
);
if
(
*
out_var_handle
!=
*
in_var_handle
)
{
int
dst_id
=
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
boost
::
get
<
platform
::
CUDAPlace
>
(
out_var_handle
->
place_
).
device
;
PADDLE_ENFORCE_EQ
(
out_var_handle
->
place_
.
which
(),
in_tensor
.
place
().
which
(),
"Places must be all on CPU or all on CUDA."
);
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
out_var_handle
->
place_
,
in_tensor
.
type
());
}
auto
out_p
=
out_var_handle
->
place_
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
out_p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dst_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
void
*
send_recv_buffer
=
nullptr
;
void
*
send_recv_buffer
=
nullptr
;
if
(
root
==
dev
_id
)
{
if
(
root
_id
==
dst
_id
)
{
send_recv_buffer
=
const_cast
<
void
*>
(
in_tensor
.
data
<
void
>
());
send_recv_buffer
=
const_cast
<
void
*>
(
in_tensor
.
data
<
void
>
());
out_handle
=
out_var_handle
;
out_handle
=
out_var_handle
;
}
else
{
}
else
{
...
@@ -116,10 +118,12 @@ void BroadcastOpHandle::RunImpl() {
...
@@ -116,10 +118,12 @@ void BroadcastOpHandle::RunImpl() {
}
}
int
type
=
platform
::
ToNCCLDataType
(
in_tensor
.
type
());
int
type
=
platform
::
ToNCCLDataType
(
in_tensor
.
type
());
broadcast_calls
.
emplace_back
([
=
]
{
size_t
numel
=
static_cast
<
size_t
>
(
in_tensor
.
numel
());
broadcast_calls
.
emplace_back
(
[
send_recv_buffer
,
numel
,
type
,
root_id
,
&
nccl_ctx
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclBcast
(
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclBcast
(
send_recv_buffer
,
in_tensor
.
numel
(
),
send_recv_buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
type
),
static_cast
<
ncclDataType_t
>
(
type
),
root
,
comm
,
stream
));
root_id
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
()
));
});
});
}
}
...
@@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() {
...
@@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() {
call
();
call
();
}
}
}
}
// TODO(zcd): Maybe the unequal operator is not appropriate here.
if
(
*
out_handle
!=
*
in_var_handle
)
{
if
(
*
out_handle
!=
*
in_var_handle
)
{
auto
out_var
=
var_scopes
.
at
(
in_var_handle
->
scope_idx_
)
auto
out_var
=
var_scopes
.
at
(
in_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handles
[
0
]
->
name_
);
->
FindVar
(
out_var_handles
[
0
]
->
name_
);
...
@@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() {
...
@@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() {
}
}
});
});
#else
#else
PADDLE_THROW
(
"CUDA is not
support
."
);
PADDLE_THROW
(
"CUDA is not
enabled
."
);
#endif
#endif
}
}
}
}
...
...
paddle/fluid/framework/details/gather_op_handle.cc
浏览文件 @
7722baa8
...
@@ -36,7 +36,6 @@ void GatherOpHandle::RunImpl() {
...
@@ -36,7 +36,6 @@ void GatherOpHandle::RunImpl() {
VarHandle
*
out_var_handle
;
VarHandle
*
out_var_handle
;
{
{
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
outputs_
);
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
outputs_
);
PADDLE_ENFORCE_EQ
(
out_var_handles
.
size
(),
1
,
PADDLE_ENFORCE_EQ
(
out_var_handles
.
size
(),
1
,
"The number of output should be one."
);
"The number of output should be one."
);
out_var_handle
=
out_var_handles
.
front
();
out_var_handle
=
out_var_handles
.
front
();
...
@@ -51,43 +50,39 @@ void GatherOpHandle::RunImpl() {
...
@@ -51,43 +50,39 @@ void GatherOpHandle::RunImpl() {
auto
pre_in_var
=
auto
pre_in_var
=
var_scopes
.
at
(
in_0_handle
->
scope_idx_
)
->
FindVar
(
in_0_handle
->
name_
);
var_scopes
.
at
(
in_0_handle
->
scope_idx_
)
->
FindVar
(
in_0_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
pre_in_var
);
PADDLE_ENFORCE_NOT_NULL
(
pre_in_var
);
PADDLE_ENFORCE
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
(),
PADDLE_ENFORCE
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
(),
"Currently, gather_op only can gather SelectedRows."
);
"Currently, gather_op only can gather SelectedRows."
);
// Wait input done, this Wait is asynchronous operation
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
(
in_var_handles
);
auto
&
pre_in_value
=
pre_in_var
->
Get
<
framework
::
SelectedRows
>
();
std
::
vector
<
int64_t
>
out_rows
;
std
::
vector
<
int64_t
>
out_rows
;
std
::
vector
<
Tensor
>
in_tensors
;
std
::
vector
<
Tensor
>
in_tensors
;
auto
&
pre_in_value
=
pre_in_var
->
Get
<
framework
::
SelectedRows
>
();
// Gather the inputs
// gather the inputs
for
(
auto
*
in_handle
:
in_var_handles
)
{
for
(
auto
*
in_handle
:
in_var_handles
)
{
auto
*
in_var
=
auto
*
in_var
=
var_scopes
.
at
(
in_handle
->
scope_idx_
)
->
FindVar
(
in_handle
->
name_
);
var_scopes
.
at
(
in_handle
->
scope_idx_
)
->
FindVar
(
in_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
);
VariableVisitor
::
EnforceShapeAndDTypeEQ
(
*
in_var
,
*
pre_in_var
);
auto
&
in_sr_value
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
in_sr_value
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
PADDLE_ENFORCE_EQ
(
in_sr_value
.
place
().
which
(),
pre_in_value
.
place
().
which
(),
"Places must be all on CPU or all on GPU."
);
PADDLE_ENFORCE_EQ
(
in_sr_value
.
value
().
type
(),
pre_in_value
.
value
().
type
(),
"The type of input is not consistent."
);
PADDLE_ENFORCE_EQ
(
in_sr_value
.
height
(),
pre_in_value
.
height
(),
"The height of inputs is not consistent."
);
PADDLE_ENFORCE_EQ
(
in_sr_value
.
GetCompleteDims
(),
pre_in_value
.
GetCompleteDims
(),
"The dims of inputs is not consistent."
);
auto
&
in_sr_rows
=
in_sr_value
.
rows
();
auto
&
in_sr_rows
=
in_sr_value
.
rows
();
out_rows
.
insert
(
out_rows
.
end
(),
in_sr_rows
.
begin
(),
in_sr_rows
.
end
());
out_rows
.
insert
(
out_rows
.
end
(),
in_sr_rows
.
begin
(),
in_sr_rows
.
end
());
in_tensors
.
emplace_back
(
in_sr_value
.
value
());
in_tensors
.
emplace_back
(
in_sr_value
.
value
());
}
}
// write the output
// TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
auto
&
out_place
=
out_var_handle
->
place_
;
auto
&
out_place
=
out_var_handle
->
place_
;
PADDLE_ENFORCE_EQ
(
out_place
.
which
(),
pre_in_value
.
place
().
which
(),
PADDLE_ENFORCE_EQ
(
out_place
.
which
(),
pre_in_value
.
place
().
which
(),
"Places must be all on CPU or all on GPU."
);
"Currently, Places of input and output must be all on CPU "
"or all on GPU."
);
auto
out_var
=
auto
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
...
@@ -97,19 +92,18 @@ void GatherOpHandle::RunImpl() {
...
@@ -97,19 +92,18 @@ void GatherOpHandle::RunImpl() {
size_t
rows
=
out_rows
.
size
();
size_t
rows
=
out_rows
.
size
();
DDim
out_dim
=
pre_in_value
.
GetCompleteDims
();
DDim
out_dim
=
pre_in_value
.
GetCompleteDims
();
out_dim
[
0
]
=
static_cast
<
int64_t
>
(
rows
);
out_dim
[
0
]
=
static_cast
<
int64_t
>
(
rows
);
out_value
->
mutable_value
()
->
Resize
(
out_dim
);
out_value
->
mutable_value
()
->
Resize
(
out_dim
).
mutable_data
(
out_value
->
mutable_value
()
->
mutable_data
(
out_place
,
out_place
,
pre_in_value
.
value
().
type
());
pre_in_value
.
value
().
type
());
Tensor
*
out_tensor
=
out_value
->
mutable_value
();
Tensor
*
out_tensor
=
out_value
->
mutable_value
();
// copy
// copy
auto
dev_ctx
=
dev_ctxes_
[
out_place
];
auto
dev_ctx
=
dev_ctxes_
[
out_place
];
RunAndRecordEvent
(
out_place
,
[
in_tensors
,
out_tensor
,
dev_ctx
,
out_place
]
{
RunAndRecordEvent
(
out_place
,
[
in_tensors
,
out_tensor
,
&
dev_ctx
,
out_place
]
{
int
s
=
0
,
e
=
0
;
int
s
=
0
,
e
=
0
;
for
(
size_t
j
=
0
;
j
<
in_tensors
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
in_tensors
.
size
();
++
j
)
{
e
+=
in_tensors
[
j
].
dims
()[
0
];
e
+=
in_tensors
[
j
].
dims
()[
0
];
auto
sub_out
=
out_tensor
->
Slice
(
s
,
e
);
auto
sub_out
=
out_tensor
->
Slice
(
s
,
e
);
paddle
::
framework
::
TensorCopy
(
in_tensors
[
j
],
out_place
,
*
(
dev_ctx
)
,
paddle
::
framework
::
TensorCopy
(
in_tensors
[
j
],
out_place
,
*
dev_ctx
,
&
sub_out
);
&
sub_out
);
s
=
e
;
s
=
e
;
}
}
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
7722baa8
...
@@ -116,13 +116,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -116,13 +116,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
places_
.
size
());
// size_t cur_device_id = 0;
size_t
cur_update_sparse_gp_dev_id
=
0
;
size_t
update_sparse_gp_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
sparse_var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_sparse_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
var_name_on_devices
.
resize
(
places_
.
size
());
sparse_
var_name_on_devices
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_
sparse_
var_name_set
.
resize
(
places_
.
size
());
// Find "send" op first for split is in front of send.
// Find "send" op first for split is in front of send.
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
...
@@ -142,13 +141,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -142,13 +141,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
var_name_on_devices
,
*
op
);
int
op_dev_id
=
GetOpDeviceID
(
sparse_
var_name_on_devices
,
*
op
);
if
(
op_dev_id
==
-
1
)
{
// var on all device
if
(
op_dev_id
==
-
1
)
{
// var on all device
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
}
else
{
}
else
{
CreateComputationalOp
(
&
result
,
*
op
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
*
op
,
op_dev_id
);
for
(
auto
&
var_name
:
op
->
OutputArgumentNames
())
{
for
(
auto
&
var_name
:
op
->
OutputArgumentNames
())
{
var_name_on_devices
[
op_dev_id
].
emplace
(
var_name
);
sparse_
var_name_on_devices
[
op_dev_id
].
emplace
(
var_name
);
}
}
}
}
...
@@ -158,10 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -158,10 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for
(
auto
&
og
:
op
->
OutputArgumentNames
())
{
for
(
auto
&
og
:
op
->
OutputArgumentNames
())
{
if
(
IsParameterGradientOnce
(
og
,
&
og_has_been_broadcast
))
{
if
(
IsParameterGradientOnce
(
og
,
&
og_has_been_broadcast
))
{
if
(
IsSparseGradient
(
og
))
{
if
(
IsSparseGradient
(
og
))
{
CreateReduceOp
(
&
result
,
update_sparse_gp_device_id
,
og
);
CreateReduceOp
(
&
result
,
cur_update_sparse_gp_dev_id
,
og
);
var_name_on_devices
[
update_sparse_gp_device_id
].
emplace
(
og
);
sparse_var_name_on_devices
[
cur_update_sparse_gp_dev_id
].
emplace
(
bcast_var_name_set
[
update_sparse_gp_device_id
].
emplace
(
og
);
bcast_sparse_var_name_set
[
cur_update_sparse_gp_dev_id
].
emplace
(
og
.
substr
(
0
,
og
.
size
()
-
strlen
(
kGradVarSuffix
)));
og
.
substr
(
0
,
og
.
size
()
-
strlen
(
kGradVarSuffix
)));
cur_update_sparse_gp_dev_id
=
(
cur_update_sparse_gp_dev_id
+
1
)
%
places_
.
size
();
}
else
{
}
else
{
InsertNCCLAllReduceOp
(
&
result
,
og
);
InsertNCCLAllReduceOp
(
&
result
,
og
);
}
}
...
@@ -172,8 +174,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -172,8 +174,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// Insert BCast Ops
// Insert BCast Ops
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_
sparse_
var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
auto
&
to_bcast_set
=
bcast_
sparse_
var_name_set
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
&
result
,
bcast_name
,
dev_id
);
CreateBroadcastOp
(
&
result
,
bcast_name
,
dev_id
);
}
}
...
@@ -206,13 +208,14 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
...
@@ -206,13 +208,14 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
var_name_on_devices
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
sparse_var_name_on_devices
,
const
OpDesc
&
op
)
const
{
const
OpDesc
&
op
)
const
{
int
var_dev_id
=
-
1
;
int
var_dev_id
=
-
1
;
for
(
auto
&
var_name
:
op
.
InputArgumentNames
())
{
for
(
auto
&
var_name
:
op
.
InputArgumentNames
())
{
if
(
var_dev_id
!=
-
1
)
break
;
if
(
var_dev_id
!=
-
1
)
break
;
for
(
size_t
i
=
0
;
i
<
var_name_on_devices
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sparse_
var_name_on_devices
.
size
();
++
i
)
{
if
(
var_name_on_devices
[
i
].
count
(
var_name
))
{
if
(
sparse_
var_name_on_devices
[
i
].
count
(
var_name
))
{
var_dev_id
=
static_cast
<
int
>
(
i
);
var_dev_id
=
static_cast
<
int
>
(
i
);
break
;
break
;
}
}
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
7722baa8
...
@@ -52,27 +52,30 @@ void ReduceOpHandle::RunImpl() {
...
@@ -52,27 +52,30 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
(
in_var_handles
);
auto
pre_place
=
in_0_handle
->
place_
;
std
::
vector
<
platform
::
Place
>
in_places
;
// used to get dev_ctx
std
::
vector
<
platform
::
Place
>
in_places
;
// used to get dev_ctx
auto
pre_in_tensor
=
VariableVisitor
::
GetMutableTensor
(
pre_in_var
);
for
(
auto
*
in_handle
:
in_var_handles
)
{
for
(
auto
*
in_handle
:
in_var_handles
)
{
in_places
.
emplace_back
(
in_handle
->
place_
);
in_places
.
emplace_back
(
in_handle
->
place_
);
auto
in_var
=
auto
in_var
=
var_scopes
.
at
(
in_handle
->
scope_idx_
)
->
FindVar
(
in_handle
->
name_
);
var_scopes
.
at
(
in_handle
->
scope_idx_
)
->
FindVar
(
in_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
);
VariableVisitor
::
EnforceShapeAndDTypeEQ
(
*
pre_in_var
,
*
in_var
);
auto
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
PADDLE_ENFORCE_EQ
(
pre_in_tensor
.
place
().
which
(),
in_tensor
.
place
().
which
(),
"Places must be all on CPU or all on GPU."
);
PADDLE_ENFORCE_EQ
(
in_tensor
.
type
(),
pre_in_tensor
.
type
(),
"The type of input is not consistent."
);
}
}
auto
out_var
=
auto
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
// TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
PADDLE_ENFORCE_EQ
(
VariableVisitor
::
GetMutableTensor
(
pre_in_var
).
place
().
which
(),
out_var_handle
->
place_
.
which
(),
"Currently, Places of input and output must be all on CPU or all on "
"GPU."
);
if
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
...
@@ -96,7 +99,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -96,7 +99,7 @@ void ReduceOpHandle::RunImpl() {
out_var_handle
->
place_
,
pre_in
.
type
());
out_var_handle
->
place_
,
pre_in
.
type
());
auto
out_p
=
out_var_handle
->
place_
;
auto
out_p
=
out_var_handle
->
place_
;
int
root
=
boost
::
get
<
platform
::
CUDAPlace
>
(
out_p
).
device
;
int
root
_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
out_p
).
device
;
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
for
(
size_t
i
=
0
;
i
<
var_scopes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
var_scopes
.
size
();
++
i
)
{
auto
&
p
=
in_places
[
i
];
auto
&
p
=
in_places
[
i
];
...
@@ -104,22 +107,22 @@ void ReduceOpHandle::RunImpl() {
...
@@ -104,22 +107,22 @@ void ReduceOpHandle::RunImpl() {
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor
.
data
<
void
>
());
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor
.
data
<
void
>
());
void
*
recvbuffer
=
nullptr
;
void
*
recvbuffer
=
nullptr
;
if
(
root
==
dev_id
)
{
if
(
root
_id
==
dev_id
)
{
recvbuffer
=
recvbuffer
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
mutable_data
(
out_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
mutable_data
(
out_var_handle
->
place_
);
out_var_handle
->
place_
);
}
}
int
type
=
platform
::
ToNCCLDataType
(
lod_tensor
.
type
());
int
type
=
platform
::
ToNCCLDataType
(
lod_tensor
.
type
());
all_reduce_calls
.
emplace_back
([
=
]
{
size_t
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
all_reduce_calls
.
emplace_back
(
[
buffer
,
recvbuffer
,
type
,
numel
,
root_id
,
&
nccl_ctx
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclReduce
(
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclReduce
(
buffer
,
recvbuffer
,
static_cast
<
size_t
>
(
lod_tensor
.
numel
()
),
buffer
,
recvbuffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
type
),
static_cast
<
ncclDataType_t
>
(
type
),
ncclSum
,
root
,
comm
,
stream
));
ncclSum
,
root_id
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
()
));
});
});
}
}
...
@@ -130,7 +133,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -130,7 +133,7 @@ void ReduceOpHandle::RunImpl() {
}
}
});
});
#else
#else
PADDLE_THROW
(
"CUDA is not
support
."
);
PADDLE_THROW
(
"CUDA is not
enabled
."
);
#endif
#endif
}
else
{
}
else
{
PADDLE_THROW
(
"Place should be CPUPlace or CUDAPlace."
);
PADDLE_THROW
(
"Place should be CPUPlace or CUDAPlace."
);
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
7722baa8
...
@@ -47,17 +47,6 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
...
@@ -47,17 +47,6 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
}
}
}
VarHandle
*
SSAGraphBuilder
::
GetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
vars_
[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
if
(
var_holder
.
empty
())
{
return
nullptr
;
}
return
var_holder
.
rbegin
()
->
get
();
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
7722baa8
...
@@ -62,6 +62,16 @@ struct VarHandle : public VarHandleBase {
...
@@ -62,6 +62,16 @@ struct VarHandle : public VarHandleBase {
std
::
string
name_
;
std
::
string
name_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
// member
// variables(version_, scope_id_, name_, place_) must be equal. But sometimes
// judging whether the two var_handle is equal is actually to determine
// whether
// the two Variables that represented by var_handle is the same. And the same
// Variable may have many different var_handles, the version_ of these
// var_handles
// is different. So I don't take care of version_ temporarily when overloading
// equal.
bool
operator
==
(
const
VarHandle
&
o
)
const
{
bool
operator
==
(
const
VarHandle
&
o
)
const
{
return
o
.
generated_op_
==
generated_op_
&&
o
.
name_
==
name_
&&
return
o
.
generated_op_
==
generated_op_
&&
o
.
name_
==
name_
&&
o
.
scope_idx_
==
scope_idx_
;
o
.
scope_idx_
==
scope_idx_
;
...
...
paddle/fluid/framework/details/variable_visitor.cc
浏览文件 @
7722baa8
...
@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
...
@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable
(
src
,
&
visitor
);
VisitVariable
(
src
,
&
visitor
);
}
}
struct
EnforceEqualShapeAndDTypeVisitor
{
const
Variable
*
trg_
;
void
operator
()(
const
LoDTensor
&
src
)
{
auto
&
tensor
=
trg_
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
src
.
place
().
which
(),
tensor
.
place
().
which
(),
"The Places of the two Variable must be all on CPU or all on GPU."
);
PADDLE_ENFORCE_EQ
(
src
.
type
(),
tensor
.
type
(),
"The dtype of the two Variable is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
dims
(),
tensor
.
dims
(),
"The dims of the two Variable is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
lod
(),
tensor
.
lod
(),
"The lod of the two Variable is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
layout
(),
tensor
.
layout
(),
"The layout of the two Variable's tensor is not equal."
);
}
void
operator
()(
const
SelectedRows
&
src
)
{
auto
&
selected_rows
=
trg_
->
Get
<
SelectedRows
>
();
PADDLE_ENFORCE_EQ
(
src
.
place
().
which
(),
selected_rows
.
place
().
which
(),
"The Places of the two Variable must be all on CPU or all on GPU."
);
PADDLE_ENFORCE_EQ
(
src
.
value
().
type
(),
selected_rows
.
value
().
type
(),
"The dtype of the two Variable is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
value
().
layout
(),
selected_rows
.
value
().
layout
(),
"The layout of the two Variable's tensor is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
height
(),
selected_rows
.
height
(),
"The height of the two Variable is not equal."
);
PADDLE_ENFORCE_EQ
(
src
.
GetCompleteDims
(),
selected_rows
.
GetCompleteDims
(),
"The dims of the two Variable is not equal."
);
}
template
<
typename
T
>
void
operator
()(
const
T
&
)
{
PADDLE_ENFORCE
(
"EnforceShapeAndDTypeEQ is not supported by type %s"
,
typeid
(
T
).
name
());
}
};
void
VariableVisitor
::
EnforceShapeAndDTypeEQ
(
const
Variable
&
var1
,
const
Variable
&
var2
)
{
EnforceEqualShapeAndDTypeVisitor
visitor
{
&
var1
};
VisitVariable
(
var2
,
&
visitor
);
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/variable_visitor.h
浏览文件 @
7722baa8
...
@@ -26,6 +26,9 @@ class VariableVisitor {
...
@@ -26,6 +26,9 @@ class VariableVisitor {
static
Tensor
&
GetMutableTensor
(
Variable
*
var
);
static
Tensor
&
GetMutableTensor
(
Variable
*
var
);
static
void
ShareDimsAndLoD
(
const
Variable
&
src
,
Variable
*
trg
);
static
void
ShareDimsAndLoD
(
const
Variable
&
src
,
Variable
*
trg
);
static
void
EnforceShapeAndDTypeEQ
(
const
Variable
&
var1
,
const
Variable
&
var2
);
};
};
}
// namespace details
}
// namespace details
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录