未验证 提交 9b0d8621 编写于 作者: W whs 提交者: GitHub

Merge pull request #789 from wanghaoshuang/refine_ctc

Refine OCR CTC model.
...@@ -187,25 +187,17 @@ def ctc_train_net(images, label, args, num_classes): ...@@ -187,25 +187,17 @@ def ctc_train_net(images, label, args, num_classes):
error_evaluator = fluid.evaluator.EditDistance( error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label) input=decoded_out, label=casted_label)
inference_program = fluid.default_main_program().clone() inference_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(error_evaluator)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum) learning_rate=args.learning_rate, momentum=args.momentum)
_, params_grads = optimizer.minimize(sum_cost) _, params_grads = optimizer.minimize(sum_cost)
model_average = None model_average = fluid.optimizer.ModelAverage(
if args.model_average: params_grads,
model_average = fluid.optimizer.ModelAverage( args.average_window,
params_grads, min_average_window=args.min_average_window,
args.average_window, max_average_window=args.max_average_window)
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
return sum_cost, error_evaluator, inference_program, model_average return sum_cost, error_evaluator, inference_program, model_average
......
import os import os
import cv2 import cv2
import tarfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from os import path
from paddle.v2.image import load_image from paddle.v2.image import load_image
import paddle.v2 as paddle import paddle.v2 as paddle
NUM_CLASSES = 10784 NUM_CLASSES = 10784
DATA_SHAPE = [1, 48, 512] DATA_SHAPE = [1, 48, 512]
DATA_MD5 = "1de60d54d19632022144e4e58c2637b5"
DATA_URL = "http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c"
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"
class DataGenerator(object): class DataGenerator(object):
def __init__(self): def __init__(self):
...@@ -102,25 +113,42 @@ class DataGenerator(object): ...@@ -102,25 +113,42 @@ class DataGenerator(object):
def num_classes(): def num_classes():
'''Get classes number of this dataset.
'''
return NUM_CLASSES return NUM_CLASSES
def data_shape(): def data_shape():
'''Get image shape of this dataset. It is a dummy shape for this dataset.
'''
return DATA_SHAPE return DATA_SHAPE
def train(batch_size): def train(batch_size):
generator = DataGenerator() generator = DataGenerator()
data_dir = download_data()
return generator.train_reader( return generator.train_reader(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train_images/", path.join(data_dir, TRAIN_DATA_DIR_NAME),
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train.list", path.join(data_dir, TRAIN_LIST_FILE_NAME), batch_size)
batch_size)
def test(batch_size=1): def test(batch_size=1):
generator = DataGenerator() generator = DataGenerator()
data_dir = download_data()
return paddle.batch( return paddle.batch(
generator.test_reader( generator.test_reader(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test_images/", path.join(data_dir, TRAIN_DATA_DIR_NAME),
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test.list" path.join(data_dir, TRAIN_LIST_FILE_NAME)), batch_size)
), batch_size)
def download_data():
'''Download train and test data.
'''
tar_file = paddle.dataset.common.download(
DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
if not path.isdir(data_dir):
t = tarfile.open(tar_file, "r:gz")
t.extractall(path=path.dirname(tar_file))
t.close()
return data_dir
...@@ -8,6 +8,7 @@ import functools ...@@ -8,6 +8,7 @@ import functools
import sys import sys
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net from crnn_ctc_model import ctc_train_net
import time
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
...@@ -23,11 +24,10 @@ add_arg('momentum', float, 0.9, "Momentum.") ...@@ -23,11 +24,10 @@ add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.") add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU" add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.") "while '0' means GPU-0.")
add_arg('model_average', bool, True, "Whether to aevrage model for evaluation.")
add_arg('min_average_window', int, 10000, "Min average window.") add_arg('min_average_window', int, 10000, "Min average window.")
add_arg('max_average_window', int, 15625, "Max average window.") add_arg('max_average_window', int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.") add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, True, "Whether use parallel training.") add_arg('parallel', bool, False, "Whether use parallel training.")
# yapf: disable # yapf: disable
def load_parameter(place): def load_parameter(place):
...@@ -70,11 +70,12 @@ def train(args, data_reader=dummy_reader): ...@@ -70,11 +70,12 @@ def train(args, data_reader=dummy_reader):
fetch_list=[sum_cost] + error_evaluator.metrics) fetch_list=[sum_cost] + error_evaluator.metrics)
total_loss += batch_loss[0] total_loss += batch_loss[0]
total_seq_error += batch_seq_error[0] total_seq_error += batch_seq_error[0]
if batch_id % 10 == 1: if batch_id % 100 == 1:
print '.', print '.',
sys.stdout.flush() sys.stdout.flush()
if batch_id % args.log_period == 1: if batch_id % args.log_period == 1:
print "\nPass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % ( print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
time.time(),
pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size)) pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
...@@ -84,8 +85,6 @@ def train(args, data_reader=dummy_reader): ...@@ -84,8 +85,6 @@ def train(args, data_reader=dummy_reader):
for data in test_reader(): for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place)) exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe) _, test_seq_error = error_evaluator.eval(exe)
if model_average != None:
model_average.restore(exe)
print "\nEnd pass[%d]; Test seq error: %s.\n" % ( print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0])) pass_id, str(test_seq_error[0]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册