提交 ae39709e 编写于 作者: M minqiyang

Polish code

上级 55d7f55c
...@@ -208,11 +208,8 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -208,11 +208,8 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1); proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS && if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) { boost::get<std::vector<int>>(v).size() == 0u) {
proto::OpProto proto = OpInfoMap::Instance().Get(Type()).Proto();
// Find current attr via attr name and set the correct attribute value // Find current attr via attr name and set the correct attribute value
for (int i = 0; i != proto.attrs_size(); ++i) { const proto::OpProto::Attr& attr = GetProtoAttr(name);
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
switch (attr.type()) { switch (attr.type()) {
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
...@@ -250,8 +247,6 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -250,8 +247,6 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true; need_update_ = true;
return; return;
} }
}
}
this->attrs_[name] = v; this->attrs_[name] = v;
need_update_ = true; need_update_ = true;
...@@ -280,6 +275,18 @@ Attribute OpDesc::GetAttr(const std::string &name) const { ...@@ -280,6 +275,18 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) {
proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return attr;
}
}
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}
Attribute OpDesc::GetNullableAttr(const std::string &name) const { Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
if (it != attrs_.end()) { if (it != attrs_.end()) {
......
...@@ -81,6 +81,8 @@ class OpDesc { ...@@ -81,6 +81,8 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
......
...@@ -55,9 +55,8 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -55,9 +55,8 @@ def reader_creator(filename, sub_name, cycle=False):
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = [ names = (each_item.name for each_item in f
each_item.name for each_item in f if sub_name in each_item.name if sub_name in each_item.name)
]
while True: while True:
for name in names: for name in names:
......
...@@ -20,7 +20,6 @@ import shutil ...@@ -20,7 +20,6 @@ import shutil
import sys import sys
import importlib import importlib
import paddle.dataset import paddle.dataset
import paddle.fluid.compat as cpt
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import glob import glob
......
...@@ -90,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name): ...@@ -90,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name):
labels = [] labels = []
one_seg = [] one_seg = []
for word, label in zip(words_file, props_file): for word, label in zip(words_file, props_file):
word = cpt.to_literal_str(word.strip()) word = cpt.to_text(word.strip())
label = cpt.to_literal_str(label.strip().split()) label = cpt.to_text(label.strip().split())
if len(label) == 0: # end of sentence if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])): for i in range(len(one_seg[0])):
......
...@@ -114,7 +114,7 @@ def __initialize_meta_info__(): ...@@ -114,7 +114,7 @@ def __initialize_meta_info__():
categories_set = set() categories_set = set()
with package.open('ml-1m/movies.dat') as movie_file: with package.open('ml-1m/movies.dat') as movie_file:
for i, line in enumerate(movie_file): for i, line in enumerate(movie_file):
line = cpt.to_literal_str(line, encoding='latin') line = cpt.to_text(line, encoding='latin')
movie_id, title, categories = line.strip().split('::') movie_id, title, categories = line.strip().split('::')
categories = categories.split('|') categories = categories.split('|')
for c in categories: for c in categories:
...@@ -139,7 +139,7 @@ def __initialize_meta_info__(): ...@@ -139,7 +139,7 @@ def __initialize_meta_info__():
USER_INFO = dict() USER_INFO = dict()
with package.open('ml-1m/users.dat') as user_file: with package.open('ml-1m/users.dat') as user_file:
for line in user_file: for line in user_file:
line = cpt.to_literal_str(line, encoding='latin') line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::") uid, gender, age, job, _ = line.strip().split("::")
USER_INFO[int(uid)] = UserInfo( USER_INFO[int(uid)] = UserInfo(
index=uid, gender=gender, age=age, job_id=job) index=uid, gender=gender, age=age, job_id=job)
...@@ -152,7 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): ...@@ -152,7 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
with zipfile.ZipFile(file=fn) as package: with zipfile.ZipFile(file=fn) as package:
with package.open('ml-1m/ratings.dat') as rating: with package.open('ml-1m/ratings.dat') as rating:
for line in rating: for line in rating:
line = cpt.to_literal_str(line, encoding='latin') line = cpt.to_text(line, encoding='latin')
if (rand.random() < test_ratio) == is_test: if (rand.random() < test_ratio) == is_test:
uid, mov_id, rating, _ = line.strip().split("::") uid, mov_id, rating, _ = line.strip().split("::")
uid = int(uid) uid = int(uid)
......
...@@ -55,7 +55,7 @@ def __read_to_dict(tar_file, dict_size): ...@@ -55,7 +55,7 @@ def __read_to_dict(tar_file, dict_size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
out_dict[cpt.to_literal_str(line.strip())] = line_count out_dict[cpt.to_text(line.strip())] = line_count
else: else:
break break
return out_dict return out_dict
......
...@@ -89,9 +89,9 @@ def __load_dict(tar_file, dict_size, lang, reverse=False): ...@@ -89,9 +89,9 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
with open(dict_path, "rb") as fdict: with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = cpt.to_literal_str(line.strip()) word_dict[idx] = cpt.to_text(line.strip())
else: else:
word_dict[cpt.to_literal_str(line.strip())] = idx word_dict[cpt.to_text(line.strip())] = idx
return word_dict return word_dict
......
...@@ -103,8 +103,8 @@ def _some_in_set_(cands, s): ...@@ -103,8 +103,8 @@ def _some_in_set_(cands, s):
""" """
if len(cands) == 0: if len(cands) == 0:
return False return False
literal_set = cpt.to_literal_str(s) literal_set = cpt.to_text(s)
literal_cands = cpt.to_literal_str(cands) literal_cands = cpt.to_text(cands)
for c in literal_cands: for c in literal_cands:
if c in literal_set: if c in literal_set:
return True return True
...@@ -117,7 +117,7 @@ def _strip_grad_suffix_(name): ...@@ -117,7 +117,7 @@ def _strip_grad_suffix_(name):
e.g. x@GRAD ==> x e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y y@GRAD@RENAME@1 ==> y
""" """
name = cpt.to_literal_str(name) name = cpt.to_text(name)
pos = name.find(core.grad_var_suffix()) pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name return name[:pos] if pos != -1 else name
...@@ -127,7 +127,7 @@ def _append_grad_suffix_(name): ...@@ -127,7 +127,7 @@ def _append_grad_suffix_(name):
Append grad suffix to the given variable name Append grad suffix to the given variable name
e.g. x ==> x@GRAD e.g. x ==> x@GRAD
""" """
return cpt.to_literal_str(name) + core.grad_var_suffix() return cpt.to_text(name) + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs): def _addup_repetitive_outputs_(op_descs):
...@@ -365,7 +365,7 @@ def _append_backward_ops_(block, ...@@ -365,7 +365,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, op.desc,
cpt.to_literal_str(no_grad_dict[block.idx]), grad_sub_block_list) cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -600,7 +600,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -600,7 +600,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads = [] params_and_grads = []
for param in parameters: for param in parameters:
if cpt.to_literal_str(param) not in grad_info_map: if cpt.to_text(param) not in grad_info_map:
continue continue
grad_info = grad_info_map[param] grad_info = grad_info_map[param]
grad_block = grad_info[1] grad_block = grad_info[1]
......
...@@ -17,7 +17,7 @@ import math ...@@ -17,7 +17,7 @@ import math
__all__ = [ __all__ = [
'long_type', 'long_type',
'to_literal_str', 'to_text',
'to_bytes', 'to_bytes',
'round', 'round',
'floor_division', 'floor_division',
...@@ -33,7 +33,7 @@ else: ...@@ -33,7 +33,7 @@ else:
# str and bytes related functions # str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False): def to_text(obj, encoding='utf-8', inplace=False):
""" """
All string in PaddlePaddle should be represented as a literal string. All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding. This function will convert object to a literal string without any encoding.
...@@ -60,23 +60,23 @@ def to_literal_str(obj, encoding='utf-8', inplace=False): ...@@ -60,23 +60,23 @@ def to_literal_str(obj, encoding='utf-8', inplace=False):
if isinstance(obj, list): if isinstance(obj, list):
if inplace: if inplace:
for i in six.moves.xrange(len(obj)): for i in six.moves.xrange(len(obj)):
obj[i] = _to_literal_str(obj[i], encoding) obj[i] = _to_text(obj[i], encoding)
return obj return obj
else: else:
return [_to_literal_str(item, encoding) for item in obj] return [_to_text(item, encoding) for item in obj]
elif isinstance(obj, set): elif isinstance(obj, set):
if inplace: if inplace:
for item in obj: for item in obj:
obj.remove(item) obj.remove(item)
obj.add(_to_literal_str(item, encoding)) obj.add(_to_text(item, encoding))
return obj return obj
else: else:
return set([_to_literal_str(item, encoding) for item in obj]) return set([_to_text(item, encoding) for item in obj])
else: else:
return _to_literal_str(obj, encoding) return _to_text(obj, encoding)
def _to_literal_str(obj, encoding): def _to_text(obj, encoding):
""" """
In Python3: In Python3:
Decode the bytes type object to str type with specific encoding Decode the bytes type object to str type with specific encoding
......
...@@ -199,7 +199,7 @@ class Variable(object): ...@@ -199,7 +199,7 @@ class Variable(object):
if name is None: if name is None:
name = unique_name.generate('_generated_var') name = unique_name.generate('_generated_var')
is_new_var = False is_new_var = False
name = cpt.to_literal_str(name) name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name)) self.desc = self.block.desc.find_var(cpt.to_bytes(name))
if self.desc is None: if self.desc is None:
...@@ -326,7 +326,7 @@ class Variable(object): ...@@ -326,7 +326,7 @@ class Variable(object):
@property @property
def name(self): def name(self):
return cpt.to_literal_str(self.desc.name()) return cpt.to_text(self.desc.name())
@name.setter @name.setter
def name(self, new_name): def name(self, new_name):
...@@ -530,7 +530,7 @@ class Operator(object): ...@@ -530,7 +530,7 @@ class Operator(object):
elif isinstance(arg, six.binary_type): elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode()) in_arg_names.append(arg.decode())
else: else:
in_arg_names.append(cpt.to_literal_str(arg.name)) in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names) self.desc.set_input(in_proto.name, in_arg_names)
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
...@@ -559,7 +559,7 @@ class Operator(object): ...@@ -559,7 +559,7 @@ class Operator(object):
(out_proto.name, len(out_args))) (out_proto.name, len(out_args)))
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
out_arg_names.append(cpt.to_literal_str(arg.name)) out_arg_names.append(cpt.to_text(arg.name))
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
...@@ -986,8 +986,8 @@ class Block(object): ...@@ -986,8 +986,8 @@ class Block(object):
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
name = cpt.to_literal_str(name) name = cpt.to_text(name)
new_name = cpt.to_literal_str(new_name) new_name = cpt.to_text(new_name)
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current block" % name) raise ValueError("var %s is not in current block" % name)
......
...@@ -155,13 +155,13 @@ class ParallelExecutor(object): ...@@ -155,13 +155,13 @@ class ParallelExecutor(object):
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
self._places, self._places,
set([ set([
cpt.to_literal_str(p.name) cpt.to_text(p.name)
for p in main.global_block().iter_parameters() for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(cpt.to_literal_str(var) set(cpt.to_text(var)
for var in self.persistable_vars), main.desc, for var in self.persistable_vars), main.desc,
cpt.to_literal_str(loss_name) cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy, if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
...@@ -275,7 +275,7 @@ class ParallelExecutor(object): ...@@ -275,7 +275,7 @@ class ParallelExecutor(object):
fetch_var_name = '@FETCHED_VAR_NAME@' fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run( self.executor.run(
cpt.to_literal_str(fetch_list), cpt.to_literal_str(fetch_var_name)) cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist: if self.is_dist:
......
...@@ -26,44 +26,44 @@ class TestCompatible(unittest.TestCase): ...@@ -26,44 +26,44 @@ class TestCompatible(unittest.TestCase):
self.assertEqual(cpt.int_type, int) self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int) self.assertEqual(cpt.long_type, int)
def test_to_literal_str(self): def test_to_text(self):
# Only support python2.x and python3.x now # Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3) self.assertTrue(six.PY2 | six.PY3)
if six.PY2: if six.PY2:
# check None # check None
self.assertIsNone(cpt.to_literal_str(None)) self.assertIsNone(cpt.to_text(None))
# check all string related types # check all string related types
self.assertTrue(isinstance(cpt.to_literal_str(str("")), unicode)) self.assertTrue(isinstance(cpt.to_text(str("")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), unicode)) self.assertTrue(isinstance(cpt.to_text(str("123")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertEqual(u"", cpt.to_literal_str(str(""))) self.assertEqual(u"", cpt.to_text(str("")))
self.assertEqual(u"123", cpt.to_literal_str(str("123"))) self.assertEqual(u"123", cpt.to_text(str("123")))
self.assertEqual(u"", cpt.to_literal_str(b"")) self.assertEqual(u"", cpt.to_text(b""))
self.assertEqual(u"123", cpt.to_literal_str(b"123")) self.assertEqual(u"123", cpt.to_text(b"123"))
self.assertEqual(u"", cpt.to_literal_str(u"")) self.assertEqual(u"", cpt.to_text(u""))
self.assertEqual(u"123", cpt.to_literal_str(u"123")) self.assertEqual(u"123", cpt.to_text(u"123"))
# check list types, not inplace # check list types, not inplace
l = [""] l = [""]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([u""], l2) self.assertEqual([u""], l2)
l = ["", "123"] l = ["", "123"]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2) self.assertEqual([u"", u"123"], l2)
l = ["", b'123', u"321"] l = ["", b'123', u"321"]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
...@@ -73,19 +73,19 @@ class TestCompatible(unittest.TestCase): ...@@ -73,19 +73,19 @@ class TestCompatible(unittest.TestCase):
# check list types, inplace # check list types, inplace
l = [""] l = [""]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([u""], l2) self.assertEqual([u""], l2)
l = ["", "123"] l = ["", "123"]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2) self.assertEqual([u"", u"123"], l2)
l = ["", b"123", u"321"] l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
...@@ -93,19 +93,19 @@ class TestCompatible(unittest.TestCase): ...@@ -93,19 +93,19 @@ class TestCompatible(unittest.TestCase):
# check set types, not inplace # check set types, not inplace
l = set("") l = set("")
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set(u""), l2) self.assertEqual(set(u""), l2)
l = set([b"", b"123"]) l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2) self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"]) l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
...@@ -115,56 +115,56 @@ class TestCompatible(unittest.TestCase): ...@@ -115,56 +115,56 @@ class TestCompatible(unittest.TestCase):
# check set types, inplace # check set types, inplace
l = set("") l = set("")
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set(u""), l2) self.assertEqual(set(u""), l2)
l = set([b"", b"123"]) l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2) self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"]) l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2) self.assertEqual(set([u"", u"123", u"321"]), l2)
elif six.PY3: elif six.PY3:
self.assertIsNone(cpt.to_literal_str(None)) self.assertIsNone(cpt.to_text(None))
self.assertTrue(isinstance(cpt.to_literal_str(str("")), str)) self.assertTrue(isinstance(cpt.to_text(str("")), str))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), str)) self.assertTrue(isinstance(cpt.to_text(str("123")), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertEqual("", cpt.to_literal_str(str(""))) self.assertEqual("", cpt.to_text(str("")))
self.assertEqual("123", cpt.to_literal_str(str("123"))) self.assertEqual("123", cpt.to_text(str("123")))
self.assertEqual("", cpt.to_literal_str(b"")) self.assertEqual("", cpt.to_text(b""))
self.assertEqual("123", cpt.to_literal_str(b"123")) self.assertEqual("123", cpt.to_text(b"123"))
self.assertEqual("", cpt.to_literal_str(u"")) self.assertEqual("", cpt.to_text(u""))
self.assertEqual("123", cpt.to_literal_str(u"123")) self.assertEqual("123", cpt.to_text(u"123"))
# check list types, not inplace # check list types, not inplace
l = [""] l = [""]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([""], l2) self.assertEqual([""], l2)
l = ["", "123"] l = ["", "123"]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2) self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"] l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertNotEqual(l, l2) self.assertNotEqual(l, l2)
...@@ -172,19 +172,19 @@ class TestCompatible(unittest.TestCase): ...@@ -172,19 +172,19 @@ class TestCompatible(unittest.TestCase):
# check list types, inplace # check list types, inplace
l = [""] l = [""]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([""], l2) self.assertEqual([""], l2)
l = ["", b"123"] l = ["", b"123"]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2) self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"] l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
...@@ -194,19 +194,19 @@ class TestCompatible(unittest.TestCase): ...@@ -194,19 +194,19 @@ class TestCompatible(unittest.TestCase):
# check set types, not inplace # check set types, not inplace
l = set("") l = set("")
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set(""), l2) self.assertEqual(set(""), l2)
l = set([b"", b"123"]) l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertNotEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123"]), l2) self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"]) l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertNotEqual(l, l2) self.assertNotEqual(l, l2)
...@@ -214,19 +214,19 @@ class TestCompatible(unittest.TestCase): ...@@ -214,19 +214,19 @@ class TestCompatible(unittest.TestCase):
# check set types, inplace # check set types, inplace
l = set("") l = set("")
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set(""), l2) self.assertEqual(set(""), l2)
l = set([b"", b"123"]) l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual(set(["", "123"]), l2) self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"]) l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2) self.assertTrue(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
......
...@@ -186,7 +186,7 @@ class TestDistBase(unittest.TestCase): ...@@ -186,7 +186,7 @@ class TestDistBase(unittest.TestCase):
env=env_local) env=env_local)
local_proc.wait() local_proc.wait()
out, err = local_proc.communicate() out, err = local_proc.communicate()
local_ret = cpt.to_literal_str(out) local_ret = cpt.to_text(out)
sys.stderr.write('local_loss: %s\n' % local_ret) sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err) sys.stderr.write('local_stderr: %s\n' % err)
...@@ -224,7 +224,7 @@ class TestDistBase(unittest.TestCase): ...@@ -224,7 +224,7 @@ class TestDistBase(unittest.TestCase):
tr1_proc.wait() tr1_proc.wait()
out, err = tr0_proc.communicate() out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err) sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = cpt.to_literal_str(out) loss_data0 = cpt.to_text(out)
sys.stderr.write('dist_loss: %s\n' % loss_data0) sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n") lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0] dist_first_loss = eval(lines[0].replace(" ", ","))[0]
......
...@@ -260,7 +260,7 @@ class ControlFlowGraph(object): ...@@ -260,7 +260,7 @@ class ControlFlowGraph(object):
# memory allocated in order to reuse the memory. # memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i) _rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var( self._program.block(block_desc.id).var(
cpt.to_literal_str(x)).desc = self._find_var( cpt.to_text(x)).desc = self._find_var(
block_desc, cache_var, is_forward) block_desc, cache_var, is_forward)
self._update_graph(x, cache_var, begin_idx=i) self._update_graph(x, cache_var, begin_idx=i)
break break
......
...@@ -391,9 +391,9 @@ class PipeReader: ...@@ -391,9 +391,9 @@ class PipeReader:
buff = self.process.stdout.read(self.bufsize) buff = self.process.stdout.read(self.bufsize)
if buff: if buff:
if self.file_type == "gzip": if self.file_type == "gzip":
decomp_buff = cpt.to_literal_str(self.dec.decompress(buff)) decomp_buff = cpt.to_text(self.dec.decompress(buff))
elif self.file_type == "plain": elif self.file_type == "plain":
decomp_buff = cpt.to_literal_str(buff) decomp_buff = cpt.to_text(buff)
else: else:
raise TypeError("file_type %s is not allowed" % raise TypeError("file_type %s is not allowed" %
self.file_type) self.file_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册