提交 27afc286 编写于 作者: G guosheng

Update Additive Attention followed by GRU

上级 e4e393c8
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from text import DynamicDecode, RNN, BasicLSTMCell, RNNCell
from text import DynamicDecode, RNN, RNNCell
from model import Model, Loss
......@@ -91,82 +92,70 @@ class OCRConv(fluid.dygraph.Layer):
return inputs_4
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc1 = Linear(decoder_size, decoder_size, bias_attr=False)
self.fc2 = Linear(decoder_size, 1, bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state = self.fc1(decoder_state)
decoder_state = fluid.layers.unsqueeze(decoder_state, [1])
mix = fluid.layers.elementwise_add(encoder_proj, decoder_state)
mix = fluid.layers.tanh(x=mix)
attn_score = self.fc2(mix)
attn_scores = layers.squeeze(attn_score, [2])
attn_scores = fluid.layers.softmax(attn_scores)
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=attn_scores, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUCell(RNNCell):
def __init__(self,
size,
input_size,
hidden_size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
init_size=None):
origin_mode=False):
super(GRUCell, self).__init__()
self.input_proj = Linear(
768, size * 3, param_attr=param_attr, bias_attr=False)
self.hidden_size = hidden_size
self.fc_layer = Linear(
input_size,
hidden_size * 3,
param_attr=param_attr,
bias_attr=False)
self.gru_unit = GRUUnit(
size * 3,
hidden_size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.is_reverse = is_reverse
def forward(self, inputs, states):
# step_outputs, new_states = cell(step_inputs, states)
# for GRUCell, `step_outputs` and `new_states` both are hidden
x = self.input_proj(inputs)
x = self.fc_layer(inputs)
hidden, _, _ = self.gru_unit(x, states)
return hidden, hidden
class DecoderCell(RNNCell):
def __init__(self, size):
self.gru = GRUCell(size)
self.attention = SimpleAttention(size)
self.fc_1_layer = Linear(
input_dim=size * 2, output_dim=size * 3, bias_attr=False)
self.fc_2_layer = Linear(
input_dim=size, output_dim=size * 3, bias_attr=False)
def forward(self, inputs, states, encoder_vec, encoder_proj):
context = self.attention(encoder_vec, encoder_proj, states)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(inputs)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _ = self.gru(decoder_inputs, states)
return h, h
class Decoder(fluid.dygraph.Layer):
def __init__(self, size, num_classes):
super(Decoder, self).__init__()
self.embedder = Embedding(size=[num_classes, size])
self.gru_attention = RNN(DecoderCell(size),
is_reverse=False,
time_major=False)
self.output_layer = Linear(size, num_classes, bias_attr=False)
def forward(self, target, decoder_initial_states, encoder_vec,
encoder_proj):
inputs = self.embedder(target)
decoder_output, _ = self.gru_attention(
inputs,
initial_states=decoder_initial_states,
encoder_vec=encoder_vec,
encoder_proj=encoder_proj)
predict = self.output_layer(decoder_output)
return predict
@property
def state_shape(self):
return [self.hidden_size]
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
batch_size,
decoder_size,
rnn_hidden_size=200,
is_test=False,
......@@ -179,21 +168,24 @@ class EncoderNet(fluid.dygraph.Layer):
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.fc_2_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
self.gru_forward_layer = RNN(
cell=GRUCell(
input_size=128 * 6, # channel * h
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
candidate_activation='relu'),
is_reverse=False,
time_major=False)
self.gru_backward_layer = RNN(
cell=GRUCell(
input_size=128 * 6, # channel * h
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
candidate_activation='relu'),
is_reverse=True,
time_major=False)
self.encoded_proj_fc = Linear(
rnn_hidden_size * 2, decoder_size, bias_attr=False)
......@@ -211,13 +203,9 @@ class EncoderNet(fluid.dygraph.Layer):
],
inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward, _ = self.gru_forward_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
gru_backward, _ = self.gru_backward_layer(sliced_feature)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
......@@ -227,88 +215,50 @@ class EncoderNet(fluid.dygraph.Layer):
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear(
decoder_size, decoder_size, act=None, bias_attr=False)
self.fc_2 = Linear(decoder_size, 1, act=None, bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]],
inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax(weights_reshape)
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class DecoderCell(RNNCell):
def __init__(self, encoder_size, decoder_size):
super(DecoderCell, self).__init__()
self.attention = SimpleAttention(decoder_size)
self.gru_cell = GRUCell(
input_size=encoder_size * 2 +
decoder_size, # encoded_vector.shape[-1] + embed_size
hidden_size=decoder_size)
def forward(self, current_word, states, encoder_vec, encoder_proj):
context = self.attention(encoder_vec, encoder_proj, states)
decoder_inputs = layers.concat([current_word, context], axis=1)
hidden, _ = self.gru_cell(decoder_inputs, states)
return hidden, hidden
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, encoder_size, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear(
input_dim=encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear(
input_dim=decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3, param_attr=None, bias_attr=None)
self.gru_attention = RNN(DecoderCell(encoder_size, decoder_size),
is_reverse=False,
time_major=False)
self.out_layer = Linear(
input_dim=decoder_size,
output_dim=num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self,
current_word,
encoder_vec,
encoder_proj,
decoder_boot,
inference=False):
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
decoder_boot)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, decoder_boot)
out = self.out_layer(h)
return out, h
def forward(self, inputs, decoder_initial_states, encoder_vec,
encoder_proj):
out, _ = self.gru_attention(
inputs,
initial_states=decoder_initial_states,
encoder_vec=encoder_vec,
encoder_proj=encoder_proj)
predict = self.out_layer(out)
return predict
class OCRAttention(fluid.dygraph.Layer):
def __init__(self, batch_size, num_classes, encoder_size, decoder_size,
class OCRAttention(Model):
def __init__(self, num_classes, encoder_size, decoder_size,
word_vector_dim):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet(batch_size, decoder_size)
self.encoder_net = EncoderNet(decoder_size)
self.fc = Linear(
input_dim=encoder_size,
output_dim=decoder_size,
......@@ -318,36 +268,26 @@ class OCRAttention(fluid.dygraph.Layer):
[num_classes + 2, word_vector_dim], dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(
encoder_size, decoder_size, num_classes)
self.batch_size = batch_size
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
decoder_boot = self.fc(gru_backward[:, 0])
trg_embedding = self.embedding(label_in)
prediction = self.gru_decoder_with_attention(
trg_embedding, decoder_boot, encoded_vector, encoded_proj)
trg_embedding = fluid.layers.reshape(
trg_embedding, [self.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
return prediction
pred_temp = []
for i in range(trg_embedding.shape[1]):
current_word = fluid.layers.slice(
trg_embedding, axes=[1], starts=[i], ends=[i + 1])
out, decoder_boot = self.gru_decoder_with_attention(
current_word, encoded_vector, encoded_proj, decoder_boot)
pred_temp.append(out)
pred_temp = fluid.layers.concat(pred_temp, axis=1)
batch_size = trg_embedding.shape[0]
seq_len = trg_embedding.shape[1]
prediction = fluid.layers.reshape(
pred_temp, shape=[batch_size, seq_len, -1])
class CrossEntropyCriterion(Loss):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
return prediction
def forward(self, outputs, labels):
predict, (label, mask) = outputs[0], labels
loss = layers.cross_entropy(predict, label=label, soft_label=False)
loss = layers.elementwise_mul(loss, mask, axis=0)
loss = layers.reduce_sum(loss)
return loss
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
import data_reader
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments, get_attention_feeder_data
from model import Input, set_device
from nets import OCRAttention, CrossEntropyCriterion
from eval import evaluate
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('epoch_num', int, 30, "Epoch number.")
add_arg('lr', float, 0.001, "Learning rate.")
add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
add_arg('log_period', int, 200, "Log period.")
add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
add_arg('dynamic', bool, False, "Whether to use dygraph.")
def train(args):
device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) if args.dynamic else None
ocr_attention = OCRAttention(encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
LR = args.lr
if args.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01])
else:
learning_rate = LR
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters())
# grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
inputs = [
Input([None, 1, 48, 384], "float32", name="pixel"),
Input([None, None], "int64", name="label_in"),
]
labels = [
Input([None, None], "int64", name="label_out"),
Input([None, None], "float32", name="mask")]
ocr_attention.prepare(optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
train_reader = data_reader.data_reader(
args.batch_size,
shuffle=True,
images_dir=args.train_images,
list_file=args.train_list,
data_type='train')
# test_reader = data_reader.data_reader(
# args.batch_size,
# images_dir=args.test_images,
# list_file=args.test_list,
# data_type="test")
# if not os.path.exists(args.save_model_dir):
# os.makedirs(args.save_model_dir)
total_step = 0
epoch_num = args.epoch_num
for epoch in range(epoch_num):
batch_id = 0
total_loss = 0.0
for data in train_reader():
total_step += 1
data_dict = get_attention_feeder_data(data)
pixel = data_dict["pixel"]
label_in = data_dict["label_in"].reshape([pixel.shape[0], -1])
label_out = data_dict["label_out"].reshape([pixel.shape[0], -1])
mask = data_dict["mask"].reshape(label_out.shape).astype("float32")
avg_loss = ocr_attention.train(inputs=[pixel, label_in], labels=[label_out, mask])[0]
total_loss += avg_loss
if True:#batch_id > 0 and batch_id % args.log_period == 0:
print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id,
total_loss / args.batch_size / args.log_period))
total_loss = 0.0
batch_id += 1
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(args)
else:
train(args)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册