Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
56e758fc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
56e758fc
编写于
1月 09, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
trainer ok
上级
f35c5606
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
13 addition
and
7 deletion
+13
-7
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+13
-7
未找到文件。
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
56e758fc
...
...
@@ -56,6 +56,8 @@ def split_dense_variable(var_list,
(
block_id
)
*
block_size
))
block
=
VarBlock
(
var
.
name
,
block_id
,
curr_block_size
)
blocks
.
append
(
str
(
block
))
print
(
"$$ splited var: "
,
var
.
name
,
var
.
shape
,
split_count
,
len
(
blocks
),
block_size
)
return
blocks
...
...
@@ -132,10 +134,12 @@ class DistributeTranspiler:
# step4
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
concat
=
program
.
global_block
().
append_op
(
type
=
"concat"
,
inputs
=
{
"X"
:
s
end_outputs
},
inputs
=
{
"X"
:
s
plited_var
},
outputs
=
{
"Out"
:
orig_param
},
attrs
=
{
"axis"
:
0
})
...
...
@@ -147,28 +151,29 @@ class DistributeTranspiler:
if
not
block_map
.
has_key
(
varname
):
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
long
(
offset
),
long
(
size
)))
for
varname
,
splited
in
block_map
.
iteritems
():
orig_var
=
program
.
global_block
().
vars
[
varname
]
var_mapping
[
varname
]
=
[]
if
len
(
splited
)
==
1
:
var_mapping
[
varname
]
=
[
orig_var
]
continue
orig_shape
=
orig_var
.
shape
orig_dim1_flatten
=
1
if
len
(
orig_shape
)
>=
2
:
orig_dim1_flatten
=
reduce
(
lambda
x
,
y
:
x
*
y
,
orig_shape
[
1
:])
var_list
=
[]
for
i
,
block
in
enumerate
(
splited
):
size
=
block
[
1
]
rows
=
size
/
orig_dim1_flatten
splited_shape
=
[
rows
]
if
len
(
orig_shape
)
>=
2
:
splited_shape
.
extend
(
orig_shape
[
1
:])
print
(
"block, splited shape:"
,
block
,
splited_shape
)
var
=
program
.
global_block
().
create_var
(
name
=
"%s.block%d"
%
(
varname
,
i
),
psersistable
=
False
,
dtype
=
orig_var
.
dtype
,
shape
=
splited_shape
)
# flattend splited var
var_list
.
append
(
var
)
var_mapping
[
varname
]
=
var_list
var_mapping
[
varname
].
append
(
var
)
return
var_mapping
def
_clone_param
(
self
,
block
,
v
):
...
...
@@ -199,7 +204,8 @@ class DistributeTranspiler:
def
_append_split_op
(
self
,
program
,
gradblocks
):
var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
gradblocks
)
for
varname
,
splited_vars
in
var_mapping
.
iteritems
():
if
len
(
splited_vars
)
==
1
:
# variable that don't need to split have empty splited_vars
if
len
(
splited_vars
)
<=
1
:
continue
orig_var
=
program
.
global_block
().
vars
[
varname
]
sections
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录