提交 5d4238cd 编写于 作者: M minqiyang

Fix six.iteritems problem

上级 e4e9450e
...@@ -64,7 +64,7 @@ def build_dict(pattern, cutoff): ...@@ -64,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1 word_freq[word] += 1
# Not sure if we should prune less-frequent words here. # Not sure if we should prune less-frequent words here.
word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff] word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary)) words, _ = list(zip(*dictionary))
......
...@@ -66,7 +66,7 @@ def build_dict(min_word_freq=50): ...@@ -66,7 +66,7 @@ def build_dict(min_word_freq=50):
del word_freq['<unk>'] del word_freq['<unk>']
word_freq = [ word_freq = [
x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq x for x in six.iteritems(word_freq) if x[1] > min_word_freq
] ]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
......
...@@ -65,7 +65,7 @@ def get_word_dict(): ...@@ -65,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category): for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field): for words in movie_reviews.words(field):
word_freq_dict[words] += 1 word_freq_dict[words] += 1
words_sort_list = six.moves.iteritems(word_freq_dict) words_sort_list = six.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index)) words_freq_sorted.append((word[0], index))
......
...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True): ...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in six.moves.iteritems(src_dict)} src_dict = {v: k for k, v in six.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)} trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
return src_dict, trg_dict return src_dict, trg_dict
......
...@@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(
six.moves.iteritems(word_dict), six.iteritems(word_dict), key=lambda x: x[1],
key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write("%s\n" % (word[0]))
......
...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
""" """
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for para, args in six.moves.iteritems(inputs): for para, args in six.iteritems(inputs):
op_desc.set_input( op_desc.set_input(
para, para,
list( list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args))) args)))
for para, args in six.moves.iteritems(outputs): for para, args in six.iteritems(outputs):
op_desc.set_output( op_desc.set_output(
para, para,
list( list(
...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs: if op_role_attr_name not in attrs:
attrs[ attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in six.moves.iteritems(attrs): for name, val in six.iteritems(attrs):
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
else: else:
...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in six.moves.iteritems(renamed_vars): for var_name, inputs in six.iteritems(renamed_vars):
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): ...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name) op_desc.rename_output(name, new_name)
var_map[name] = new_name var_map[name] = new_name
for g, ng in six.moves.iteritems(var_map): for g, ng in six.iteritems(var_map):
if g in grad_to_var: if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g] grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g) grad_to_var.pop(g)
......
...@@ -958,7 +958,7 @@ class Block(object): ...@@ -958,7 +958,7 @@ class Block(object):
return list(self.iter_parameters()) return list(self.iter_parameters())
def iter_parameters(self): def iter_parameters(self):
return (item[1] for item in six.moves.iteritems(self.vars) return (item[1] for item in six.iteritems(self.vars)
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
......
...@@ -106,7 +106,7 @@ class Graph(object): ...@@ -106,7 +106,7 @@ class Graph(object):
def _rank_repr(self): def _rank_repr(self):
ranks = sorted( ranks = sorted(
six.moves.iteritems(self.rank_groups), six.iteritems(self.rank_groups),
key=functools.cmp_to_key( key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority)) lambda a, b: a[1].priority > b[1].priority))
repr = [] repr = []
...@@ -150,9 +150,8 @@ class Node(object): ...@@ -150,9 +150,8 @@ class Node(object):
reprs = '{name} [label={label} {extra} ];'.format( reprs = '{name} [label={label} {extra} ];'.format(
name=self.name, name=self.name,
label=self.label, label=self.label,
extra=',' + ','.join( extra=',' + ','.join("%s=%s" % (key, crepr(value))
"%s=%s" % (key, crepr(value)) for key, value in six.iteritems(self.attrs))
for key, value in six.moves.iteritems(self.attrs))
if self.attrs else "") if self.attrs else "")
return reprs return reprs
...@@ -176,7 +175,7 @@ class Edge(object): ...@@ -176,7 +175,7 @@ class Edge(object):
target=self.target.name, target=self.target.name,
extra="" if not self.attrs else extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) "[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in six.moves.iteritems(self.attrs)) + "]") for attr in six.iteritems(self.attrs)) + "]")
return repr return repr
......
...@@ -603,7 +603,7 @@ class StaticRNN(object): ...@@ -603,7 +603,7 @@ class StaticRNN(object):
boot_memories = [] boot_memories = []
pre_memories = [] pre_memories = []
memories = [] memories = []
for _, mem in six.moves.iteritems(self.memories): for _, mem in six.iteritems(self.memories):
boot_memories.append(mem.init) boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name) pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name) mem_var = rnn_block.var(mem.mem.name)
......
...@@ -80,10 +80,10 @@ class MetricBase(object): ...@@ -80,10 +80,10 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in six.moves.iteritems(self.__dict__) for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
for attr, value in six.moves.iteritems(states): for attr, value in six.iteritems(states):
if isinstance(value, int): if isinstance(value, int):
setattr(self, attr, 0) setattr(self, attr, 0)
elif isinstance(value, float): elif isinstance(value, float):
...@@ -106,7 +106,7 @@ class MetricBase(object): ...@@ -106,7 +106,7 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in six.moves.iteritems(self.__dict__) for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
config = {} config = {}
......
...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest): ...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self): def _get_input_names(self):
inputs = [] inputs = []
for name, value in six.moves.iteritems(self.inputs): for name, value in six.iteritems(self.inputs):
if isinstance(value, list): if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value]) inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name) inputs.append(name)
...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest): ...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self): def _get_output_names(self):
outputs = [] outputs = []
for var_name, var in six.moves.iteritems(self.outputs): for var_name, var in six.iteritems(self.outputs):
if isinstance(var, list): if isinstance(var, list):
for sub_var_name, sub_var in var: for sub_var_name, sub_var in var:
outputs.append(sub_var_name) outputs.append(sub_var_name)
......
...@@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp]) true_pos[label].append([score, tp])
false_pos[label].append([score, fp]) false_pos[label].append([score, fp])
for (label, label_pos_num) in six.moves.iteritems(label_count): for (label, label_pos_num) in six.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label] label_true_pos = true_pos[label]
label_false_pos = false_pos[label] label_false_pos = false_pos[label]
......
...@@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): ...@@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe.run(scope=scope, feed={'x': tensor}) exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name) var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table() table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table)) self.assertEqual([(0, 5), (1, 1), (2, 1)], six.iteritems(table))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): ...@@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics # accumulate statistics
pos, neg, neu = 0, 0, 0 pos, neg, neu = 0, 0, 0
for _, ranks in six.moves.iteritems(predictions): for _, ranks in six.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2): for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5 w = (w1 + w2) * 0.5
......
...@@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order): ...@@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))" "The values of 'feed_order' should be a permutation of [0, len(feed_order))"
) )
sorted_pair_list = sorted( sorted_pair_list = sorted(
six.moves.iteritems(feed_order), key=lambda item: item[1]) six.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [ feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list program.global_block().var(pair[0]) for pair in sorted_pair_list
] ]
...@@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in six.moves.iteritems(trainer_args): for name, value in six.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name) args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f: with open(args_file, 'w') as f:
f.write(str(value)) f.write(str(value))
......
...@@ -218,8 +218,7 @@ class DistributeTranspiler(object): ...@@ -218,8 +218,7 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list( grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping))
six.moves.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up: if not self.config.slice_var_up:
random.seed(self.origin_program.random_seed) random.seed(self.origin_program.random_seed)
...@@ -280,7 +279,7 @@ class DistributeTranspiler(object): ...@@ -280,7 +279,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in six.moves.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
...@@ -304,7 +303,7 @@ class DistributeTranspiler(object): ...@@ -304,7 +303,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
for varname, splited_var in six.moves.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
...@@ -561,7 +560,7 @@ class DistributeTranspiler(object): ...@@ -561,7 +560,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program # 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = collections.OrderedDict() created_var_map = collections.OrderedDict()
for _, var in six.moves.iteritems(pserver_vars): for _, var in six.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var) tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
...@@ -998,7 +997,7 @@ class DistributeTranspiler(object): ...@@ -998,7 +997,7 @@ class DistributeTranspiler(object):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((int(offset), int(size))) block_map[varname].append((int(offset), int(size)))
for varname, splited in six.moves.iteritems(block_map): for varname, splited in six.iteritems(block_map):
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
...@@ -1249,7 +1248,7 @@ class DistributeTranspiler(object): ...@@ -1249,7 +1248,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict): def _is_splited_grad_var(self, var, var_dict):
grad_block = None grad_block = None
for _, g in six.moves.iteritems(var_dict): for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name): if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1: if g.name.find(".trainer_") == -1:
grad_block = g grad_block = g
...@@ -1259,7 +1258,7 @@ class DistributeTranspiler(object): ...@@ -1259,7 +1258,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op): def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in six.moves.iteritems(inputs): for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1268,7 +1267,7 @@ class DistributeTranspiler(object): ...@@ -1268,7 +1267,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in six.moves.iteritems(outputs): for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1283,7 +1282,7 @@ class DistributeTranspiler(object): ...@@ -1283,7 +1282,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in six.moves.iteritems(inputs): for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1302,7 +1301,7 @@ class DistributeTranspiler(object): ...@@ -1302,7 +1301,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in six.moves.iteritems(outputs): for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册