提交 ae39709e 编写于 作者: M minqiyang

Polish code

上级 55d7f55c
......@@ -208,11 +208,8 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) {
proto::OpProto proto = OpInfoMap::Instance().Get(Type()).Proto();
// Find current attr via attr name and set the correct attribute value
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
const proto::OpProto::Attr& attr = GetProtoAttr(name);
switch (attr.type()) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
......@@ -250,8 +247,6 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true;
return;
}
}
}
this->attrs_[name] = v;
need_update_ = true;
......@@ -280,6 +275,18 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}
const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) {
proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return attr;
}
}
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
......
......@@ -81,6 +81,8 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
......
......@@ -55,9 +55,8 @@ def reader_creator(filename, sub_name, cycle=False):
def reader():
with tarfile.open(filename, mode='r') as f:
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
names = (each_item.name for each_item in f
if sub_name in each_item.name)
while True:
for name in names:
......
......@@ -20,7 +20,6 @@ import shutil
import sys
import importlib
import paddle.dataset
import paddle.fluid.compat as cpt
import six.moves.cPickle as pickle
import glob
......
......@@ -90,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name):
labels = []
one_seg = []
for word, label in zip(words_file, props_file):
word = cpt.to_literal_str(word.strip())
label = cpt.to_literal_str(label.strip().split())
word = cpt.to_text(word.strip())
label = cpt.to_text(label.strip().split())
if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])):
......
......@@ -114,7 +114,7 @@ def __initialize_meta_info__():
categories_set = set()
with package.open('ml-1m/movies.dat') as movie_file:
for i, line in enumerate(movie_file):
line = cpt.to_literal_str(line, encoding='latin')
line = cpt.to_text(line, encoding='latin')
movie_id, title, categories = line.strip().split('::')
categories = categories.split('|')
for c in categories:
......@@ -139,7 +139,7 @@ def __initialize_meta_info__():
USER_INFO = dict()
with package.open('ml-1m/users.dat') as user_file:
for line in user_file:
line = cpt.to_literal_str(line, encoding='latin')
line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::")
USER_INFO[int(uid)] = UserInfo(
index=uid, gender=gender, age=age, job_id=job)
......@@ -152,7 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
with zipfile.ZipFile(file=fn) as package:
with package.open('ml-1m/ratings.dat') as rating:
for line in rating:
line = cpt.to_literal_str(line, encoding='latin')
line = cpt.to_text(line, encoding='latin')
if (rand.random() < test_ratio) == is_test:
uid, mov_id, rating, _ = line.strip().split("::")
uid = int(uid)
......
......@@ -55,7 +55,7 @@ def __read_to_dict(tar_file, dict_size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
out_dict[cpt.to_literal_str(line.strip())] = line_count
out_dict[cpt.to_text(line.strip())] = line_count
else:
break
return out_dict
......
......@@ -89,9 +89,9 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = cpt.to_literal_str(line.strip())
word_dict[idx] = cpt.to_text(line.strip())
else:
word_dict[cpt.to_literal_str(line.strip())] = idx
word_dict[cpt.to_text(line.strip())] = idx
return word_dict
......
......@@ -103,8 +103,8 @@ def _some_in_set_(cands, s):
"""
if len(cands) == 0:
return False
literal_set = cpt.to_literal_str(s)
literal_cands = cpt.to_literal_str(cands)
literal_set = cpt.to_text(s)
literal_cands = cpt.to_text(cands)
for c in literal_cands:
if c in literal_set:
return True
......@@ -117,7 +117,7 @@ def _strip_grad_suffix_(name):
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
name = cpt.to_literal_str(name)
name = cpt.to_text(name)
pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name
......@@ -127,7 +127,7 @@ def _append_grad_suffix_(name):
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
return cpt.to_literal_str(name) + core.grad_var_suffix()
return cpt.to_text(name) + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs):
......@@ -365,7 +365,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc,
cpt.to_literal_str(no_grad_dict[block.idx]), grad_sub_block_list)
cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......@@ -600,7 +600,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads = []
for param in parameters:
if cpt.to_literal_str(param) not in grad_info_map:
if cpt.to_text(param) not in grad_info_map:
continue
grad_info = grad_info_map[param]
grad_block = grad_info[1]
......
......@@ -17,7 +17,7 @@ import math
__all__ = [
'long_type',
'to_literal_str',
'to_text',
'to_bytes',
'round',
'floor_division',
......@@ -33,7 +33,7 @@ else:
# str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False):
def to_text(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding.
......@@ -60,23 +60,23 @@ def to_literal_str(obj, encoding='utf-8', inplace=False):
if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_literal_str(obj[i], encoding)
obj[i] = _to_text(obj[i], encoding)
return obj
else:
return [_to_literal_str(item, encoding) for item in obj]
return [_to_text(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_literal_str(item, encoding))
obj.add(_to_text(item, encoding))
return obj
else:
return set([_to_literal_str(item, encoding) for item in obj])
return set([_to_text(item, encoding) for item in obj])
else:
return _to_literal_str(obj, encoding)
return _to_text(obj, encoding)
def _to_literal_str(obj, encoding):
def _to_text(obj, encoding):
"""
In Python3:
Decode the bytes type object to str type with specific encoding
......
......@@ -199,7 +199,7 @@ class Variable(object):
if name is None:
name = unique_name.generate('_generated_var')
is_new_var = False
name = cpt.to_literal_str(name)
name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name))
if self.desc is None:
......@@ -326,7 +326,7 @@ class Variable(object):
@property
def name(self):
return cpt.to_literal_str(self.desc.name())
return cpt.to_text(self.desc.name())
@name.setter
def name(self, new_name):
......@@ -530,7 +530,7 @@ class Operator(object):
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
in_arg_names.append(cpt.to_literal_str(arg.name))
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......@@ -559,7 +559,7 @@ class Operator(object):
(out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
out_arg_names.append(cpt.to_literal_str(arg.name))
out_arg_names.append(cpt.to_text(arg.name))
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
......@@ -986,8 +986,8 @@ class Block(object):
Returns:
Variable: the Variable with the giving name.
"""
name = cpt.to_literal_str(name)
new_name = cpt.to_literal_str(new_name)
name = cpt.to_text(name)
new_name = cpt.to_text(new_name)
if not self.has_var(name):
raise ValueError("var %s is not in current block" % name)
......
......@@ -155,13 +155,13 @@ class ParallelExecutor(object):
self.executor = core.ParallelExecutor(
self._places,
set([
cpt.to_literal_str(p.name)
cpt.to_text(p.name)
for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(cpt.to_literal_str(var)
set(cpt.to_text(var)
for var in self.persistable_vars), main.desc,
cpt.to_literal_str(loss_name)
cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id)
self.scope = scope
......@@ -275,7 +275,7 @@ class ParallelExecutor(object):
fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run(
cpt.to_literal_str(fetch_list), cpt.to_literal_str(fetch_var_name))
cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist:
......
......@@ -26,44 +26,44 @@ class TestCompatible(unittest.TestCase):
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_literal_str(self):
def test_to_text(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
if six.PY2:
# check None
self.assertIsNone(cpt.to_literal_str(None))
self.assertIsNone(cpt.to_text(None))
# check all string related types
self.assertTrue(isinstance(cpt.to_literal_str(str("")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode))
self.assertEqual(u"", cpt.to_literal_str(str("")))
self.assertEqual(u"123", cpt.to_literal_str(str("123")))
self.assertEqual(u"", cpt.to_literal_str(b""))
self.assertEqual(u"123", cpt.to_literal_str(b"123"))
self.assertEqual(u"", cpt.to_literal_str(u""))
self.assertEqual(u"123", cpt.to_literal_str(u"123"))
self.assertTrue(isinstance(cpt.to_text(str("")), unicode))
self.assertTrue(isinstance(cpt.to_text(str("123")), unicode))
self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertEqual(u"", cpt.to_text(str("")))
self.assertEqual(u"123", cpt.to_text(str("123")))
self.assertEqual(u"", cpt.to_text(b""))
self.assertEqual(u"123", cpt.to_text(b"123"))
self.assertEqual(u"", cpt.to_text(u""))
self.assertEqual(u"123", cpt.to_text(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
......@@ -73,19 +73,19 @@ class TestCompatible(unittest.TestCase):
# check list types, inplace
l = [""]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
......@@ -93,19 +93,19 @@ class TestCompatible(unittest.TestCase):
# check set types, not inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
......@@ -115,56 +115,56 @@ class TestCompatible(unittest.TestCase):
# check set types, inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2)
elif six.PY3:
self.assertIsNone(cpt.to_literal_str(None))
self.assertTrue(isinstance(cpt.to_literal_str(str("")), str))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str))
self.assertEqual("", cpt.to_literal_str(str("")))
self.assertEqual("123", cpt.to_literal_str(str("123")))
self.assertEqual("", cpt.to_literal_str(b""))
self.assertEqual("123", cpt.to_literal_str(b"123"))
self.assertEqual("", cpt.to_literal_str(u""))
self.assertEqual("123", cpt.to_literal_str(u"123"))
self.assertIsNone(cpt.to_text(None))
self.assertTrue(isinstance(cpt.to_text(str("")), str))
self.assertTrue(isinstance(cpt.to_text(str("123")), str))
self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertEqual("", cpt.to_text(str("")))
self.assertEqual("123", cpt.to_text(str("123")))
self.assertEqual("", cpt.to_text(b""))
self.assertEqual("123", cpt.to_text(b"123"))
self.assertEqual("", cpt.to_text(u""))
self.assertEqual("123", cpt.to_text(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l)
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
......@@ -172,19 +172,19 @@ class TestCompatible(unittest.TestCase):
# check list types, inplace
l = [""]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", b"123"]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
......@@ -194,19 +194,19 @@ class TestCompatible(unittest.TestCase):
# check set types, not inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False)
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
......@@ -214,19 +214,19 @@ class TestCompatible(unittest.TestCase):
# check set types, inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True)
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
......
......@@ -186,7 +186,7 @@ class TestDistBase(unittest.TestCase):
env=env_local)
local_proc.wait()
out, err = local_proc.communicate()
local_ret = cpt.to_literal_str(out)
local_ret = cpt.to_text(out)
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
......@@ -224,7 +224,7 @@ class TestDistBase(unittest.TestCase):
tr1_proc.wait()
out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = cpt.to_literal_str(out)
loss_data0 = cpt.to_text(out)
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
......
......@@ -260,7 +260,7 @@ class ControlFlowGraph(object):
# memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(
cpt.to_literal_str(x)).desc = self._find_var(
cpt.to_text(x)).desc = self._find_var(
block_desc, cache_var, is_forward)
self._update_graph(x, cache_var, begin_idx=i)
break
......
......@@ -391,9 +391,9 @@ class PipeReader:
buff = self.process.stdout.read(self.bufsize)
if buff:
if self.file_type == "gzip":
decomp_buff = cpt.to_literal_str(self.dec.decompress(buff))
decomp_buff = cpt.to_text(self.dec.decompress(buff))
elif self.file_type == "plain":
decomp_buff = cpt.to_literal_str(buff)
decomp_buff = cpt.to_text(buff)
else:
raise TypeError("file_type %s is not allowed" %
self.file_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册