Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
79989c90
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79989c90
编写于
3月 21, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add SSA builder
上级
64d7a302
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
199 addition
and
174 deletion
+199
-174
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+199
-170
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-4
未找到文件。
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
79989c90
...
...
@@ -43,14 +43,20 @@ struct SSAGraph {
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
ops_
;
};
/**
class
SSAGraphBuilder
{
public:
virtual
~
SSAGraphBuilder
()
{}
virtual
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
=
0
;
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static
void
PolishGraphToSupportDataHazards
(
SSAGraph
*
graph
)
{
static
void
PolishGraphToSupportDataHazards
(
SSAGraph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
...
...
@@ -83,9 +89,9 @@ static void PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
}
}
}
}
static
VarHandle
*
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
...
...
@@ -103,11 +109,12 @@ static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
var
=
&
var_holder
.
rbegin
()
->
second
;
}
return
var
;
}
}
static
void
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
static
void
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
size_t
version
=
vars
.
size
();
auto
&
var
=
vars
[
version
];
...
...
@@ -115,7 +122,132 @@ static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
var
.
name_
=
each_var_name
;
var
.
place_
=
place
;
op_handle
->
AddOutput
(
&
var
);
}
}
};
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
)
{
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
override
{
SSAGraph
&
result
=
*
graph
;
result
.
vars_
.
resize
(
places_
.
size
());
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
bool
change_forward
=
false
;
if
(
!
is_forwarding
)
{
// FIXME(yy): Do not hard code like this
if
(
op
->
OutputArgumentNames
().
size
()
==
1
&&
op
->
OutputArgumentNames
()[
0
]
==
GradVarName
(
loss_var_name_
))
{
continue
;
// Drop fill 1. for backward coeff;
}
}
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
*
s
=
local_scopes_
[
i
];
result
.
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
s
,
p
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
p
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
auto
var_names
=
op
->
InputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
&
result
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
&
result
,
op_handle
,
each_var_name
,
p
,
i
);
}
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name_
)
{
// Insert ScaleCost OpHandle
op_handle
=
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
s
,
p
,
nccl_ctxs_
->
DevCtx
(
p
));
result
.
ops_
.
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput
(
&
result
,
op_handle
,
GradVarName
(
loss_var_name_
),
p
,
i
);
change_forward
=
true
;
}
}
}
if
(
change_forward
)
{
is_forwarding
=
false
;
}
if
(
!
is_forwarding
)
{
auto
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
og
:
var_names
)
{
if
(
grad_names_
.
count
(
og
)
!=
0
)
{
// is param grad
// Insert NCCL AllReduce Op
result
.
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
,
*
nccl_ctxs_
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
vars
=
result
.
vars_
[
i
][
og
];
if
(
vars
.
empty
())
{
// This device has no data. continue.
continue
;
}
auto
*
prev_grad
=
&
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
);
auto
&
var
=
vars
[
vars
.
size
()];
var
.
place_
=
p
;
var
.
name_
=
og
;
var
.
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
&
var
);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards
(
&
result
);
}
private:
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
platform
::
NCCLContextMap
*
nccl_ctxs_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
};
class
ParallelExecutorPrivate
{
public:
...
...
@@ -123,9 +255,7 @@ class ParallelExecutorPrivate {
const
std
::
vector
<
platform
::
Place
>
&
places
)
:
places_
(
places
),
fetch_dev_ctxs_
(
places
),
pool_
(
num_threads
<=
1
?
nullptr
:
new
ThreadPool
(
num_threads
))
{
graph_
.
vars_
.
resize
(
places
.
size
());
}
pool_
(
num_threads
<=
1
?
nullptr
:
new
ThreadPool
(
num_threads
))
{}
std
::
vector
<
platform
::
Place
>
places_
;
platform
::
DeviceContextPool
fetch_dev_ctxs_
;
...
...
@@ -199,7 +329,10 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
ConstructDependencyGraph
(
params
,
main_program
,
loss_var_name
);
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
());
builder
.
Build
(
main_program
,
&
member_
->
graph_
);
// Step 3. Create vars in each scope;
for
(
auto
*
scope
:
member_
->
local_scopes_
)
{
...
...
@@ -213,110 +346,6 @@ ParallelExecutor::ParallelExecutor(
}
}
void
ParallelExecutor
::
ConstructDependencyGraph
(
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
)
const
{
std
::
unordered_set
<
std
::
string
>
grads
;
for
(
auto
&
each_param
:
params
)
{
grads
.
insert
(
each_param
+
"@GRAD"
);
}
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
main_program
.
Block
(
0
).
AllOps
())
{
bool
change_forward
=
false
;
if
(
!
is_forwarding
)
{
// FIXME(yy): Do not hard code like this
if
(
op
->
OutputArgumentNames
().
size
()
==
1
&&
op
->
OutputArgumentNames
()[
0
]
==
loss_var_name
+
"@GRAD"
)
{
continue
;
// Drop fill 1. for backward coeff;
}
}
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
&
p
=
member_
->
places_
[
i
];
auto
*
s
=
member_
->
local_scopes_
[
i
];
member_
->
graph_
.
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
s
,
p
));
auto
*
op_handle
=
member_
->
graph_
.
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
p
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
auto
var_names
=
op
->
InputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
&
member_
->
graph_
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
&
member_
->
graph_
,
op_handle
,
each_var_name
,
p
,
i
);
}
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name
)
{
// Insert ScaleCost OpHandle
op_handle
=
new
ScaleLossGradOpHandle
(
this
->
member_
->
local_scopes_
.
size
(),
s
,
p
,
member_
->
nccl_ctxs_
->
DevCtx
(
p
));
member_
->
graph_
.
ops_
.
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput
(
&
member_
->
graph_
,
op_handle
,
loss_var_name
+
"@GRAD"
,
p
,
i
);
change_forward
=
true
;
}
}
}
if
(
change_forward
)
{
is_forwarding
=
false
;
}
if
(
!
is_forwarding
)
{
auto
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
og
:
var_names
)
{
if
(
grads
.
count
(
og
)
!=
0
)
{
// is param grad
// Insert NCCL AllReduce Op
member_
->
graph_
.
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
member_
->
local_scopes_
,
member_
->
places_
,
*
member_
->
nccl_ctxs_
));
auto
*
op_handle
=
member_
->
graph_
.
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
&
p
=
member_
->
places_
[
i
];
auto
&
vars
=
member_
->
graph_
.
vars_
[
i
][
og
];
if
(
vars
.
empty
())
{
// This device has no data. continue.
continue
;
}
auto
*
prev_grad
=
&
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
);
auto
&
var
=
vars
[
vars
.
size
()];
var
.
place_
=
p
;
var
.
name_
=
og
;
var
.
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
&
var
);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards
(
&
member_
->
graph_
);
}
void
ParallelExecutor
::
BCastParamsToGPUs
(
const
ProgramDesc
&
startup_program
)
const
{
#ifdef PADDLE_WITH_CUDA
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
79989c90
...
...
@@ -47,10 +47,6 @@ class ParallelExecutor {
void
BCastParamsToGPUs
(
const
ProgramDesc
&
startup_program
)
const
;
void
ConstructDependencyGraph
(
const
std
::
unordered_set
<
std
::
string
>&
params
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
)
const
;
void
BuildNCCLCommunicator
()
const
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录