Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5d4238cd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d4238cd
编写于
8月 10, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix six.iteritems problem
上级
e4e9450e
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
37 addition
and
40 deletion
+37
-40
python/paddle/dataset/imdb.py
python/paddle/dataset/imdb.py
+1
-1
python/paddle/dataset/imikolov.py
python/paddle/dataset/imikolov.py
+1
-1
python/paddle/dataset/sentiment.py
python/paddle/dataset/sentiment.py
+1
-1
python/paddle/dataset/wmt14.py
python/paddle/dataset/wmt14.py
+2
-2
python/paddle/dataset/wmt16.py
python/paddle/dataset/wmt16.py
+1
-2
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+5
-5
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-1
python/paddle/fluid/graphviz.py
python/paddle/fluid/graphviz.py
+4
-5
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+1
-1
python/paddle/fluid/metrics.py
python/paddle/fluid/metrics.py
+3
-3
python/paddle/fluid/tests/unittests/benchmark.py
python/paddle/fluid/tests/unittests/benchmark.py
+2
-2
python/paddle/fluid/tests/unittests/test_detection_map_op.py
python/paddle/fluid/tests/unittests/test_detection_map_op.py
+1
-1
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
+1
-1
python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py
...e/fluid/tests/unittests/test_positive_negative_pair_op.py
+1
-1
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+2
-2
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+10
-11
未找到文件。
python/paddle/dataset/imdb.py
浏览文件 @
5d4238cd
...
@@ -64,7 +64,7 @@ def build_dict(pattern, cutoff):
...
@@ -64,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq
[
word
]
+=
1
word_freq
[
word
]
+=
1
# Not sure if we should prune less-frequent words here.
# Not sure if we should prune less-frequent words here.
word_freq
=
[
x
for
x
in
six
.
moves
.
iteritems
(
word_freq
)
if
x
[
1
]
>
cutoff
]
word_freq
=
[
x
for
x
in
six
.
iteritems
(
word_freq
)
if
x
[
1
]
>
cutoff
]
dictionary
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
dictionary
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
dictionary
))
words
,
_
=
list
(
zip
(
*
dictionary
))
...
...
python/paddle/dataset/imikolov.py
浏览文件 @
5d4238cd
...
@@ -66,7 +66,7 @@ def build_dict(min_word_freq=50):
...
@@ -66,7 +66,7 @@ def build_dict(min_word_freq=50):
del
word_freq
[
'<unk>'
]
del
word_freq
[
'<unk>'
]
word_freq
=
[
word_freq
=
[
x
for
x
in
six
.
moves
.
iteritems
(
word_freq
)
if
x
[
1
]
>
min_word_freq
x
for
x
in
six
.
iteritems
(
word_freq
)
if
x
[
1
]
>
min_word_freq
]
]
word_freq_sorted
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
word_freq_sorted
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
...
...
python/paddle/dataset/sentiment.py
浏览文件 @
5d4238cd
...
@@ -65,7 +65,7 @@ def get_word_dict():
...
@@ -65,7 +65,7 @@ def get_word_dict():
for
field
in
movie_reviews
.
fileids
(
category
):
for
field
in
movie_reviews
.
fileids
(
category
):
for
words
in
movie_reviews
.
words
(
field
):
for
words
in
movie_reviews
.
words
(
field
):
word_freq_dict
[
words
]
+=
1
word_freq_dict
[
words
]
+=
1
words_sort_list
=
six
.
moves
.
iteritems
(
word_freq_dict
)
words_sort_list
=
six
.
iteritems
(
word_freq_dict
)
words_sort_list
.
sort
(
cmp
=
lambda
a
,
b
:
b
[
1
]
-
a
[
1
])
words_sort_list
.
sort
(
cmp
=
lambda
a
,
b
:
b
[
1
]
-
a
[
1
])
for
index
,
word
in
enumerate
(
words_sort_list
):
for
index
,
word
in
enumerate
(
words_sort_list
):
words_freq_sorted
.
append
((
word
[
0
],
index
))
words_freq_sorted
.
append
((
word
[
0
],
index
))
...
...
python/paddle/dataset/wmt14.py
浏览文件 @
5d4238cd
...
@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
...
@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file
=
paddle
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
tar_file
=
paddle
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
src_dict
,
trg_dict
=
__read_to_dict
(
tar_file
,
dict_size
)
src_dict
,
trg_dict
=
__read_to_dict
(
tar_file
,
dict_size
)
if
reverse
:
if
reverse
:
src_dict
=
{
v
:
k
for
k
,
v
in
six
.
moves
.
iteritems
(
src_dict
)}
src_dict
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
src_dict
)}
trg_dict
=
{
v
:
k
for
k
,
v
in
six
.
moves
.
iteritems
(
trg_dict
)}
trg_dict
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
trg_dict
)}
return
src_dict
,
trg_dict
return
src_dict
,
trg_dict
...
...
python/paddle/dataset/wmt16.py
浏览文件 @
5d4238cd
...
@@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
...
@@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout
.
write
(
"%s
\n
%s
\n
%s
\n
"
%
(
START_MARK
,
END_MARK
,
UNK_MARK
))
fout
.
write
(
"%s
\n
%s
\n
%s
\n
"
%
(
START_MARK
,
END_MARK
,
UNK_MARK
))
for
idx
,
word
in
enumerate
(
for
idx
,
word
in
enumerate
(
sorted
(
sorted
(
six
.
moves
.
iteritems
(
word_dict
),
six
.
iteritems
(
word_dict
),
key
=
lambda
x
:
x
[
1
],
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)):
reverse
=
True
)):
if
idx
+
3
==
dict_size
:
break
if
idx
+
3
==
dict_size
:
break
fout
.
write
(
"%s
\n
"
%
(
word
[
0
]))
fout
.
write
(
"%s
\n
"
%
(
word
[
0
]))
...
...
python/paddle/fluid/backward.py
浏览文件 @
5d4238cd
...
@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
...
@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
"""
op_desc
=
core
.
OpDesc
()
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
op_desc
.
set_type
(
op_type
)
for
para
,
args
in
six
.
moves
.
iteritems
(
inputs
):
for
para
,
args
in
six
.
iteritems
(
inputs
):
op_desc
.
set_input
(
op_desc
.
set_input
(
para
,
para
,
list
(
list
(
map
(
lambda
arg
:
arg
.
decode
()
if
isinstance
(
arg
,
six
.
binary_type
)
else
arg
,
map
(
lambda
arg
:
arg
.
decode
()
if
isinstance
(
arg
,
six
.
binary_type
)
else
arg
,
args
)))
args
)))
for
para
,
args
in
six
.
moves
.
iteritems
(
outputs
):
for
para
,
args
in
six
.
iteritems
(
outputs
):
op_desc
.
set_output
(
op_desc
.
set_output
(
para
,
para
,
list
(
list
(
...
@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
...
@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if
op_role_attr_name
not
in
attrs
:
if
op_role_attr_name
not
in
attrs
:
attrs
[
attrs
[
op_role_attr_name
]
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
op_role_attr_name
]
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
name
,
val
in
six
.
moves
.
iteritems
(
attrs
):
for
name
,
val
in
six
.
iteritems
(
attrs
):
if
isinstance
(
val
,
framework
.
Block
):
if
isinstance
(
val
,
framework
.
Block
):
op_desc
.
set_block_attr
(
name
,
val
.
desc
)
op_desc
.
set_block_attr
(
name
,
val
.
desc
)
else
:
else
:
...
@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
...
@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc
.
set_output
(
param_name
,
arg_names
)
op_desc
.
set_output
(
param_name
,
arg_names
)
renamed_vars
[
var_name
].
append
(
new_name
)
renamed_vars
[
var_name
].
append
(
new_name
)
for
var_name
,
inputs
in
six
.
moves
.
iteritems
(
renamed_vars
):
for
var_name
,
inputs
in
six
.
iteritems
(
renamed_vars
):
if
len
(
inputs
)
>
1
:
if
len
(
inputs
)
>
1
:
pending_sum_ops
.
append
(
pending_sum_ops
.
append
(
(
_create_op_desc_
(
"sum"
,
{
"X"
:
inputs
},
{
"Out"
:
[
var_name
]},
(
_create_op_desc_
(
"sum"
,
{
"X"
:
inputs
},
{
"Out"
:
[
var_name
]},
...
@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
...
@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc
.
rename_output
(
name
,
new_name
)
op_desc
.
rename_output
(
name
,
new_name
)
var_map
[
name
]
=
new_name
var_map
[
name
]
=
new_name
for
g
,
ng
in
six
.
moves
.
iteritems
(
var_map
):
for
g
,
ng
in
six
.
iteritems
(
var_map
):
if
g
in
grad_to_var
:
if
g
in
grad_to_var
:
grad_to_var
[
ng
]
=
grad_to_var
[
g
]
grad_to_var
[
ng
]
=
grad_to_var
[
g
]
grad_to_var
.
pop
(
g
)
grad_to_var
.
pop
(
g
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
5d4238cd
...
@@ -958,7 +958,7 @@ class Block(object):
...
@@ -958,7 +958,7 @@ class Block(object):
return
list
(
self
.
iter_parameters
())
return
list
(
self
.
iter_parameters
())
def
iter_parameters
(
self
):
def
iter_parameters
(
self
):
return
(
item
[
1
]
for
item
in
six
.
moves
.
iteritems
(
self
.
vars
)
return
(
item
[
1
]
for
item
in
six
.
iteritems
(
self
.
vars
)
if
isinstance
(
item
[
1
],
Parameter
))
if
isinstance
(
item
[
1
],
Parameter
))
def
create_var
(
self
,
*
args
,
**
kwargs
):
def
create_var
(
self
,
*
args
,
**
kwargs
):
...
...
python/paddle/fluid/graphviz.py
浏览文件 @
5d4238cd
...
@@ -106,7 +106,7 @@ class Graph(object):
...
@@ -106,7 +106,7 @@ class Graph(object):
def
_rank_repr
(
self
):
def
_rank_repr
(
self
):
ranks
=
sorted
(
ranks
=
sorted
(
six
.
moves
.
iteritems
(
self
.
rank_groups
),
six
.
iteritems
(
self
.
rank_groups
),
key
=
functools
.
cmp_to_key
(
key
=
functools
.
cmp_to_key
(
lambda
a
,
b
:
a
[
1
].
priority
>
b
[
1
].
priority
))
lambda
a
,
b
:
a
[
1
].
priority
>
b
[
1
].
priority
))
repr
=
[]
repr
=
[]
...
@@ -150,9 +150,8 @@ class Node(object):
...
@@ -150,9 +150,8 @@ class Node(object):
reprs
=
'{name} [label={label} {extra} ];'
.
format
(
reprs
=
'{name} [label={label} {extra} ];'
.
format
(
name
=
self
.
name
,
name
=
self
.
name
,
label
=
self
.
label
,
label
=
self
.
label
,
extra
=
','
+
','
.
join
(
extra
=
','
+
','
.
join
(
"%s=%s"
%
(
key
,
crepr
(
value
))
"%s=%s"
%
(
key
,
crepr
(
value
))
for
key
,
value
in
six
.
iteritems
(
self
.
attrs
))
for
key
,
value
in
six
.
moves
.
iteritems
(
self
.
attrs
))
if
self
.
attrs
else
""
)
if
self
.
attrs
else
""
)
return
reprs
return
reprs
...
@@ -176,7 +175,7 @@ class Edge(object):
...
@@ -176,7 +175,7 @@ class Edge(object):
target
=
self
.
target
.
name
,
target
=
self
.
target
.
name
,
extra
=
""
if
not
self
.
attrs
else
extra
=
""
if
not
self
.
attrs
else
"["
+
','
.
join
(
"{}={}"
.
format
(
attr
[
0
],
crepr
(
attr
[
1
]))
"["
+
','
.
join
(
"{}={}"
.
format
(
attr
[
0
],
crepr
(
attr
[
1
]))
for
attr
in
six
.
moves
.
iteritems
(
self
.
attrs
))
+
"]"
)
for
attr
in
six
.
iteritems
(
self
.
attrs
))
+
"]"
)
return
repr
return
repr
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
5d4238cd
...
@@ -603,7 +603,7 @@ class StaticRNN(object):
...
@@ -603,7 +603,7 @@ class StaticRNN(object):
boot_memories
=
[]
boot_memories
=
[]
pre_memories
=
[]
pre_memories
=
[]
memories
=
[]
memories
=
[]
for
_
,
mem
in
six
.
moves
.
iteritems
(
self
.
memories
):
for
_
,
mem
in
six
.
iteritems
(
self
.
memories
):
boot_memories
.
append
(
mem
.
init
)
boot_memories
.
append
(
mem
.
init
)
pre_memories
.
append
(
mem
.
pre_mem
.
name
)
pre_memories
.
append
(
mem
.
pre_mem
.
name
)
mem_var
=
rnn_block
.
var
(
mem
.
mem
.
name
)
mem_var
=
rnn_block
.
var
(
mem
.
mem
.
name
)
...
...
python/paddle/fluid/metrics.py
浏览文件 @
5d4238cd
...
@@ -80,10 +80,10 @@ class MetricBase(object):
...
@@ -80,10 +80,10 @@ class MetricBase(object):
"""
"""
states
=
{
states
=
{
attr
:
value
attr
:
value
for
attr
,
value
in
six
.
moves
.
iteritems
(
self
.
__dict__
)
for
attr
,
value
in
six
.
iteritems
(
self
.
__dict__
)
if
not
attr
.
startswith
(
"_"
)
if
not
attr
.
startswith
(
"_"
)
}
}
for
attr
,
value
in
six
.
moves
.
iteritems
(
states
):
for
attr
,
value
in
six
.
iteritems
(
states
):
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
setattr
(
self
,
attr
,
0
)
setattr
(
self
,
attr
,
0
)
elif
isinstance
(
value
,
float
):
elif
isinstance
(
value
,
float
):
...
@@ -106,7 +106,7 @@ class MetricBase(object):
...
@@ -106,7 +106,7 @@ class MetricBase(object):
"""
"""
states
=
{
states
=
{
attr
:
value
attr
:
value
for
attr
,
value
in
six
.
moves
.
iteritems
(
self
.
__dict__
)
for
attr
,
value
in
six
.
iteritems
(
self
.
__dict__
)
if
not
attr
.
startswith
(
"_"
)
if
not
attr
.
startswith
(
"_"
)
}
}
config
=
{}
config
=
{}
...
...
python/paddle/fluid/tests/unittests/benchmark.py
浏览文件 @
5d4238cd
...
@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
...
@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def
_get_input_names
(
self
):
def
_get_input_names
(
self
):
inputs
=
[]
inputs
=
[]
for
name
,
value
in
six
.
moves
.
iteritems
(
self
.
inputs
):
for
name
,
value
in
six
.
iteritems
(
self
.
inputs
):
if
isinstance
(
value
,
list
):
if
isinstance
(
value
,
list
):
inputs
.
extend
([
sub_name
for
sub_name
,
_
in
value
])
inputs
.
extend
([
sub_name
for
sub_name
,
_
in
value
])
inputs
.
append
(
name
)
inputs
.
append
(
name
)
...
@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
...
@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def
_get_output_names
(
self
):
def
_get_output_names
(
self
):
outputs
=
[]
outputs
=
[]
for
var_name
,
var
in
six
.
moves
.
iteritems
(
self
.
outputs
):
for
var_name
,
var
in
six
.
iteritems
(
self
.
outputs
):
if
isinstance
(
var
,
list
):
if
isinstance
(
var
,
list
):
for
sub_var_name
,
sub_var
in
var
:
for
sub_var_name
,
sub_var
in
var
:
outputs
.
append
(
sub_var_name
)
outputs
.
append
(
sub_var_name
)
...
...
python/paddle/fluid/tests/unittests/test_detection_map_op.py
浏览文件 @
5d4238cd
...
@@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
...
@@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos
[
label
].
append
([
score
,
tp
])
true_pos
[
label
].
append
([
score
,
tp
])
false_pos
[
label
].
append
([
score
,
fp
])
false_pos
[
label
].
append
([
score
,
fp
])
for
(
label
,
label_pos_num
)
in
six
.
moves
.
iteritems
(
label_count
):
for
(
label
,
label_pos_num
)
in
six
.
iteritems
(
label_count
):
if
label_pos_num
==
0
or
label
not
in
true_pos
:
continue
if
label_pos_num
==
0
or
label
not
in
true_pos
:
continue
label_true_pos
=
true_pos
[
label
]
label_true_pos
=
true_pos
[
label
]
label_false_pos
=
false_pos
[
label
]
label_false_pos
=
false_pos
[
label
]
...
...
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
浏览文件 @
5d4238cd
...
@@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
...
@@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe
.
run
(
scope
=
scope
,
feed
=
{
'x'
:
tensor
})
exe
.
run
(
scope
=
scope
,
feed
=
{
'x'
:
tensor
})
var
=
scope
.
find_var
(
rank_table
.
name
)
var
=
scope
.
find_var
(
rank_table
.
name
)
table
=
var
.
get_lod_rank_table
()
table
=
var
.
get_lod_rank_table
()
self
.
assertEqual
([(
0
,
5
),
(
1
,
1
),
(
2
,
1
)],
six
.
moves
.
iteritems
(
table
))
self
.
assertEqual
([(
0
,
5
),
(
1
,
1
),
(
2
,
1
)],
six
.
iteritems
(
table
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py
浏览文件 @
5d4238cd
...
@@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
...
@@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics
# accumulate statistics
pos
,
neg
,
neu
=
0
,
0
,
0
pos
,
neg
,
neu
=
0
,
0
,
0
for
_
,
ranks
in
six
.
moves
.
iteritems
(
predictions
):
for
_
,
ranks
in
six
.
iteritems
(
predictions
):
for
e1
,
e2
in
itertools
.
combinations
(
ranks
,
2
):
for
e1
,
e2
in
itertools
.
combinations
(
ranks
,
2
):
s1
,
s2
,
l1
,
l2
,
w1
,
w2
=
e1
[
0
],
e2
[
0
],
e1
[
1
],
e2
[
1
],
e1
[
2
],
e2
[
2
]
s1
,
s2
,
l1
,
l2
,
w1
,
w2
=
e1
[
0
],
e2
[
0
],
e1
[
1
],
e2
[
1
],
e1
[
2
],
e2
[
2
]
w
=
(
w1
+
w2
)
*
0.5
w
=
(
w1
+
w2
)
*
0.5
...
...
python/paddle/fluid/trainer.py
浏览文件 @
5d4238cd
...
@@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order):
...
@@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
)
sorted_pair_list
=
sorted
(
sorted_pair_list
=
sorted
(
six
.
moves
.
iteritems
(
feed_order
),
key
=
lambda
item
:
item
[
1
])
six
.
iteritems
(
feed_order
),
key
=
lambda
item
:
item
[
1
])
feed_var_list
=
[
feed_var_list
=
[
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
]
]
...
@@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
...
@@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
six
.
moves
.
iteritems
(
trainer_args
):
for
name
,
value
in
six
.
iteritems
(
trainer_args
):
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
f
.
write
(
str
(
value
))
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
5d4238cd
...
@@ -218,8 +218,7 @@ class DistributeTranspiler(object):
...
@@ -218,8 +218,7 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items
=
list
(
grad_var_mapping_items
=
list
(
six
.
iteritems
(
self
.
grad_var_mapping
))
six
.
moves
.
iteritems
(
self
.
grad_var_mapping
))
if
not
self
.
config
.
slice_var_up
:
if
not
self
.
config
.
slice_var_up
:
random
.
seed
(
self
.
origin_program
.
random_seed
)
random
.
seed
(
self
.
origin_program
.
random_seed
)
...
@@ -280,7 +279,7 @@ class DistributeTranspiler(object):
...
@@ -280,7 +279,7 @@ class DistributeTranspiler(object):
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
six
.
moves
.
iteritems
(
self
.
param_var_mapping
):
for
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
eps
=
[]
eps
=
[]
for
var
in
splited_var
:
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
...
@@ -304,7 +303,7 @@ class DistributeTranspiler(object):
...
@@ -304,7 +303,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
})
for
varname
,
splited_var
in
six
.
moves
.
iteritems
(
self
.
param_var_mapping
):
for
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
orig_param
=
program
.
global_block
().
vars
[
varname
]
...
@@ -561,7 +560,7 @@ class DistributeTranspiler(object):
...
@@ -561,7 +560,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program
# 1. create vars in pserver program to startup program
pserver_vars
=
pserver_program
.
global_block
().
vars
pserver_vars
=
pserver_program
.
global_block
().
vars
created_var_map
=
collections
.
OrderedDict
()
created_var_map
=
collections
.
OrderedDict
()
for
_
,
var
in
six
.
moves
.
iteritems
(
pserver_vars
):
for
_
,
var
in
six
.
iteritems
(
pserver_vars
):
tmpvar
=
s_prog
.
global_block
().
_clone_variable
(
var
)
tmpvar
=
s_prog
.
global_block
().
_clone_variable
(
var
)
created_var_map
[
var
.
name
]
=
tmpvar
created_var_map
[
var
.
name
]
=
tmpvar
...
@@ -998,7 +997,7 @@ class DistributeTranspiler(object):
...
@@ -998,7 +997,7 @@ class DistributeTranspiler(object):
block_map
[
varname
]
=
[]
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
int
(
offset
),
int
(
size
)))
block_map
[
varname
].
append
((
int
(
offset
),
int
(
size
)))
for
varname
,
splited
in
six
.
moves
.
iteritems
(
block_map
):
for
varname
,
splited
in
six
.
iteritems
(
block_map
):
orig_var
=
program
.
global_block
().
var
(
varname
)
orig_var
=
program
.
global_block
().
var
(
varname
)
if
len
(
splited
)
==
1
:
if
len
(
splited
)
==
1
:
if
self
.
sync_mode
and
add_trainer_suffix
:
if
self
.
sync_mode
and
add_trainer_suffix
:
...
@@ -1249,7 +1248,7 @@ class DistributeTranspiler(object):
...
@@ -1249,7 +1248,7 @@ class DistributeTranspiler(object):
def
_is_splited_grad_var
(
self
,
var
,
var_dict
):
def
_is_splited_grad_var
(
self
,
var
,
var_dict
):
grad_block
=
None
grad_block
=
None
for
_
,
g
in
six
.
moves
.
iteritems
(
var_dict
):
for
_
,
g
in
six
.
iteritems
(
var_dict
):
if
self
.
_orig_varname
(
g
.
name
)
==
self
.
_orig_varname
(
var
.
name
):
if
self
.
_orig_varname
(
g
.
name
)
==
self
.
_orig_varname
(
var
.
name
):
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
grad_block
=
g
grad_block
=
g
...
@@ -1259,7 +1258,7 @@ class DistributeTranspiler(object):
...
@@ -1259,7 +1258,7 @@ class DistributeTranspiler(object):
def
_clone_lr_op
(
self
,
program
,
block
,
op
):
def
_clone_lr_op
(
self
,
program
,
block
,
op
):
inputs
=
self
.
_get_input_map_from_op
(
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
self
.
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
six
.
moves
.
iteritems
(
inputs
):
for
key
,
varlist
in
six
.
iteritems
(
inputs
):
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
...
@@ -1268,7 +1267,7 @@ class DistributeTranspiler(object):
...
@@ -1268,7 +1267,7 @@ class DistributeTranspiler(object):
outputs
=
self
.
_get_output_map_from_op
(
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
self
.
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
six
.
moves
.
iteritems
(
outputs
):
for
key
,
varlist
in
six
.
iteritems
(
outputs
):
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
...
@@ -1283,7 +1282,7 @@ class DistributeTranspiler(object):
...
@@ -1283,7 +1282,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated
# Append the ops for parameters that do not need to be optimized/updated
inputs
=
self
.
_get_input_map_from_op
(
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
six
.
moves
.
iteritems
(
inputs
):
for
key
,
varlist
in
six
.
iteritems
(
inputs
):
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
...
@@ -1302,7 +1301,7 @@ class DistributeTranspiler(object):
...
@@ -1302,7 +1301,7 @@ class DistributeTranspiler(object):
outputs
=
self
.
_get_output_map_from_op
(
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
six
.
moves
.
iteritems
(
outputs
):
for
key
,
varlist
in
six
.
iteritems
(
outputs
):
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录