提交 6dc07e7f 编写于 作者: M minqiyang

Replace items() with six.moves.iteritems() to improve memory usage

上级 9cd59990
......@@ -47,7 +47,8 @@ def tokenize(pattern):
while tf != None:
if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(
yield tarf.extractfile(tf).read().rstrip(six.b(
"\n\r")).translate(
None, six.b(string.punctuation)).lower().split()
tf = tarf.next()
......@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1
# Not sure if we should prune less-frequent words here.
word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff]
word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
......
......@@ -21,7 +21,7 @@ into paddle reader creators.
import paddle.dataset.common
import collections
import tarfile
from six.moves import range
import six
__all__ = ['train', 'test', 'build_dict', 'convert']
......@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq]
word_freq = [
x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, range(len(words)))))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words)
return word_idx
......@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
for i in six.moves.range(n, len(l) + 1):
yield tuple(l[i - n:i])
elif DataType.SEQ == data_type:
l = l.strip().split()
......
......@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset.
"""
import six
import collections
from itertools import chain
......@@ -64,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = list(word_freq_dict.items())
words_sort_list = six.moves.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index))
......
......@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in list(src_dict.items())}
trg_dict = {v: k for k, v in list(trg_dict.items())}
src_dict = {v: k for k, v in six.moves.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)}
return src_dict, trg_dict
......
......@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate(
sorted(
iter(list(word_dict.items())),
six.moves.iteritems(word_dict),
key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
......
......@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in list(inputs.items()):
for para, args in six.moves.iteritems(inputs):
op_desc.set_input(
para,
list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args)))
for para, args in list(outputs.items()):
for para, args in six.moves.iteritems(outputs):
op_desc.set_output(
para,
list(
......@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in list(attrs.items()):
for name, val in six.moves.iteritems(attrs):
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
else:
......@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
for var_name, inputs in list(renamed_vars.items()):
for var_name, inputs in six.moves.iteritems(renamed_vars):
if len(inputs) > 1:
pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
......@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name)
var_map[name] = new_name
for g, ng in list(var_map.items()):
for g, ng in six.moves.iteritems(var_map):
if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g)
......
......@@ -958,7 +958,7 @@ class Block(object):
return list(self.iter_parameters())
def iter_parameters(self):
return (item[1] for item in list(self.vars.items())
return (item[1] for item in six.moves.iteritems(self.vars)
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs):
......
......@@ -106,7 +106,7 @@ class Graph(object):
def _rank_repr(self):
ranks = sorted(
list(self.rank_groups.items()),
six.moves.iteritems(self.rank_groups),
key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority))
repr = []
......@@ -150,8 +150,9 @@ class Node(object):
reprs = '{name} [label={label} {extra} ];'.format(
name=self.name,
label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value))
for key, value in list(self.attrs.items()))
extra=',' + ','.join(
"%s=%s" % (key, crepr(value))
for key, value in six.moves.iteritems(self.attrs))
if self.attrs else "")
return reprs
......@@ -175,7 +176,7 @@ class Edge(object):
target=self.target.name,
extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in list(self.attrs.items())) + "]")
for attr in six.moves.iteritems(self.attrs)) + "]")
return repr
......
......@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu
from .ops import logical_and, logical_not, logical_or
import numpy
import warnings
import six
from functools import reduce
__all__ = [
......@@ -602,7 +603,7 @@ class StaticRNN(object):
boot_memories = []
pre_memories = []
memories = []
for _, mem in list(self.memories.items()):
for _, mem in six.moves.iteritems(self.memories):
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name)
......
......@@ -19,6 +19,7 @@ The metrics are accomplished via Python natively.
import numpy as np
import copy
import warnings
import six
__all__ = [
'MetricBase',
......@@ -79,10 +80,10 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in list(self.__dict__.items())
for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_")
}
for attr, value in list(states.items()):
for attr, value in six.moves.iteritems(states):
if isinstance(value, int):
setattr(self, attr, 0)
elif isinstance(value, float):
......@@ -105,7 +106,7 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in list(self.__dict__.items())
for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_")
}
config = {}
......
......@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self):
inputs = []
for name, value in list(self.inputs.items()):
for name, value in six.moves.iteritems(self.inputs):
if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name)
......@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self):
outputs = []
for var_name, var in list(self.outputs.items()):
for var_name, var in six.moves.iteritems(self.outputs):
if isinstance(var, list):
for sub_var_name, sub_var in var:
outputs.append(sub_var_name)
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import six
import sys
import collections
import math
......@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
for (label, label_pos_num) in list(label_count.items()):
for (label, label_pos_num) in six.moves.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
......
......@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor
import paddle.fluid.core as core
import numpy
import unittest
import six
class TestLoDRankTable(unittest.TestCase):
......@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items()))
self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table))
if __name__ == '__main__':
......
......@@ -15,6 +15,7 @@
import unittest
import itertools
import numpy as np
import six
from op_test import OpTest
......@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics
pos, neg, neu = 0, 0, 0
for _, ranks in list(predictions.items()):
for _, ranks in six.moves.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5
......
......@@ -16,6 +16,7 @@ import contextlib
import os
import errno
import shutil
import six
import time
from . import core
......@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list = sorted(
list(feed_order.items()), key=lambda item: item[1])
six.moves.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list
]
......@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in list(trainer_args.items()):
for name, value in six.moves.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
......
......@@ -218,7 +218,8 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(self.grad_var_mapping.items())
grad_var_mapping_items = list(
six.moves.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up:
random.seed(self.origin_program.random_seed)
......@@ -279,7 +280,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in list(self.param_var_mapping.items()):
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
......@@ -303,7 +304,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in list(self.param_var_mapping.items()):
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
......@@ -560,7 +561,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = collections.OrderedDict()
for _, var in list(pserver_vars.items()):
for _, var in six.moves.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar
......@@ -997,7 +998,7 @@ class DistributeTranspiler(object):
block_map[varname] = []
block_map[varname].append((int(offset), int(size)))
for varname, splited in list(block_map.items()):
for varname, splited in six.moves.iteritems(block_map):
orig_var = program.global_block().var(varname)
if len(splited) == 1:
if self.sync_mode and add_trainer_suffix:
......@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict):
grad_block = None
# TODO(minqiyang): replace these items() with six.iteritems() to
# improve memory
for _, g in list(var_dict.items()):
for _, g in six.moves.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
......@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in list(inputs.items()):
for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in list(outputs.items()):
for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in list(inputs.items()):
for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in list(outputs.items()):
for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册