提交 c6e23183 编写于 作者: P peterzhang2029

add infer.py for ctc

上级 8b5c739a
此差异已折叠。
from __future__ import absolute_import
from __future__ import division
import os
from paddle.v2.image import load_image
import cv2
class AsciiDic(object):
UNK = 0
def __init__(self):
self.dic = {
'<unk>': self.UNK,
}
self.chars = [chr(i) for i in range(40, 171)]
for id, c in enumerate(self.chars):
self.dic[c] = id + 1
def lookup(self, w):
return self.dic.get(w, self.UNK)
def id2word(self):
self.id2word = {}
for key, value in self.dic.items():
self.id2word[value] = key
return self.id2word
def word2ids(self, sent):
'''
transform a word to a list of ids.
@sent: str
'''
return [self.lookup(c) for c in list(sent)]
def size(self):
return len(self.dic)
class ImageDataset(object):
def __init__(self,
train_image_paths_generator,
test_image_paths_generator,
infer_image_paths_generator,
fixed_shape=None,
is_infer=False):
'''
@image_paths_generator: function
return a list of images' paths, called like:
for path in image_paths_generator():
load_image(path)
'''
if is_infer == False:
self.train_filelist = [p for p in train_image_paths_generator]
self.test_filelist = [p for p in test_image_paths_generator]
else:
self.infer_filelist = [p for p in infer_image_paths_generator]
self.fixed_shape = fixed_shape
self.ascii_dic = AsciiDic()
def train(self):
for i, (image, label) in enumerate(self.train_filelist):
yield self.load_image(image), self.ascii_dic.word2ids(label)
def test(self):
for i, (image, label) in enumerate(self.test_filelist):
yield self.load_image(image), self.ascii_dic.word2ids(label)
def infer(self):
for i, (image, label) in enumerate(self.infer_filelist):
yield self.load_image(image), label
def load_image(self, path):
'''
load image and transform to 1-dimention vector
'''
image = load_image(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# resize all images to a fixed shape
if self.fixed_shape:
image = cv2.resize(
image, self.fixed_shape, interpolation=cv2.INTER_CUBIC)
image = image.flatten() / 255.
return image
def get_file_list(image_file_list):
pwd = os.path.dirname(image_file_list)
with open(image_file_list) as f:
for line in f:
fs = line.strip().split(',')
file = fs[0].strip()
path = os.path.join(pwd, file)
yield path, fs[1][2:-1]
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import groupby
import numpy as np
def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
import logging
import argparse
import paddle.v2 as paddle
import gzip
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
from decoder import ctc_greedy_decoder
def infer(inferer, test_batch, labels):
infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(test_batch))
]
results = []
# best path decode
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=AsciiDic().id2word())
results.append(output_transcription)
for result, label in zip(results, labels):
print("\nOutput Transcription: %s\nTarget Transcription: %s" % (result,
label))
if __name__ == "__main__":
model_path = "model.ctc-pass-1-batch-150-test-10.2607016472.tar.gz"
image_shape = "173,46"
batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
image_shape = tuple(map(int, image_shape.split(',')))
infer_generator = get_file_list(infer_file_list)
dataset = ImageDataset(None, None, infer_generator, image_shape, True)
paddle.init(use_gpu=True, trainer_count=4)
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
model = Model(AsciiDic().size(), image_shape, is_infer=True)
inferer = paddle.inference.Inference(
output_layer=model.log_probs, parameters=parameters)
test_batch = []
labels = []
for i, (image, label) in enumerate(dataset.infer()):
test_batch.append([image])
labels.append(label)
if len(test_batch) == batch_size:
infer(inferer, test_batch, labels)
test_batch = []
labels = []
if test_batch:
infer(inferer, test_batch, labels)
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
def conv_groups(input_image, num, with_bn):
'''
a deep CNN.
@input_image: input image
@num: number of CONV filters
@with_bn: whether with batch normal
'''
assert num % 4 == 0
tmp = img_conv_group(
input=input_image,
num_channels=1,
conv_padding=1,
conv_num_filter=[16] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[32] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[64] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[128] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
return tmp
class Model(object):
def __init__(self, num_classes, shape, is_infer=False):
'''
@num_classes: int
size of the character dict
@shape: tuple of 2 int
size of the input images
'''
self.num_classes = num_classes
self.shape = shape
self.is_infer = is_infer
self.image_vector_size = shape[0] * shape[1]
self.__declare_input_layers__()
self.__build_nn__()
def __declare_input_layers__(self):
# image input as a float vector
self.image = layer.data(
name='image',
type=paddle.data_type.dense_vector(self.image_vector_size),
height=self.shape[0],
width=self.shape[1])
# label input as a ID list
if self.is_infer == False:
self.label = layer.data(
name='label',
type=paddle.data_type.integer_value_sequence(self.num_classes))
def __build_nn__(self):
# CNN output image features, 128 float matrixes
conv_features = conv_groups(self.image, 8, True)
# cutting CNN output into a sequence of feature vectors, which are
# 1 pixel wide and 11 pixel high.
sliced_feature = layer.block_expand(
input=conv_features,
num_channels=128,
stride_x=1,
stride_y=1,
block_x=1,
block_y=11)
# RNNs to capture sequence information forwards and backwards.
gru_forward = simple_gru(input=sliced_feature, size=128, act=Relu())
gru_backward = simple_gru(
input=sliced_feature, size=128, act=Relu(), reverse=True)
# map each step of RNN to character distribution.
self.output = layer.fc(
input=[gru_forward, gru_backward],
size=self.num_classes + 1,
act=Linear())
self.log_probs = paddle.layer.mixed(
input=paddle.layer.identity_projection(input=self.output),
act=paddle.activation.Softmax())
# warp CTC to calculate cost for a CTC task.
if self.is_infer == False:
self.cost = layer.warp_ctc(
input=self.output,
label=self.label,
size=self.num_classes + 1,
norm_by_times=True,
blank=self.num_classes)
import logging
import argparse
import paddle.v2 as paddle
import gzip
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
parser = argparse.ArgumentParser(description="PaddlePaddle CTC example")
parser.add_argument(
'--image_shape',
type=str,
required=True,
help="image's shape, format is like '173,46'")
parser.add_argument(
'--train_file_list',
type=str,
required=True,
help='path of the file which contains path list of train image files')
parser.add_argument(
'--test_file_list',
type=str,
required=True,
help='path of the file which contains path list of test image files')
parser.add_argument(
'--batch_size', type=int, default=5, help='size of a mini-batch')
parser.add_argument(
'--model_output_prefix',
type=str,
default='model.ctc',
help='prefix of path for model to store (default: ./model.ctc)')
parser.add_argument(
'--trainer_count', type=int, default=4, help='number of training threads')
parser.add_argument(
'--save_period_by_batch',
type=int,
default=50,
help='save model to disk every N batches')
parser.add_argument(
'--num_passes',
type=int,
default=1,
help='number of passes to train (default: 1)')
args = parser.parse_args()
image_shape = tuple(map(int, args.image_shape.split(',')))
print 'image_shape', image_shape
print 'batch_size', args.batch_size
print 'train_file_list', args.train_file_list
print 'test_file_list', args.test_file_list
train_generator = get_file_list(args.train_file_list)
test_generator = get_file_list(args.test_file_list)
infer_generator = None
dataset = ImageDataset(
train_generator,
test_generator,
infer_generator,
fixed_shape=image_shape,
is_infer=False)
paddle.init(use_gpu=True, trainer_count=args.trainer_count)
model = Model(AsciiDic().size(), image_shape, is_infer=False)
params = paddle.parameters.create(model.cost)
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(
cost=model.cost, parameters=params, update_equation=optimizer)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, batch %d, Samples %d, Cost %f" % (
event.pass_id, event.batch_id, event.batch_id * args.batch_size,
event.cost)
if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0:
result = trainer.test(
reader=paddle.batch(dataset.test, batch_size=10),
feeding={'image': 0,
'label': 1})
print "Test %d-%d, Cost %f " % (event.pass_id, event.batch_id,
result.cost)
path = "{}-pass-{}-batch-{}-test-{}.tar.gz".format(
args.model_output_prefix, event.pass_id, event.batch_id,
result.cost)
with gzip.open(path, 'w') as f:
params.to_tar(f)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=500),
batch_size=args.batch_size),
feeding={'image': 0,
'label': 1},
event_handler=event_handler,
num_passes=args.num_passes)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册