diff --git a/examples/ocr/data.py b/examples/ocr/data.py index 00c4b1ac5a1116443ce5a8ecaeeb35e66782e521..cb0e13f608af5abb58ae2f5244e4d089ce663f33 100644 --- a/examples/ocr/data.py +++ b/examples/ocr/data.py @@ -29,40 +29,6 @@ TRAIN_LIST_FILE_NAME = "train.list" TEST_LIST_FILE_NAME = "test.list" -class BatchCompose(object): - def __init__(self, transforms=[]): - self.transforms = transforms - - def __call__(self, data): - for f in self.transforms: - try: - data = f(data) - except Exception as e: - stack_info = traceback.format_exc() - logger.info("fail to perform batch transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e - # sample list to batch data - batch = list(zip(*data)) - return batch - - -class Compose(object): - def __init__(self, transforms=[]): - self.transforms = transforms - - def __call__(self, *data): - for f in self.transforms: - try: - data = f(*data) - except Exception as e: - stack_info = traceback.format_exc() - logger.info("fail to perform transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e - return data - - class Resize(object): def __init__(self, height=48): self.interp = Image.NEAREST # Image.ANTIALIAS diff --git a/examples/ocr/eval.py b/examples/ocr/eval.py index e3c487da64baf41c68904263e840701ac91c8545..2fd22751694f9def9988bb24d9292ad82a442337 100644 --- a/examples/ocr/eval.py +++ b/examples/ocr/eval.py @@ -20,6 +20,7 @@ import paddle.fluid.profiler as profiler import paddle.fluid as fluid from hapi.model import Input, set_device +from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy @@ -73,7 +74,7 @@ def main(FLAGS): model.load(FLAGS.init_model) test_dataset = data.test() - test_collate_fn = data.BatchCompose( + test_collate_fn = BatchCompose( [data.Resize(), data.Normalize(), data.PadTarget()]) test_sampler = data.MyBatchSampler( test_dataset, @@ -122,7 +123,7 @@ def beam_search(FLAGS): model.load(FLAGS.init_model) test_dataset = data.test() - test_collate_fn = data.BatchCompose( + test_collate_fn = BatchCompose( [data.Resize(), data.Normalize(), data.PadTarget()]) test_sampler = data.MyBatchSampler( test_dataset, diff --git a/examples/ocr/predict.py b/examples/ocr/predict.py index d1a66f7b053fcd83bb548eddf7bbcfcf5ca74c8e..df0f21e9ea16e1442159ed049d4867c066d4d4e4 100644 --- a/examples/ocr/predict.py +++ b/examples/ocr/predict.py @@ -27,6 +27,7 @@ import paddle.fluid as fluid from hapi.model import Input, set_device from hapi.datasets.folder import ImageFolder +from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments from utility import postprocess, index2word @@ -67,7 +68,7 @@ def main(FLAGS): fn = lambda p: Image.open(p).convert('L') test_dataset = ImageFolder(FLAGS.image_path, loader=fn) - test_collate_fn = data.BatchCompose([data.Resize(), data.Normalize()]) + test_collate_fn = BatchCompose([data.Resize(), data.Normalize()]) test_loader = fluid.io.DataLoader( test_dataset, places=device, diff --git a/examples/ocr/train.py b/examples/ocr/train.py index 789edcafe5859bf031069d3e23697cd1e56b4399..423e27d6d2171e125d3108257634997393caa0fc 100644 --- a/examples/ocr/train.py +++ b/examples/ocr/train.py @@ -25,6 +25,7 @@ import paddle.fluid.profiler as profiler import paddle.fluid as fluid from hapi.model import Input, set_device +from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments from utility import SeqAccuracy, MyProgBarLogger @@ -97,7 +98,7 @@ def main(FLAGS): labels=labels) train_dataset = data.train() - train_collate_fn = data.BatchCompose( + train_collate_fn = BatchCompose( [data.Resize(), data.Normalize(), data.PadTarget()]) train_sampler = data.MyBatchSampler( train_dataset, batch_size=FLAGS.batch_size, shuffle=True) @@ -109,7 +110,7 @@ def main(FLAGS): return_list=True, collate_fn=train_collate_fn) test_dataset = data.test() - test_collate_fn = data.BatchCompose( + test_collate_fn = BatchCompose( [data.Resize(), data.Normalize(), data.PadTarget()]) test_sampler = data.MyBatchSampler( test_dataset,