Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
baea2cf1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
baea2cf1
编写于
4月 08, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
wip
上级
01c6618d
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
87 addition
and
72 deletion
+87
-72
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+46
-13
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-2
paddle/fluid/framework/details/send_op_handle.cc
paddle/fluid/framework/details/send_op_handle.cc
+13
-50
paddle/fluid/framework/details/send_op_handle.h
paddle/fluid/framework/details/send_op_handle.h
+8
-7
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+7
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
baea2cf1
...
...
@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library
(
nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
baea2cf1
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
...
...
@@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
platform
::
NCCLContextMap
*
nccl_ctxs
,
bool
distributed
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
distributed_
(
distributed
),
nccl_ctxs_
(
nccl_ctxs
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
)
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
distributed
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
)
{
local_scopes_
(
local_scopes
),
distributed_
(
distributed
)
{
#endif
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
SSAGraph
*
result
,
OpDesc
*
op
,
const
platform
::
Place
&
p
,
const
size_t
&
i
)
const
{
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
auto
var_names
=
op
->
InputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
result
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
result
,
op_handle
,
each_var_name
,
p
,
i
);
}
}
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
auto
graph
=
new
SSAGraph
();
...
...
@@ -72,6 +93,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// append send op if program is distributed trainer main program.
// always use the first device
if
(
is_forwarding
&&
distributed_
&&
op
->
Type
()
==
"send"
)
{
auto
&
p
=
places_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
size_t
i
=
0
;
result
.
ops_
.
emplace_back
(
new
SendOpHandle
(
*
op
,
s
,
p
));
CreateOpHandleIOs
(
&
result
,
op
,
p
,
i
);
continue
;
}
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
*
s
=
local_scopes_
[
i
];
...
...
@@ -81,18 +113,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
op_handle
->
dev_ctxes_
[
p
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
auto
var_names
=
op
->
InputArgumentNames
();
CreateOpHandleIOs
(
&
result
,
op
,
p
,
i
);
// auto var_names = op->InputArgumentNames();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
&
result
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
//
for (auto &each_var_name : var_names) {
//
VarHandle *var =
//
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
//
op_handle->AddInput(var);
//
}
auto
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
&
result
,
op_handle
,
each_var_name
,
p
,
i
);
}
//
for (auto &each_var_name : var_names) {
//
CreateOpOutput(&result, op_handle, each_var_name, p, i);
//
}
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name_
)
{
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
baea2cf1
...
...
@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace
paddle
{
...
...
@@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
);
platform
::
NCCLContextMap
*
nccl_ctxs
,
bool
distributed
=
false
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
);
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
distributed
=
false
);
#endif
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
private:
void
CreateOpHandleIOs
(
SSAGraph
*
result
,
OpDesc
*
op
,
const
platform
::
Place
&
p
,
const
size_t
&
i
)
const
;
private:
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
bool
distributed_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
...
...
paddle/fluid/framework/details/send_op_handle.cc
浏览文件 @
baea2cf1
...
...
@@ -18,61 +18,24 @@ namespace paddle {
namespace
framework
{
namespace
details
{
SendOpHandle
::
SendOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
&
ctxs
)
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{}
SendOpHandle
::
SendOpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
place
)
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
local_scope_
(
local_scope
),
place_
(
place
)
{}
void
SendOpHandle
::
RunImpl
()
{
if
(
inputs_
.
size
()
==
1
)
{
return
;
// No need to all reduce when GPU count = 1;
}
else
{
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
}
auto
&
var_name
=
static_cast
<
VarHandle
*>
(
this
->
inputs_
[
0
])
->
name_
;
int
dtype
=
-
1
;
size_t
numel
=
0
;
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
*
s
=
local_scopes_
[
i
];
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
lod_tensor
=
s
->
FindVar
(
var_name
)
->
Get
<
LoDTensor
>
();
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor
.
data
<
void
>
());
if
(
dtype
==
-
1
)
{
dtype
=
platform
::
ToNCCLDataType
(
lod_tensor
.
type
());
}
if
(
numel
==
0
)
{
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
}
auto
&
nccl_ctx
=
nccl_ctxs_
.
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
buffer
,
buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
ncclSum
,
comm
,
stream
));
});
}
platform
::
NCCLGroupGuard
guard
;
for
(
auto
&
call
:
all_reduce_calls
)
{
call
();
}
}
op_
->
Run
(
*
local_scope_
,
place_
);
}
std
::
string
NCCLAllReduceOpHandle
::
Name
()
const
{
return
"nccl_all_reduce
"
;
}
std
::
string
SendOpHandle
::
Name
()
const
{
return
"send
"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/send_op_handle.h
浏览文件 @
baea2cf1
...
...
@@ -19,6 +19,8 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
...
...
@@ -27,19 +29,18 @@ namespace framework {
namespace
details
{
struct
SendOpHandle
:
public
OpHandleBase
{
const
std
::
vector
<
Scope
*>
&
local_scopes
_
;
const
std
::
vector
<
platform
::
Place
>
&
places
_
;
const
platform
::
NCCLContextMap
&
nccl_ctxs
_
;
std
::
unique_ptr
<
OperatorBase
>
op
_
;
const
Scope
*
local_scope
_
;
const
platform
::
Place
&
place
_
;
SendOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
&
ctxs
);
SendOpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
place
);
std
::
string
Name
()
const
override
;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool
IsMultiDeviceTransfer
()
override
{
return
tru
e
;
};
bool
IsMultiDeviceTransfer
()
override
{
return
fals
e
;
};
protected:
void
RunImpl
()
override
;
...
...
python/paddle/fluid/framework.py
浏览文件 @
baea2cf1
...
...
@@ -951,6 +951,13 @@ class Block(object):
if
var
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
type
=
var
.
type
)
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
persistable
=
True
)
else
:
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录