“33709441663604c397afff7a48d63176cf2496a1”上不存在“develop/doc/dev/write_docs_en.html”
未验证 提交 072bedd1 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #50 from qingqing01/ocr

Add OCR attention model
简介
--------
本OCR任务是识别图片单行的字母信息,基于attention的seq2seq结构。 运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。
## 代码结构
```
.
|-- data.py # 数据读取
|-- eval.py # 评估脚本
|-- images # 测试图片
|-- predict.py # 预测脚本
|-- seq2seq_attn.py # 模型
|-- train.py # 训练脚本
`-- utility.py # 公共模块
```
## 训练/评估/预测流程
- 设置GPU环境:
```
export CUDA_VISIBLE_DEVICES=0
```
- 训练
```
python train.py
```
更多参数可以通过`--help`查看。
- 动静切换
```
python train.py --dynamic=True
```
- 评估
```
python eval.py --init_model=checkpoint/final
```
- 预测
目前不支持动态图预测
```
python predict.py --init_model=checkpoint/final --image_path=images/ --dynamic=False --beam_size=3
```
预测结果如下:
```
Image 1: images/112_chubbiness_13557.jpg
0: chubbines
1: chubbiness
2: chubbinesS
Image 2: images/177_Interfiled_40185.jpg
0: Interflied
1: Interfiled
2: InterfIled
Image 3: images/325_dame_19109.jpg
0: da
1: damo
2: dame
Image 4: images/368_fixtures_29232.jpg
0: firtures
1: Firtures
2: fixtures
```
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from os import path
import random
import traceback
import copy
import math
import tarfile
from PIL import Image
import logging
logger = logging.getLogger(__name__)
import paddle
from paddle import fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
CACHE_DIR_NAME = "attention_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"
class Resize(object):
def __init__(self, height=48):
self.interp = Image.NEAREST # Image.ANTIALIAS
self.height = height
def __call__(self, samples):
shape = samples[0][0].size
for i in range(len(samples)):
im = samples[i][0]
im = im.resize((shape[0], self.height), self.interp)
samples[i][0] = im
return samples
class Normalize(object):
def __init__(self,
mean=[127.5],
std=[1.0],
scale=False,
channel_first=True):
self.mean = mean
self.std = std
self.scale = scale
self.channel_first = channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, samples):
for i in range(len(samples)):
im = samples[i][0]
im = np.array(im).astype(np.float32, copy=False)
im = im[np.newaxis, ...]
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.scale:
im = im / 255.0
#im -= mean
im -= 127.5
#im /= std
samples[i][0] = im
return samples
class PadTarget(object):
def __init__(self, SOS=0, EOS=1):
self.SOS = SOS
self.EOS = EOS
def __call__(self, samples):
lens = np.array([len(s[1]) for s in samples], dtype="int64")
max_len = np.max(lens)
for i in range(len(samples)):
label = samples[i][1]
if max_len > len(label):
pad_label = label + [self.EOS] * (max_len - len(label))
else:
pad_label = label
samples[i][1] = np.array([self.SOS] + pad_label, dtype='int64')
# label_out
samples[i].append(np.array(pad_label + [self.EOS], dtype='int64'))
mask = np.zeros((max_len + 1)).astype('float32')
mask[:len(label) + 1] = 1.0
# mask
samples[i].append(np.array(mask, dtype='float32'))
return samples
class BatchSampler(fluid.io.BatchSampler):
def __init__(self,
dataset,
batch_size,
shuffle=False,
drop_last=True,
seed=None):
self._dataset = dataset
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
self._random = np.random
self._random.seed(seed)
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
self._num_samples = int(
math.ceil(len(self._dataset) * 1.0 / self._nranks))
self._total_size = self._num_samples * self._nranks
self._epoch = 0
def __iter__(self):
infos = copy.copy(self._dataset._sample_infos)
skip_num = 0
if self._shuffle:
if self._batch_size == 1:
self._random.RandomState(self._epoch).shuffle(infos)
else: # partial shuffle
infos = sorted(infos, key=lambda x: x.w)
skip_num = random.randint(1, 100)
infos = infos[skip_num:] + infos[:skip_num]
infos += infos[:(self._total_size - len(infos))]
last_size = self._total_size % (self._batch_size * self._nranks)
batches = []
for i in range(self._local_rank * self._batch_size,
len(infos) - last_size,
self._batch_size * self._nranks):
batches.append(infos[i:i + self._batch_size])
if (not self._drop_last) and last_size != 0:
last_local_size = last_size // self._nranks
last_infos = infos[len(infos) - last_size:]
start = self._local_rank * last_local_size
batches.append(last_infos[start:start + last_local_size])
if self._shuffle:
self._random.RandomState(self._epoch).shuffle(batches)
self._epoch += 1
for batch in batches:
batch_indices = [info.idx for info in batch]
yield batch_indices
def __len__(self):
if self._drop_last:
return self._total_size // self._batch_size
else:
return math.ceil(self._total_size / float(self._batch_size))
class SampleInfo(object):
def __init__(self, idx, h, w, im_name, labels):
self.idx = idx
self.h = h
self.w = w
self.im_name = im_name
self.labels = labels
class OCRDataset(paddle.io.Dataset):
def __init__(self, image_dir, anno_file):
self.image_dir = image_dir
self.anno_file = anno_file
self._sample_infos = []
with open(anno_file, 'r') as f:
for i, line in enumerate(f):
w, h, im_name, labels = line.strip().split(' ')
h, w = int(h), int(w)
labels = [int(c) for c in labels.split(',')]
self._sample_infos.append(SampleInfo(i, h, w, im_name, labels))
def __getitem__(self, idx):
info = self._sample_infos[idx]
im_name, labels = info.im_name, info.labels
image = Image.open(path.join(self.image_dir, im_name)).convert('L')
return [image, labels]
def __len__(self):
return len(self._sample_infos)
def train(
root_dir=None,
images_dir=None,
anno_file=None,
shuffle=True, ):
if root_dir is None:
root_dir = download_data()
if images_dir is None:
images_dir = TRAIN_DATA_DIR_NAME
images_dir = path.join(root_dir, TRAIN_DATA_DIR_NAME)
if anno_file is None:
anno_file = TRAIN_LIST_FILE_NAME
anno_file = path.join(root_dir, TRAIN_LIST_FILE_NAME)
return OCRDataset(images_dir, anno_file)
def test(
root_dir=None,
images_dir=None,
anno_file=None,
shuffle=True, ):
if root_dir is None:
root_dir = download_data()
if images_dir is None:
images_dir = TEST_DATA_DIR_NAME
images_dir = path.join(root_dir, TEST_DATA_DIR_NAME)
if anno_file is None:
anno_file = TEST_LIST_FILE_NAME
anno_file = path.join(root_dir, TEST_LIST_FILE_NAME)
return OCRDataset(images_dir, anno_file)
def download_data():
'''Download train and test data.
'''
tar_file = paddle.dataset.common.download(
DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
if not path.isdir(data_dir):
t = tarfile.open(tar_file, "r:gz")
t.extractall(path=path.dirname(tar_file))
t.close()
return data_dir
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import functools
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments
from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy
from utility import postprocess
from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy
import data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('init_model', str, 'checkpoint/final', "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('embedding_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('beam_size', int, 0, "If set beam size, will use beam search.")
add_arg('dynamic', bool, False, "Whether to use dygraph.")
# yapf: enable
def main(FLAGS):
device = set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = Seq2SeqAttModel(
encoder_size=FLAGS.encoder_size,
decoder_size=FLAGS.decoder_size,
emb_dim=FLAGS.embedding_dim,
num_classes=FLAGS.num_classes)
# yapf: disable
inputs = [
Input([None, 1, 48, 384], "float32", name="pixel"),
Input([None, None], "int64", name="label_in")
]
labels = [
Input([None, None], "int64", name="label_out"),
Input([None, None], "float32", name="mask")
]
# yapf: enable
model.prepare(
loss_function=WeightCrossEntropy(),
metrics=SeqAccuracy(),
inputs=inputs,
labels=labels,
device=device)
model.load(FLAGS.init_model)
test_dataset = data.test()
test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.BatchSampler(
test_dataset,
batch_size=FLAGS.batch_size,
drop_last=False,
shuffle=False)
test_loader = fluid.io.DataLoader(
test_dataset,
batch_sampler=test_sampler,
places=device,
num_workers=0,
return_list=True,
collate_fn=test_collate_fn)
model.evaluate(
eval_data=test_loader,
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
def beam_search(FLAGS):
device = set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = Seq2SeqAttInferModel(
encoder_size=FLAGS.encoder_size,
decoder_size=FLAGS.decoder_size,
emb_dim=FLAGS.embedding_dim,
num_classes=FLAGS.num_classes,
beam_size=FLAGS.beam_size)
inputs = [
Input(
[None, 1, 48, 384], "float32", name="pixel"), Input(
[None, None], "int64", name="label_in")
]
labels = [
Input(
[None, None], "int64", name="label_out"), Input(
[None, None], "float32", name="mask")
]
model.prepare(
loss_function=None,
metrics=SeqBeamAccuracy(),
inputs=inputs,
labels=labels,
device=device)
model.load(FLAGS.init_model)
test_dataset = data.test()
test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.BatchSampler(
test_dataset,
batch_size=FLAGS.batch_size,
drop_last=False,
shuffle=False)
test_loader = fluid.io.DataLoader(
test_dataset,
batch_sampler=test_sampler,
places=device,
num_workers=0,
return_list=True,
collate_fn=test_collate_fn)
model.evaluate(
eval_data=test_loader,
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__':
FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.beam_size:
beam_search(FLAGS)
else:
main(FLAGS)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import random
import numpy as np
import argparse
import functools
from PIL import Image
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
from hapi.model import Input, set_device
from hapi.datasets.folder import ImageFolder
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments
from utility import postprocess, index2word
from seq2seq_attn import Seq2SeqAttInferModel, WeightCrossEntropy
import data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('image_path', str, None, "The directory of images to be used for test.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('embedding_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('beam_size', int, 3, "Beam size for beam search.")
add_arg('dynamic', bool, False, "Whether to use dygraph.")
# yapf: enable
def main(FLAGS):
device = set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = Seq2SeqAttInferModel(
encoder_size=FLAGS.encoder_size,
decoder_size=FLAGS.decoder_size,
emb_dim=FLAGS.embedding_dim,
num_classes=FLAGS.num_classes,
beam_size=FLAGS.beam_size)
inputs = [Input([None, 1, 48, 384], "float32", name="pixel"), ]
model.prepare(inputs=inputs, device=device)
model.load(FLAGS.init_model)
fn = lambda p: Image.open(p).convert('L')
test_dataset = ImageFolder(FLAGS.image_path, loader=fn)
test_collate_fn = BatchCompose([data.Resize(), data.Normalize()])
test_loader = fluid.io.DataLoader(
test_dataset,
places=device,
num_workers=0,
return_list=True,
collate_fn=test_collate_fn)
samples = test_dataset.samples
#outputs = model.predict(test_loader)
ins_id = 0
for image, in test_loader:
image = image if FLAGS.dynamic else image[0]
pred = model.test_batch([image])[0]
pred = pred[:, :, np.newaxis] if len(pred.shape) == 2 else pred
pred = np.transpose(pred, [0, 2, 1])
for ins in pred:
impath = samples[ins_id]
ins_id += 1
print('Image {}: {}'.format(ins_id, impath))
for beam_idx, beam in enumerate(ins):
id_list = postprocess(beam)
word_list = index2word(id_list)
sequence = "".join(word_list)
print('{}: {}'.format(beam_idx, sequence))
if __name__ == '__main__':
FLAGS = parser.parse_args()
print_arguments(FLAGS)
main(FLAGS)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layers import BeamSearchDecoder
from hapi.text import RNNCell, RNN, DynamicDecode
from hapi.model import Model, Loss
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
in_ch,
out_ch,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.pool = pool
filter_size = 3
std = (2.0 / (filter_size**2 * in_ch))**0.5
param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, std))
std = (2.0 / (filter_size**2 * out_ch))**0.5
param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, std))
self.conv0 = fluid.dygraph.Conv2D(
in_ch,
out_ch,
3,
padding=1,
param_attr=param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn0 = fluid.dygraph.BatchNorm(out_ch, act=act)
self.conv1 = fluid.dygraph.Conv2D(
out_ch,
out_ch,
filter_size=3,
padding=1,
param_attr=param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn1 = fluid.dygraph.BatchNorm(out_ch, act=act)
if self.pool:
self.pool = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.bn0(out)
out = self.conv1(out)
out = self.bn1(out)
if self.pool:
out = self.pool(out)
return out
class CNN(fluid.dygraph.Layer):
def __init__(self, in_ch=1, is_test=False):
super(CNN, self).__init__()
self.conv_bn1 = ConvBNPool(in_ch, 16)
self.conv_bn2 = ConvBNPool(16, 32)
self.conv_bn3 = ConvBNPool(32, 64)
self.conv_bn4 = ConvBNPool(64, 128, pool=False)
def forward(self, inputs):
conv = self.conv_bn1(inputs)
conv = self.conv_bn2(conv)
conv = self.conv_bn3(conv)
conv = self.conv_bn4(conv)
return conv
class GRUCell(RNNCell):
def __init__(self,
input_size,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.fc_layer = fluid.dygraph.Linear(
input_size,
hidden_size * 3,
param_attr=param_attr,
bias_attr=False)
self.gru_unit = fluid.dygraph.GRUUnit(
hidden_size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
def forward(self, inputs, states):
# step_outputs, new_states = cell(step_inputs, states)
# for GRUCell, `step_outputs` and `new_states` both are hidden
x = self.fc_layer(inputs)
hidden, _, _ = self.gru_unit(x, states)
return hidden, hidden
@property
def state_shape(self):
return [self.hidden_size]
class Encoder(fluid.dygraph.Layer):
def __init__(
self,
in_channel=1,
rnn_hidden_size=200,
decoder_size=128,
is_test=False, ):
super(Encoder, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
self.backbone = CNN(in_ch=in_channel, is_test=is_test)
para_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
self.gru_fwd = RNN(cell=GRUCell(
input_size=128 * 6,
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu'),
is_reverse=False,
time_major=False)
self.gru_bwd = RNN(cell=GRUCell(
input_size=128 * 6,
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu'),
is_reverse=True,
time_major=False)
self.encoded_proj_fc = fluid.dygraph.Linear(
rnn_hidden_size * 2, decoder_size, bias_attr=False)
def forward(self, inputs):
conv_features = self.backbone(inputs)
conv_features = fluid.layers.transpose(
conv_features, perm=[0, 3, 1, 2])
n, w, c, h = conv_features.shape
seq_feature = fluid.layers.reshape(conv_features, [0, -1, c * h])
gru_fwd, _ = self.gru_fwd(seq_feature)
gru_bwd, _ = self.gru_bwd(seq_feature)
encoded_vector = fluid.layers.concat(input=[gru_fwd, gru_bwd], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_bwd, encoded_vector, encoded_proj
class Attention(fluid.dygraph.Layer):
"""
Neural Machine Translation by Jointly Learning to Align and Translate.
https://arxiv.org/abs/1409.0473
"""
def __init__(self, decoder_size):
super(Attention, self).__init__()
self.fc1 = fluid.dygraph.Linear(
decoder_size, decoder_size, bias_attr=False)
self.fc2 = fluid.dygraph.Linear(decoder_size, 1, bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
# alignment model, single-layer multilayer perceptron
decoder_state = self.fc1(decoder_state)
decoder_state = fluid.layers.unsqueeze(decoder_state, [1])
e = fluid.layers.elementwise_add(encoder_proj, decoder_state)
e = fluid.layers.tanh(e)
att_scores = self.fc2(e)
att_scores = fluid.layers.squeeze(att_scores, [2])
att_scores = fluid.layers.softmax(att_scores)
context = fluid.layers.elementwise_mul(
x=encoder_vec, y=att_scores, axis=0)
context = fluid.layers.reduce_sum(context, dim=1)
return context
class DecoderCell(RNNCell):
def __init__(self, encoder_size=200, decoder_size=128):
super(DecoderCell, self).__init__()
self.attention = Attention(decoder_size)
self.gru_cell = GRUCell(
input_size=encoder_size * 2 + decoder_size,
hidden_size=decoder_size)
def forward(self, current_word, states, encoder_vec, encoder_proj):
context = self.attention(encoder_vec, encoder_proj, states)
decoder_inputs = fluid.layers.concat([current_word, context], axis=1)
hidden, _ = self.gru_cell(decoder_inputs, states)
return hidden, hidden
class Decoder(fluid.dygraph.Layer):
def __init__(self, num_classes, emb_dim, encoder_size, decoder_size):
super(Decoder, self).__init__()
self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size))
self.fc = fluid.dygraph.Linear(
decoder_size, num_classes + 2, act='softmax')
def forward(self, target, initial_states, encoder_vec, encoder_proj):
out, _ = self.decoder_attention(
target,
initial_states=initial_states,
encoder_vec=encoder_vec,
encoder_proj=encoder_proj)
pred = self.fc(out)
return pred
class Seq2SeqAttModel(Model):
def __init__(
self,
in_channle=1,
encoder_size=200,
decoder_size=128,
emb_dim=128,
num_classes=None, ):
super(Seq2SeqAttModel, self).__init__()
self.encoder = Encoder(in_channle, encoder_size, decoder_size)
self.fc = fluid.dygraph.Linear(
input_dim=encoder_size,
output_dim=decoder_size,
bias_attr=False,
act='relu')
self.embedding = fluid.dygraph.Embedding(
[num_classes + 2, emb_dim], dtype='float32')
self.decoder = Decoder(num_classes, emb_dim, encoder_size,
decoder_size)
def forward(self, inputs, target):
gru_backward, encoded_vector, encoded_proj = self.encoder(inputs)
decoder_boot = self.fc(gru_backward[:, 0])
trg_embedding = self.embedding(target)
prediction = self.decoder(trg_embedding, decoder_boot, encoded_vector,
encoded_proj)
return prediction
class Seq2SeqAttInferModel(Seq2SeqAttModel):
def __init__(
self,
in_channle=1,
encoder_size=200,
decoder_size=128,
emb_dim=128,
num_classes=None,
beam_size=0,
bos_id=0,
eos_id=1,
max_out_len=20, ):
super(Seq2SeqAttInferModel, self).__init__(
in_channle, encoder_size, decoder_size, emb_dim, num_classes)
self.beam_size = beam_size
# dynamic decoder for inference
decoder = BeamSearchDecoder(
self.decoder.decoder_attention.cell,
start_token=bos_id,
end_token=eos_id,
beam_size=beam_size,
embedding_fn=self.embedding,
output_fn=self.decoder.fc)
self.infer_decoder = DynamicDecode(
decoder, max_step_num=max_out_len, is_test=True)
def forward(self, inputs, *args):
gru_backward, encoded_vector, encoded_proj = self.encoder(inputs)
decoder_boot = self.fc(gru_backward[:, 0])
if self.beam_size:
# Tile the batch dimension with beam_size
encoded_vector = BeamSearchDecoder.tile_beam_merge_with_batch(
encoded_vector, self.beam_size)
encoded_proj = BeamSearchDecoder.tile_beam_merge_with_batch(
encoded_proj, self.beam_size)
# dynamic decoding with beam search
rs, _ = self.infer_decoder(
inits=decoder_boot,
encoder_vec=encoded_vector,
encoder_proj=encoded_proj)
return rs
class WeightCrossEntropy(Loss):
def __init__(self):
super(WeightCrossEntropy, self).__init__(average=False)
def forward(self, outputs, labels):
predict, (label, mask) = outputs[0], labels
loss = layers.cross_entropy(predict, label=label)
loss = layers.elementwise_mul(loss, mask, axis=0)
loss = layers.reduce_sum(loss)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import random
import numpy as np
import argparse
import functools
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments
from utility import SeqAccuracy, LoggerCallBack
from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy
import data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('epoch', int, 30, "Epoch number.")
add_arg('num_workers', int, 0, "workers number.")
add_arg('lr', float, 0.001, "Learning rate.")
add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
add_arg('checkpoint_path', str, "checkpoint", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('resume_path', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('embedding_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
add_arg('dynamic', bool, False, "Whether to use dygraph.")
# yapf: enable
def main(FLAGS):
device = set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = Seq2SeqAttModel(
encoder_size=FLAGS.encoder_size,
decoder_size=FLAGS.decoder_size,
emb_dim=FLAGS.embedding_dim,
num_classes=FLAGS.num_classes)
lr = FLAGS.lr
if FLAGS.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay(
[200000, 250000], [lr, lr * 0.1, lr * 0.01])
else:
learning_rate = lr
grad_clip = fluid.clip.GradientClipByGlobalNorm(FLAGS.gradient_clip)
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
parameter_list=model.parameters(),
grad_clip=grad_clip)
# yapf: disable
inputs = [
Input([None,1,48,384], "float32", name="pixel"),
Input([None, None], "int64", name="label_in"),
]
labels = [
Input([None, None], "int64", name="label_out"),
Input([None, None], "float32", name="mask"),
]
# yapf: enable
model.prepare(
optimizer,
WeightCrossEntropy(),
SeqAccuracy(),
inputs=inputs,
labels=labels)
train_dataset = data.train()
train_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()])
train_sampler = data.BatchSampler(
train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
train_loader = fluid.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
places=device,
num_workers=FLAGS.num_workers,
return_list=True,
collate_fn=train_collate_fn)
test_dataset = data.test()
test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.BatchSampler(
test_dataset,
batch_size=FLAGS.batch_size,
drop_last=False,
shuffle=False)
test_loader = fluid.io.DataLoader(
test_dataset,
batch_sampler=test_sampler,
places=device,
num_workers=0,
return_list=True,
collate_fn=test_collate_fn)
model.fit(train_data=train_loader,
eval_data=test_loader,
epochs=FLAGS.epoch,
save_dir=FLAGS.checkpoint_path,
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__':
FLAGS = parser.parse_args()
print_arguments(FLAGS)
main(FLAGS)
"""Contains common utility functions."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
import paddle.fluid as fluid
import six
from hapi.metrics import Metric
from hapi.callbacks import ProgBarLogger
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class SeqAccuracy(Metric):
def __init__(self, name=None, *args, **kwargs):
super(SeqAccuracy, self).__init__(*args, **kwargs)
self._name = 'seq_acc'
self.reset()
def add_metric_op(self, output, label, mask, *args, **kwargs):
pred = fluid.layers.flatten(output, axis=2)
score, topk = fluid.layers.topk(pred, 1)
return topk, label, mask
def update(self, topk, label, mask, *args, **kwargs):
topk = topk.reshape(label.shape[0], -1)
seq_len = np.sum(mask, -1)
acc = 0
for i in range(label.shape[0]):
l = int(seq_len[i] - 1)
pred = topk[i][:l - 1]
ref = label[i][:l - 1]
if np.array_equal(pred, ref):
self.total += 1
acc += 1
self.count += 1
return float(acc) / label.shape[0]
def reset(self):
self.total = 0.
self.count = 0.
def accumulate(self):
return float(self.total) / self.count
def name(self):
return self._name
class LoggerCallBack(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None):
super(LoggerCallBack, self).__init__(log_freq, verbose)
self.train_bs = train_bs
self.eval_bs = eval_bs if eval_bs else train_bs
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(LoggerCallBack, self).on_train_batch_end(step, logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(LoggerCallBack, self).on_epoch_end(epoch, logs)
def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(LoggerCallBack, self).on_eval_batch_end(step, logs)
def on_eval_end(self, logs=None):
logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(LoggerCallBack, self).on_eval_end(logs)
def index2word(ids):
return [chr(int(k + 33)) for k in ids]
def postprocess(seq, bos_idx=0, eos_idx=1):
if type(seq) is np.ndarray:
seq = seq.tolist()
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1] if idx != bos_idx and idx != eos_idx
]
return seq
class SeqBeamAccuracy(Metric):
def __init__(self, name=None, *args, **kwargs):
super(SeqBeamAccuracy, self).__init__(*args, **kwargs)
self._name = 'seq_acc'
self.reset()
def add_metric_op(self, output, label, mask, *args, **kwargs):
return output, label, mask
def update(self, preds, labels, masks, *args, **kwargs):
preds = preds[:, :, np.newaxis] if len(preds.shape) == 2 else preds
preds = np.transpose(preds, [0, 2, 1])
seq_len = np.sum(masks, -1)
acc = 0
for i in range(labels.shape[0]):
l = int(seq_len[i] - 1)
#ref = labels[i][: l - 1]
ref = np.array(postprocess(labels[i]))
pred = preds[i]
for idx, beam in enumerate(pred):
beam_pred = np.array(postprocess(beam))
if np.array_equal(beam_pred, ref):
self.total += 1
acc += 1
break
self.count += 1
return float(acc) / labels.shape[0]
def reset(self):
self.total = 0.
self.count = 0.
def accumulate(self):
return float(self.total) / self.count
def name(self):
return self._name
...@@ -218,8 +218,6 @@ class ProgBarLogger(Callback): ...@@ -218,8 +218,6 @@ class ProgBarLogger(Callback):
# if steps is not None, last step will update in on_epoch_end # if steps is not None, last step will update in on_epoch_end
if self.steps and self.train_step < self.steps: if self.steps and self.train_step < self.steps:
self._updates(logs, 'train') self._updates(logs, 'train')
else:
self._updates(logs, 'train')
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
logs = logs or {} logs = logs or {}
...@@ -238,7 +236,7 @@ class ProgBarLogger(Callback): ...@@ -238,7 +236,7 @@ class ProgBarLogger(Callback):
def on_eval_batch_end(self, step, logs=None): def on_eval_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
self.eval_step = step self.eval_step += 1
samples = logs.get('batch_size', 1) samples = logs.get('batch_size', 1)
self.evaled_samples += samples self.evaled_samples += samples
......
...@@ -18,7 +18,7 @@ import cv2 ...@@ -18,7 +18,7 @@ import cv2
from paddle.io import Dataset from paddle.io import Dataset
__all__ = ["DatasetFolder"] __all__ = ["DatasetFolder", "ImageFolder"]
def has_valid_extension(filename, extensions): def has_valid_extension(filename, extensions):
...@@ -164,3 +164,80 @@ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', ...@@ -164,3 +164,80 @@ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
def cv2_loader(path): def cv2_loader(path):
return cv2.imread(path) return cv2.imread(path)
class ImageFolder(Dataset):
"""A generic data loader where the samples are arranged in this way:
root/1.ext
root/2.ext
root/sub_dir/3.ext
Args:
root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
samples (list): List of sample path
"""
def __init__(self,
root,
loader=None,
extensions=None,
transform=None,
is_valid_file=None):
self.root = root
if extensions is None:
extensions = IMG_EXTENSIONS
samples = []
path = os.path.expanduser(root)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
return has_valid_extension(x, extensions)
for root, _, fnames in sorted(os.walk(path, followlinks=True)):
for fname in sorted(fnames):
f = os.path.join(root, fname)
if is_valid_file(f):
samples.append(f)
if len(samples) == 0:
raise (RuntimeError(
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.extensions = extensions
self.samples = samples
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
return [sample]
def __len__(self):
return len(self.samples)
...@@ -1161,7 +1161,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1161,7 +1161,7 @@ class Model(fluid.dygraph.Layer):
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
feed_list = None feed_list = None
else: else:
feed_list = [x.forward() for x in self._inputs + self._labels] feed_list = [x.forward() for x in self._inputs]
if test_data is not None and isinstance(test_data, Dataset): if test_data is not None and isinstance(test_data, Dataset):
test_sampler = DistributedBatchSampler( test_sampler = DistributedBatchSampler(
...@@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer):
if mode == 'train': if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given' assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_end(epoch) callbacks.on_epoch_end(epoch, logs)
return logs return logs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册