未验证 提交 9b0d8621 编写于 作者: W whs 提交者: GitHub

Merge pull request #789 from wanghaoshuang/refine_ctc

Refine OCR CTC model.
......@@ -187,25 +187,17 @@ def ctc_train_net(images, label, args, num_classes):
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(error_evaluator)
inference_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum)
_, params_grads = optimizer.minimize(sum_cost)
model_average = None
if args.model_average:
model_average = fluid.optimizer.ModelAverage(
params_grads,
args.average_window,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
model_average = fluid.optimizer.ModelAverage(
params_grads,
args.average_window,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
return sum_cost, error_evaluator, inference_program, model_average
......
import os
import cv2
import tarfile
import numpy as np
from PIL import Image
from os import path
from paddle.v2.image import load_image
import paddle.v2 as paddle
NUM_CLASSES = 10784
DATA_SHAPE = [1, 48, 512]
DATA_MD5 = "1de60d54d19632022144e4e58c2637b5"
DATA_URL = "http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c"
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"
class DataGenerator(object):
def __init__(self):
......@@ -102,25 +113,42 @@ class DataGenerator(object):
def num_classes():
'''Get classes number of this dataset.
'''
return NUM_CLASSES
def data_shape():
'''Get image shape of this dataset. It is a dummy shape for this dataset.
'''
return DATA_SHAPE
def train(batch_size):
generator = DataGenerator()
data_dir = download_data()
return generator.train_reader(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train_images/",
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train.list",
batch_size)
path.join(data_dir, TRAIN_DATA_DIR_NAME),
path.join(data_dir, TRAIN_LIST_FILE_NAME), batch_size)
def test(batch_size=1):
generator = DataGenerator()
data_dir = download_data()
return paddle.batch(
generator.test_reader(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test_images/",
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test.list"
), batch_size)
path.join(data_dir, TRAIN_DATA_DIR_NAME),
path.join(data_dir, TRAIN_LIST_FILE_NAME)), batch_size)
def download_data():
'''Download train and test data.
'''
tar_file = paddle.dataset.common.download(
DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
if not path.isdir(data_dir):
t = tarfile.open(tar_file, "r:gz")
t.extractall(path=path.dirname(tar_file))
t.close()
return data_dir
......@@ -8,6 +8,7 @@ import functools
import sys
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
import time
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -23,11 +24,10 @@ add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('model_average', bool, True, "Whether to aevrage model for evaluation.")
add_arg('min_average_window', int, 10000, "Min average window.")
add_arg('max_average_window', int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('parallel', bool, False, "Whether use parallel training.")
# yapf: disable
def load_parameter(place):
......@@ -70,11 +70,12 @@ def train(args, data_reader=dummy_reader):
fetch_list=[sum_cost] + error_evaluator.metrics)
total_loss += batch_loss[0]
total_seq_error += batch_seq_error[0]
if batch_id % 10 == 1:
if batch_id % 100 == 1:
print '.',
sys.stdout.flush()
if batch_id % args.log_period == 1:
print "\nPass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
time.time(),
pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
sys.stdout.flush()
batch_id += 1
......@@ -84,8 +85,6 @@ def train(args, data_reader=dummy_reader):
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
if model_average != None:
model_average.restore(exe)
print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册