提交 a3539845 编写于 作者: M minqiyang

Polish python code style

上级 e7c7cbaa
...@@ -47,21 +47,25 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' ...@@ -47,21 +47,25 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False): def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch): def read_batch(batch):
data = batch[six.b('data')] data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in six.moves.zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = [each_item.name for each_item in f if sub_name in each_item.name] names = [
each_item.name for each_item in f if sub_name in each_item.name
]
while True: while True:
for name in names: for name in names:
if six.PY2: if six.PY2:
batch = pickle.load(f.extractfile(name)) batch = pickle.load(f.extractfile(name))
else: else:
batch = pickle.load(f.extractfile(name), encoding='bytes') batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
if not cycle: if not cycle:
......
...@@ -215,7 +215,8 @@ def max_job_id(): ...@@ -215,7 +215,8 @@ def max_job_id():
Get the maximum value of job id. Get the maximum value of job id.
""" """
__initialize_meta_info__() __initialize_meta_info__()
return six.moves.reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id return six.moves.reduce(__max_job_id_impl__,
list(USER_INFO.values())).job_id
def movie_categories(): def movie_categories():
......
...@@ -23,6 +23,7 @@ __all__ = [ ...@@ -23,6 +23,7 @@ __all__ = [
'get_exception_message', 'get_exception_message',
] ]
# str and bytes related functions # str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False): def to_literal_str(obj, encoding='utf-8', inplace=False):
""" """
...@@ -181,10 +182,10 @@ def round(x, d=0): ...@@ -181,10 +182,10 @@ def round(x, d=0):
# The official walkaround of round in Python3 is incorrect # The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html # we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0: if x > 0.0:
p = 10 ** d p = 10**d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p return float(math.floor((x * p) + math.copysign(0.5, x))) / p
elif x < 0.0: elif x < 0.0:
p = 10 ** d p = 10**d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else: else:
return math.copysign(0.0, x) return math.copysign(0.0, x)
...@@ -208,6 +209,7 @@ def floor_division(x, y): ...@@ -208,6 +209,7 @@ def floor_division(x, y):
""" """
return x // y return x // y
# exception related functions # exception related functions
def get_exception_message(exc): def get_exception_message(exc):
""" """
...@@ -225,4 +227,3 @@ def get_exception_message(exc): ...@@ -225,4 +227,3 @@ def get_exception_message(exc):
return exc.message return exc.message
else: else:
return str(exc) return str(exc)
...@@ -320,8 +320,9 @@ class Executor(object): ...@@ -320,8 +320,9 @@ class Executor(object):
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, six.text_type), ( assert isinstance(var, Variable) or isinstance(
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) var, six.text_type), ("Wrong type for fetch_list[%s]: %s" %
(i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
inputs={'X': [var]}, inputs={'X': [var]},
......
...@@ -1104,9 +1104,8 @@ def multi_box_head(inputs, ...@@ -1104,9 +1104,8 @@ def multi_box_head(inputs,
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
new_shape = [ new_shape = [
mbox_loc.shape[0], mbox_loc.shape[0], mbox_loc.shape[1] * mbox_loc.shape[2] *
mbox_loc.shape[1] * mbox_loc.shape[2] * cpt.floor_division(mbox_loc.shape[3], 4), cpt.floor_division(mbox_loc.shape[3], 4), 4
4
] ]
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten) mbox_locs.append(mbox_loc_flatten)
...@@ -1121,9 +1120,8 @@ def multi_box_head(inputs, ...@@ -1121,9 +1120,8 @@ def multi_box_head(inputs,
stride=stride) stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [ new_shape = [
conf_loc.shape[0], conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[1] * conf_loc.shape[2] * cpt.floor_division(conf_loc.shape[3], num_classes), cpt.floor_division(conf_loc.shape[3], num_classes), num_classes
num_classes
] ]
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten) mbox_confs.append(conf_loc_flatten)
......
...@@ -45,15 +45,17 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' ...@@ -45,15 +45,17 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
def reader_creator(filename, sub_name, batch_size=None): def reader_creator(filename, sub_name, batch_size=None):
def read_batch(batch): def read_batch(batch):
data = batch[six.b('data')] data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in six.moves.zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = [each_item.name for each_item in f names = [
if sub_name in each_item.name] each_item.name for each_item in f if sub_name in each_item.name
]
batch_count = 0 batch_count = 0
for name in names: for name in names:
......
...@@ -63,7 +63,6 @@ class TestCompatible(unittest.TestCase): ...@@ -63,7 +63,6 @@ class TestCompatible(unittest.TestCase):
for i in l2: for i in l2:
self.assertTrue(isinstance(i, unicode)) self.assertTrue(isinstance(i, unicode))
# check list types, inplace # check list types, inplace
l = [""] l = [""]
l2 = cpt.to_literal_str(l, inplace=True) l2 = cpt.to_literal_str(l, inplace=True)
...@@ -272,7 +271,6 @@ class TestCompatible(unittest.TestCase): ...@@ -272,7 +271,6 @@ class TestCompatible(unittest.TestCase):
for i in l2: for i in l2:
self.assertTrue(isinstance(i, bytes)) self.assertTrue(isinstance(i, bytes))
# check list types, inplace # check list types, inplace
l = [""] l = [""]
l2 = cpt.to_bytes(l, inplace=True) l2 = cpt.to_bytes(l, inplace=True)
...@@ -461,30 +459,35 @@ class TestCompatible(unittest.TestCase): ...@@ -461,30 +459,35 @@ class TestCompatible(unittest.TestCase):
exception_message = "test_message" exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None) self.assertRaises(AssertionError, cpt.get_exception_message, None)
if six.PY2: if six.PY2:
self.assertRaises(AttributeError, cpt.get_exception_message, exception_message) self.assertRaises(AttributeError, cpt.get_exception_message,
exception_message)
try: try:
raise RuntimeError(exception_message) raise RuntimeError(exception_message)
except Exception as e: except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e)) self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e) self.assertIsNotNone(e)
try: try:
raise Exception(exception_message) raise Exception(exception_message)
except Exception as e: except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e)) self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e) self.assertIsNotNone(e)
if six.PY3: if six.PY3:
try: try:
raise RuntimeError(exception_message) raise RuntimeError(exception_message)
except Exception as e: except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e)) self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e) self.assertIsNotNone(e)
try: try:
raise Exception(exception_message) raise Exception(exception_message)
except Exception as e: except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e)) self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e) self.assertIsNotNone(e)
......
...@@ -39,7 +39,8 @@ class TestGRUOp(OpTest): ...@@ -39,7 +39,8 @@ class TestGRUOp(OpTest):
for i in range(len(seq_lens)): for i in range(len(seq_lens)):
seq_starts.append(seq_starts[-1] + seq_lens[i]) seq_starts.append(seq_starts[-1] + seq_lens[i])
sorted_seqs = sorted( sorted_seqs = sorted(
list(range(len(seq_lens))), key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x])) list(range(len(seq_lens))),
key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
num_batch = seq_lens[sorted_seqs[0]] num_batch = seq_lens[sorted_seqs[0]]
for batch_idx in range(num_batch): for batch_idx in range(num_batch):
idx_in_seq = [] idx_in_seq = []
......
...@@ -36,7 +36,8 @@ class TestOperator(unittest.TestCase): ...@@ -36,7 +36,8 @@ class TestOperator(unittest.TestCase):
block.append_op(type="no_such_op") block.append_op(type="no_such_op")
self.assertFail() self.assertFail()
except ValueError as a_err: except ValueError as a_err:
self.assertEqual(cpt.get_exception_message(a_err), self.assertEqual(
cpt.get_exception_message(a_err),
"Operator \"no_such_op\" has not been registered.") "Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self): def test_op_desc_creation(self):
......
...@@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x, ...@@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x,
if global_pool == 1: if global_pool == 1:
ksize = [D, H, W] ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * ) // strides[0] + 1 if ceil_mode else (
paddings[0]) // strides[0] + 1 H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * ) // strides[1] + 1 if ceil_mode else (
paddings[1]) // strides[1] + 1 W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * ) // strides[2] + 1 if ceil_mode else (
paddings[2]) // strides[2] + 1 W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out)) out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out): for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0)) d_start = np.max((k * strides[0] - paddings[0], 0))
...@@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x, ...@@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x,
if global_pool == 1: if global_pool == 1:
ksize = [D, H, W] ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * ) // strides[0] + 1 if ceil_mode else (
paddings[0]) // strides[0] + 1 H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * ) // strides[1] + 1 if ceil_mode else (
paddings[1]) // strides[1] + 1 W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * ) // strides[2] + 1 if ceil_mode else (
paddings[2]) // strides[2] + 1 W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out)) out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out): for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0)) d_start = np.max((k * strides[0] - paddings[0], 0))
......
...@@ -259,9 +259,9 @@ class ControlFlowGraph(object): ...@@ -259,9 +259,9 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with # Rename the var to the cache var already with
# memory allocated in order to reuse the memory. # memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i) _rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(cpt.to_literal_str( self._program.block(block_desc.id).var(
x)).desc = self._find_var(block_desc, cache_var, cpt.to_literal_str(x)).desc = self._find_var(
is_forward) block_desc, cache_var, is_forward)
self._update_graph(x, cache_var, begin_idx=i) self._update_graph(x, cache_var, begin_idx=i)
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册