提交 a3539845 编写于 作者: M minqiyang

Polish python code style

上级 e7c7cbaa
......@@ -47,21 +47,25 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch):
data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = [each_item.name for each_item in f if sub_name in each_item.name]
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
while True:
for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
if not cycle:
......
......@@ -215,7 +215,8 @@ def max_job_id():
Get the maximum value of job id.
"""
__initialize_meta_info__()
return six.moves.reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id
return six.moves.reduce(__max_job_id_impl__,
list(USER_INFO.values())).job_id
def movie_categories():
......
......@@ -23,6 +23,7 @@ __all__ = [
'get_exception_message',
]
# str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False):
"""
......@@ -181,10 +182,10 @@ def round(x, d=0):
# The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0:
p = 10 ** d
p = 10**d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p
elif x < 0.0:
p = 10 ** d
p = 10**d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else:
return math.copysign(0.0, x)
......@@ -208,6 +209,7 @@ def floor_division(x, y):
"""
return x // y
# exception related functions
def get_exception_message(exc):
"""
......@@ -225,4 +227,3 @@ def get_exception_message(exc):
return exc.message
else:
return str(exc)
......@@ -320,8 +320,9 @@ class Executor(object):
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, six.text_type), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
assert isinstance(var, Variable) or isinstance(
var, six.text_type), ("Wrong type for fetch_list[%s]: %s" %
(i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
......
......@@ -1104,9 +1104,8 @@ def multi_box_head(inputs,
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
new_shape = [
mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * cpt.floor_division(mbox_loc.shape[3], 4),
4
mbox_loc.shape[0], mbox_loc.shape[1] * mbox_loc.shape[2] *
cpt.floor_division(mbox_loc.shape[3], 4), 4
]
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten)
......@@ -1121,9 +1120,8 @@ def multi_box_head(inputs,
stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [
conf_loc.shape[0],
conf_loc.shape[1] * conf_loc.shape[2] * cpt.floor_division(conf_loc.shape[3], num_classes),
num_classes
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
cpt.floor_division(conf_loc.shape[3], num_classes), num_classes
]
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten)
......
......@@ -45,15 +45,17 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
def reader_creator(filename, sub_name, batch_size=None):
def read_batch(batch):
data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = [each_item.name for each_item in f
if sub_name in each_item.name]
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
batch_count = 0
for name in names:
......
......@@ -63,7 +63,6 @@ class TestCompatible(unittest.TestCase):
for i in l2:
self.assertTrue(isinstance(i, unicode))
# check list types, inplace
l = [""]
l2 = cpt.to_literal_str(l, inplace=True)
......@@ -272,7 +271,6 @@ class TestCompatible(unittest.TestCase):
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check list types, inplace
l = [""]
l2 = cpt.to_bytes(l, inplace=True)
......@@ -461,30 +459,35 @@ class TestCompatible(unittest.TestCase):
exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None)
if six.PY2:
self.assertRaises(AttributeError, cpt.get_exception_message, exception_message)
self.assertRaises(AttributeError, cpt.get_exception_message,
exception_message)
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
if six.PY3:
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
......
......@@ -39,7 +39,8 @@ class TestGRUOp(OpTest):
for i in range(len(seq_lens)):
seq_starts.append(seq_starts[-1] + seq_lens[i])
sorted_seqs = sorted(
list(range(len(seq_lens))), key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
list(range(len(seq_lens))),
key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
num_batch = seq_lens[sorted_seqs[0]]
for batch_idx in range(num_batch):
idx_in_seq = []
......
......@@ -36,7 +36,8 @@ class TestOperator(unittest.TestCase):
block.append_op(type="no_such_op")
self.assertFail()
except ValueError as a_err:
self.assertEqual(cpt.get_exception_message(a_err),
self.assertEqual(
cpt.get_exception_message(a_err),
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self):
......
......@@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x,
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) // strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) // strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) // strides[2] + 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
......@@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x,
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) // strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) // strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) // strides[2] + 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
......
......@@ -259,9 +259,9 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(cpt.to_literal_str(
x)).desc = self._find_var(block_desc, cache_var,
is_forward)
self._program.block(block_desc.id).var(
cpt.to_literal_str(x)).desc = self._find_var(
block_desc, cache_var, is_forward)
self._update_graph(x, cache_var, begin_idx=i)
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册