提交 2ad718f7 编写于 作者: Q qingqing01

Clean code

上级 ee2054da
...@@ -29,40 +29,6 @@ TRAIN_LIST_FILE_NAME = "train.list" ...@@ -29,40 +29,6 @@ TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list" TEST_LIST_FILE_NAME = "test.list"
class BatchCompose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform batch transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
# sample list to batch data
batch = list(zip(*data))
return batch
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
class Resize(object): class Resize(object):
def __init__(self, height=48): def __init__(self, height=48):
self.interp = Image.NEAREST # Image.ANTIALIAS self.interp = Image.NEAREST # Image.ANTIALIAS
......
...@@ -20,6 +20,7 @@ import paddle.fluid.profiler as profiler ...@@ -20,6 +20,7 @@ import paddle.fluid.profiler as profiler
import paddle.fluid as fluid import paddle.fluid as fluid
from hapi.model import Input, set_device from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy
...@@ -73,7 +74,7 @@ def main(FLAGS): ...@@ -73,7 +74,7 @@ def main(FLAGS):
model.load(FLAGS.init_model) model.load(FLAGS.init_model)
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = data.BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.MyBatchSampler(
test_dataset, test_dataset,
...@@ -122,7 +123,7 @@ def beam_search(FLAGS): ...@@ -122,7 +123,7 @@ def beam_search(FLAGS):
model.load(FLAGS.init_model) model.load(FLAGS.init_model)
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = data.BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.MyBatchSampler(
test_dataset, test_dataset,
......
...@@ -27,6 +27,7 @@ import paddle.fluid as fluid ...@@ -27,6 +27,7 @@ import paddle.fluid as fluid
from hapi.model import Input, set_device from hapi.model import Input, set_device
from hapi.datasets.folder import ImageFolder from hapi.datasets.folder import ImageFolder
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import postprocess, index2word from utility import postprocess, index2word
...@@ -67,7 +68,7 @@ def main(FLAGS): ...@@ -67,7 +68,7 @@ def main(FLAGS):
fn = lambda p: Image.open(p).convert('L') fn = lambda p: Image.open(p).convert('L')
test_dataset = ImageFolder(FLAGS.image_path, loader=fn) test_dataset = ImageFolder(FLAGS.image_path, loader=fn)
test_collate_fn = data.BatchCompose([data.Resize(), data.Normalize()]) test_collate_fn = BatchCompose([data.Resize(), data.Normalize()])
test_loader = fluid.io.DataLoader( test_loader = fluid.io.DataLoader(
test_dataset, test_dataset,
places=device, places=device,
......
...@@ -25,6 +25,7 @@ import paddle.fluid.profiler as profiler ...@@ -25,6 +25,7 @@ import paddle.fluid.profiler as profiler
import paddle.fluid as fluid import paddle.fluid as fluid
from hapi.model import Input, set_device from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger from utility import SeqAccuracy, MyProgBarLogger
...@@ -97,7 +98,7 @@ def main(FLAGS): ...@@ -97,7 +98,7 @@ def main(FLAGS):
labels=labels) labels=labels)
train_dataset = data.train() train_dataset = data.train()
train_collate_fn = data.BatchCompose( train_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
train_sampler = data.MyBatchSampler( train_sampler = data.MyBatchSampler(
train_dataset, batch_size=FLAGS.batch_size, shuffle=True) train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
...@@ -109,7 +110,7 @@ def main(FLAGS): ...@@ -109,7 +110,7 @@ def main(FLAGS):
return_list=True, return_list=True,
collate_fn=train_collate_fn) collate_fn=train_collate_fn)
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = data.BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.MyBatchSampler(
test_dataset, test_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册