Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
01bbe532
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
01bbe532
编写于
6月 01, 2018
作者:
C
chengduo
提交者:
GitHub
6月 01, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11079 from chengduoZH/balance_parameter_update
Balance parameter opt
上级
59870579
e330cd03
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
30 addition
and
13 deletion
+30
-13
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+29
-12
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+1
-1
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
01bbe532
...
@@ -11,11 +11,15 @@
...
@@ -11,11 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include
"paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include
<algorithm>
#include <fstream>
#include <fstream>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...
@@ -26,9 +30,6 @@
...
@@ -26,9 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
#endif
#include <string>
#include <vector>
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
"default /tmp/graph.dot"
);
...
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_type
s
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_var
s
;
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
var_types
[
var
->
Name
()]
=
var
->
GetType
()
;
all_vars
[
var
->
Name
()]
=
var
;
}
}
auto
graph
=
new
SSAGraph
();
auto
graph
=
new
SSAGraph
();
...
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
size_t
cur_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
var_name_on_devices
.
resize
(
places_
.
size
());
var_name_on_devices
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
size_t
cur_device_id
=
0
;
std
::
vector
<
int64_t
>
balance_grads
(
places_
.
size
(),
0
);
auto
get_appropriate_dev
=
[
&
](
std
::
string
&
g_name
)
->
size_t
{
auto
var_desc
=
all_vars
.
at
(
g_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GE
(
numel
,
0
);
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_grads
),
std
::
end
(
balance_grads
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_grads
),
smallest
));
balance_grads
[
dev_id
]
+=
numel
;
return
dev_id
;
};
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
...
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch
(
strategy_
.
reduce_
)
{
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
get_appropriate_dev
(
g_name
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices
[
cur_device_id
].
emplace
(
g_name
);
var_name_on_devices
[
cur_device_id
].
emplace
(
g_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
cur_device_id
=
(
cur_device_id
+
1
)
%
places_
.
size
();
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
var_type
s
,
g_name
))
{
if
(
IsSparseGradient
(
all_var
s
,
g_name
))
{
CreateReduceOp
(
&
result
,
g_name
,
0
);
CreateReduceOp
(
&
result
,
g_name
,
0
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
}
else
{
}
else
{
...
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_type
s
,
const
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
&
all_var
s
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
PADDLE_ENFORCE
(
var_type
s
.
count
(
og
)
!=
0
);
PADDLE_ENFORCE
(
all_var
s
.
count
(
og
)
!=
0
);
if
(
var_types
.
at
(
og
)
==
proto
::
VarType
::
SELECTED_ROWS
)
{
if
(
all_vars
.
at
(
og
)
->
GetType
(
)
==
proto
::
VarType
::
SELECTED_ROWS
)
{
return
true
;
return
true
;
}
}
return
false
;
return
false
;
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
01bbe532
...
@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
bool
IsSparseGradient
(
bool
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_type
s
,
const
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
&
all_var
s
,
const
std
::
string
&
og
)
const
;
const
std
::
string
&
og
)
const
;
private:
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录