提交 0675f9cd 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/models into ssd

......@@ -218,8 +218,6 @@ class AsyncDataReader(object):
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -424,6 +422,9 @@ class AsyncDataReader(object):
sample_queue = self._start_async_processing()
batch_queue = self._manager.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)
assembling_proc = DaemonProcessGroup(
proc_num=1,
target=batch_assembling_task,
......@@ -439,3 +440,6 @@ class AsyncDataReader(object):
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
# clean the shared memory
del self._pool_manager
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import sys, time
from six import reraise
from tblib import Traceback
from multiprocessing import Manager, Process
......@@ -161,9 +161,10 @@ class SharedMemoryPoolManager(object):
def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
self._names = []
self._dict = manager.dict()
self._time_prefix = time.strftime('%Y%m%d%H%M%S')
for i in xrange(pool_size):
name = name_prefix + '_' + str(i)
name = name_prefix + '_' + self._time_prefix + '_' + str(i)
self._dict[name] = SharedNDArray(name)
self._names.append(name)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "decoder.h"
std::string decode(std::vector<std::vector<float>> probs_mat) {
// Add decoding logic here
return "example decoding result";
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "post_decode_faster.h"
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename) {
const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices.";
kaldi::ParseOptions po(usage);
binary = true;
acoustic_scale = 1.5;
allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings.
po.Register("binary", &binary, "Write output in binary mode");
po.Register("allow-partial",
&allow_partial,
"Produce output even when final state was not reached");
po.Register("acoustic-scale",
&acoustic_scale,
"Scaling factor for acoustic likelihoods");
word_syms = NULL;
if (word_syms_filename != "") {
word_syms = fst::SymbolTable::ReadText(word_syms_filename);
if (!word_syms)
KALDI_ERR << "Could not read symbol table from file "
<< word_syms_filename;
}
std::ifstream is_logprior(logprior_rxfilename);
logprior.Read(is_logprior, false);
// It's important that we initialize decode_fst after loglikes_reader, as it
// can prevent crashes on systems installed without enough virtual memory.
// It has to do with what happens on UNIX systems if you call fork() on a
// large process: the page-table entries are duplicated, which requires a
// lot of virtual memory.
decode_fst = fst::ReadFstKaldi(fst_in_filename);
decoder = new kaldi::FasterDecoder(*decode_fst, decoder_opts);
}
Decoder::~Decoder() {
if (!word_syms) delete word_syms;
delete decode_fst;
delete decoder;
}
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>>& log_probs) {
size_t num_frames = log_probs.size();
size_t dim_label = log_probs[0].size();
kaldi::Matrix<kaldi::BaseFloat> loglikes(
num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols);
for (size_t i = 0; i < num_frames; ++i) {
memcpy(loglikes.Data() + i * dim_label,
log_probs[i].data(),
sizeof(kaldi::BaseFloat) * dim_label);
}
return decode(key, loglikes);
}
std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier);
std::vector<std::string> decoding_results;
for (; !posterior_reader.Done(); posterior_reader.Next()) {
std::string key = posterior_reader.Key();
kaldi::Matrix<kaldi::BaseFloat> loglikes(posterior_reader.Value());
decoding_results.push_back(decode(key, loglikes));
}
return decoding_results;
}
std::string Decoder::decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat>& loglikes) {
std::string decoding_result;
if (loglikes.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << key;
}
KALDI_ASSERT(loglikes.NumCols() == logprior.Dim());
loglikes.ApplyLog();
loglikes.AddVecToRows(-1.0, logprior);
kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale);
decoder->Decode(&decodable);
VectorFst<kaldi::LatticeArc> decoded; // linear FST.
if ((allow_partial || decoder->ReachedFinal()) &&
decoder->GetBestPath(&decoded)) {
if (!decoder->ReachedFinal())
KALDI_WARN << "Decoder did not reach end-state, outputting partial "
"traceback.";
std::vector<int32> alignment;
std::vector<int32> words;
kaldi::LatticeWeight weight;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
decoding_result += s;
if (s == "")
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
}
}
}
return decoding_result;
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -14,5 +14,44 @@ limitations under the License. */
#include <string>
#include <vector>
#include "base/kaldi-common.h"
#include "base/timer.h"
#include "decoder/decodable-matrix.h"
#include "decoder/faster-decoder.h"
#include "fstext/fstext-lib.h"
#include "hmm/transition-model.h"
#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc
#include "tree/context-dep.h"
#include "util/common-utils.h"
std::string decode(std::vector<std::vector<float>> probs_mat);
class Decoder {
public:
Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename);
~Decoder();
// Interface to accept the scores read from specifier and return
// the batch decoding results
std::vector<std::string> decode(std::string posterior_rspecifier);
// Accept the scores of one utterance and return the decoding result
std::string decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs);
private:
// For decoding one utterance
std::string decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes);
fst::SymbolTable *word_syms;
fst::VectorFst<fst::StdArc> *decode_fst;
kaldi::FasterDecoder *decoder;
kaldi::Vector<kaldi::BaseFloat> logprior;
bool binary;
kaldi::BaseFloat acoustic_scale;
bool allow_partial;
};
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,15 +15,25 @@ limitations under the License. */
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "decoder.h"
#include "post_decode_faster.h"
namespace py = pybind11;
PYBIND11_MODULE(decoder, m) {
m.doc() = "Decode function for Deep ASR model";
m.def("decode",
&decode,
PYBIND11_MODULE(post_decode_faster, m) {
m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string>())
.def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode,
"Decode for the probability matrices in specifier "
"and return the transcriptions.")
.def(
"decode",
(std::string (Decoder::*)(
std::string, const std::vector<std::vector<kaldi::BaseFloat>>&)) &
Decoder::decode,
"Decode one input probability matrix "
"and return the transcription");
"and return the transcription.");
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,27 +13,57 @@
# limitations under the License.
import os
import glob
from distutils.core import setup, Extension
from distutils.sysconfig import get_config_vars
args = ['-std=c++11']
try:
kaldi_root = os.environ['KALDI_ROOT']
except:
raise ValueError("Enviroment variable 'KALDI_ROOT' is not defined. Please "
"install kaldi and export KALDI_ROOT=<kaldi's root dir> .")
args = [
'-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable',
'-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable',
'-Wno-deprecated-declarations', '-Wno-unused-function'
]
# remove warning about -Wstrict-prototypes
(opt, ) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(flag for flag in opt.split()
if flag != '-Wstrict-prototypes')
os.environ['CC'] = 'g++'
LIBS = [
'fst', 'kaldi-base', 'kaldi-util', 'kaldi-matrix', 'kaldi-tree',
'kaldi-hmm', 'kaldi-fstext', 'kaldi-decoder', 'kaldi-lat'
]
LIB_DIRS = [
'tools/openfst/lib', 'src/base', 'src/matrix', 'src/util', 'src/tree',
'src/hmm', 'src/fstext', 'src/decoder', 'src/lat'
]
LIB_DIRS = [os.path.join(kaldi_root, path) for path in LIB_DIRS]
LIB_DIRS = [os.path.abspath(path) for path in LIB_DIRS]
ext_modules = [
Extension(
'decoder',
['pybind.cc', 'decoder.cc'],
include_dirs=['pybind11/include', '.'],
'post_decode_faster',
['pybind.cc', 'post_decode_faster.cc'],
include_dirs=[
'pybind11/include', '.', os.path.join(kaldi_root, 'src'),
os.path.join(kaldi_root, 'tools/openfst/src/include')
],
language='c++',
libraries=LIBS,
library_dirs=LIB_DIRS,
runtime_library_dirs=LIB_DIRS,
extra_compile_args=args, ),
]
setup(
name='decoder',
name='post_decode_faster',
version='0.0.1',
author='Paddle',
author_email='',
......
set -e
if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git
......
......@@ -13,7 +13,7 @@ import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.async_data_reader as reader
import decoder.decoder as decoder
from decoder.post_decode_faster import Decoder
from data_utils.util import lodtensor_to_ndarray
from model_utils.model import stacked_lstmp_model
from data_utils.util import split_infer_result
......@@ -32,6 +32,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
......@@ -47,6 +52,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
......@@ -81,6 +91,21 @@ def parse_args():
type=str,
default='./checkpoint',
help="The checkpoint path to init model. (default: %(default)s)")
parser.add_argument(
'--vocabulary',
type=str,
default='./decoder/graph/words.txt',
help="The path to vocabulary. (default: %(default)s)")
parser.add_argument(
'--graphs',
type=str,
default='./decoder/graph/TLG.fst',
help="The path to TLG graphs for decoding. (default: %(default)s)")
parser.add_argument(
'--log_prior',
type=str,
default="./decoder/logprior",
help="The log prior probs for training data. (default: %(default)s)")
args = parser.parse_args()
return args
......@@ -99,10 +124,11 @@ def infer_from_ckpt(args):
raise IOError("Invalid checkpoint!")
prediction, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)
infer_program = fluid.default_main_program().clone()
......@@ -154,8 +180,8 @@ def infer_from_ckpt(args):
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch):
print("Decoding %d: " % (batch_id * args.batch_size + index),
decoder.decode(sample))
key = "utter#%d" % (batch_id * args.batch_size + index)
print(key, ": ", decoder.decode(key, sample), "\n")
print(np.mean(infer_costs), np.mean(infer_accs))
......
......@@ -6,7 +6,8 @@ import paddle.v2 as paddle
import paddle.fluid as fluid
def stacked_lstmp_model(hidden_dim,
def stacked_lstmp_model(frame_dim,
hidden_dim,
proj_dim,
stacked_num,
class_num,
......@@ -20,6 +21,7 @@ def stacked_lstmp_model(hidden_dim,
label data respectively. And in inference, only `feature` is needed.
Args:
frame_dim(int): The frame dimension of feature data.
hidden_dim(int): The hidden state's dimension of the LSTMP layer.
proj_dim(int): The projection size of the LSTMP layer.
stacked_num(int): The number of stacked LSTMP layers.
......@@ -78,7 +80,7 @@ def stacked_lstmp_model(hidden_dim,
# data feeder
feature = fluid.layers.data(
name="feature", shape=[-1, 120 * 11], dtype="float32", lod_level=1)
name="feature", shape=[-1, frame_dim], dtype="float32", lod_level=1)
label = fluid.layers.data(
name="label", shape=[-1, 1], dtype="int64", lod_level=1)
......@@ -92,11 +94,12 @@ def stacked_lstmp_model(hidden_dim,
feat_ = pd.read_input(feature)
label_ = pd.read_input(label)
prediction, avg_cost, acc = _net_conf(feat_, label_)
for out in [avg_cost, acc]:
for out in [prediction, avg_cost, acc]:
pd.write_output(out)
# get mean loss and acc through every devices.
avg_cost, acc = pd()
prediction, avg_cost, acc = pd()
prediction.stop_gradient = True
avg_cost = fluid.layers.mean(x=avg_cost)
acc = fluid.layers.mean(x=acc)
else:
......
......@@ -31,6 +31,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
......@@ -46,6 +51,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
......@@ -119,10 +129,11 @@ def profile(args):
"arg 'first_batches_to_skip' must not be smaller than 0.")
_, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
......
......@@ -30,6 +30,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
......@@ -45,6 +50,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
......@@ -137,10 +147,11 @@ def train(args):
os.mkdir(args.infer_models)
prediction, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)
# program for test
......
......@@ -3,18 +3,37 @@ class TrainTaskConfig(object):
# the epoch number to train.
pass_num = 2
# number of sequences contained in a mini-batch.
# the number of sequences contained in a mini-batch.
batch_size = 64
# the hyper params for Adam optimizer.
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# the params for learning rate scheduling
# the parameters for learning rate scheduling.
warmup_steps = 4000
# the directory for saving trained models.
model_dir = "trained_models"
class InferTaskConfig(object):
use_gpu = False
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1
# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
......@@ -33,6 +52,11 @@ class ModelHyperParams(object):
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# position value corresponding to the <pad> token.
pos_pad_idx = 0
......@@ -64,14 +88,21 @@ pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers listed in order.
input_data_names = (
# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
"trg_word",
"trg_pos",
"src_slf_attn_bias",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
# Names of label related data layers listed in order.
label_data_names = (
"lbl_word",
"lbl_weight", )
import numpy as np
import paddle.v2 as paddle
import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
return_pos=True,
return_attn_bias=True,
return_max_len=True)
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size
next_ids = [[]] * batch_size
# Use beam_map to map the instance idx in batch to beam idx, since the
# size of feeded batch is changing.
beam_map = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
seqs.append(seq)
return seqs
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
-1, :])
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
beam_map = active_beams
if len(beam_map) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]
return seqs, scores[:, :n_best].tolist()
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program(
target_vars=[enc_output], main_program=encoder_program)
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
if __name__ == "__main__":
main()
......@@ -4,7 +4,8 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from config import TrainTaskConfig, input_data_names, pos_enc_param_names
from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
......@@ -127,7 +128,9 @@ def multi_head_attention(queries,
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = __softmax(layers.elementwise_add(x=product, y=attn_bias))
weights = __softmax(
layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product)
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
......@@ -280,8 +283,15 @@ def encoder(enc_input,
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout_rate)
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
enc_input = enc_output
return enc_output
......@@ -373,75 +383,142 @@ def decoder(dec_input,
return dec_output
def transformer(
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
def make_inputs(input_data_names,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
batch_size,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False):
"""
Define the input data layers for the transformer model.
"""
input_layers = []
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time. The actual
# shape of src_word in run time is:
# [batch_size * max_src_length_in_a_batch, 1].
src_word = layers.data(
name=input_data_names[0],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of src_pos in runtime is:
# [batch_size * max_src_length_in_a_batch, 1].
src_pos = layers.data(
name=input_data_names[1],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of trg_word is in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
trg_word = layers.data(
name=input_data_names[2],
# The shapes set here is to pass the infer-shape in compile time.
word = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of trg_pos in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
trg_pos = layers.data(
name=input_data_names[3],
input_layers += [word]
# This is used for position data or label weight.
pos = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64",
dtype="int64" if is_pos else "float32",
append_batch_size=False)
# The actual shape of src_slf_attn_bias in runtime is:
# [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch].
# This input is used to remove attention weights on paddings.
src_slf_attn_bias = layers.data(
name=input_data_names[4],
input_layers += [pos]
if slf_attn_bias_flag:
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
# The actual shape of trg_slf_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch].
# This is used to remove attention weights on paddings and subsequent words.
trg_slf_attn_bias = layers.data(
name=input_data_names[5],
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings.
src_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
# The actual shape of trg_src_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch].
# This is used to remove attention weights on paddings.
trg_src_attn_bias = layers.data(
name=input_data_names[6],
shape=[batch_size, n_head, max_length, max_length],
input_layers += [src_attn_bias]
if enc_output_flag:
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers
def transformer(
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, False)
enc_output = wrap_encoder(
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, True)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers,
enc_output, )
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_input_layers is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, False)
else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers
enc_input = prepare_encoder(
src_word,
src_pos,
......@@ -460,6 +537,32 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate, )
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers=None,
enc_output=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_input_layers is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, True, True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
dec_input = prepare_decoder(
trg_word,
......@@ -482,32 +585,11 @@ def transformer(
d_inner_hid,
dropout_rate, )
# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax")
# The actual shape of gold in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
gold = layers.data(
name=input_data_names[7],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
cost = layers.cross_entropy(input=predict, label=gold)
# The actual shape of weights in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
# Padding index do not contribute to the total loss. This Weight is used to
# cancel padding index in calculating the loss.
weights = layers.data(
name=input_data_names[8],
shape=[batch_size * max_length, 1],
dtype="float32",
append_batch_size=False)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost)
return predict
import os
import numpy as np
import paddle.v2 as paddle
......@@ -5,21 +6,13 @@ import paddle.fluid as fluid
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import TrainTaskConfig, ModelHyperParams, \
pos_enc_param_names, input_data_names
from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head, place):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias. Then, convert the numpy
data to tensors and return a dict mapping names to tensors.
"""
input_dict = {}
def __pad_batch_data(insts,
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
return_pos=True,
return_attn_bias=True,
......@@ -35,8 +28,7 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos:
inst_pos = np.array([[
pos_i + 1 if w_i != pad_idx else 0
for pos_i, w_i in enumerate(inst)
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst)
] for inst in inst_data])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
......@@ -44,8 +36,7 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len,
max_len))
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
[-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
......@@ -63,28 +54,26 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
return_list += [max_len]
return return_list if len(return_list) > 1 else return_list[0]
def data_to_tensor(data_list, name_list, input_dict, place):
assert len(data_list) == len(name_list)
for i in range(len(name_list)):
tensor = fluid.LoDTensor()
tensor.set(data_list[i], place)
input_dict[name_list[i]] = tensor
src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, is_target=True)
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
False, False, False)
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
data_to_tensor([
src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
], input_data_names, input_dict, place)
]))
return input_dict
......@@ -92,7 +81,7 @@ def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
cost = transformer(
cost, predict = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
......@@ -118,6 +107,31 @@ def main():
buf_size=100000),
batch_size=TrainTaskConfig.batch_size)
# Program to do validation.
test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([cost])
val_data = paddle.batch(
paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
def test(exe):
test_costs = []
for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
test_cost = exe.run(test_program,
feed=data_input,
fetch_list=[cost])[0]
test_costs.append(test_cost)
return np.mean(test_costs)
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names:
......@@ -134,9 +148,10 @@ def main():
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
ModelHyperParams.n_head)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
......@@ -145,6 +160,14 @@ def main():
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" cost = " + str(cost_val))
# Validate and save the model for inference.
val_cost = test(exe)
print("pass_id = " + str(pass_id) + " val_cost = " + str(val_cost))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
encoder_input_data_names + decoder_input_data_names[:-1],
[predict], exe)
if __name__ == "__main__":
......
from PIL import Image
from PIL import Image, ImageEnhance
import numpy as np
import random
import math
......@@ -159,3 +159,77 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transform_labels(bbox_labels, sample_bbox)
return sample_img, sample_labels
def random_brightness(img, settings):
prob = random.uniform(0, 1)
if prob < settings._brightness_prob:
delta = random.uniform(-settings._brightness_delta,
settings._brightness_delta) + 1
img = ImageEnhance.Brightness(img).enhance(delta)
return img
def random_contrast(img, settings):
prob = random.uniform(0, 1)
if prob < settings._contrast_prob:
delta = random.uniform(-settings._contrast_delta,
settings._contrast_delta) + 1
img = ImageEnhance.Contrast(img).enhance(delta)
return img
def random_saturation(img, settings):
prob = random.uniform(0, 1)
if prob < settings._saturation_prob:
delta = random.uniform(-settings._saturation_delta,
settings._saturation_delta) + 1
img = ImageEnhance.Color(img).enhance(delta)
return img
def random_hue(img, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
delta = random.uniform(-settings._hue_delta, settings._hue_delta)
img_hsv = np.array(img.convert('HSV'))
img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta
img = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
return img
def distort_image(img, settings):
prob = random.uniform(0, 1)
# Apply different distort order
if prob > 0.5:
img = random_brightness(img, settings)
img = random_contrast(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
else:
img = random_brightness(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
img = random_contrast(img, settings)
return img
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
if expand_ratio - 1 >= 0.01:
height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio)
h_off = math.floor(random.uniform(0, height - img_height))
w_off = math.floor(random.uniform(0, width - img_width))
expand_bbox = bbox(-w_off / img_width, -h_off / img_height,
(width - w_off) / img_width,
(height - h_off) / img_height)
expand_img = np.ones((height, width, 3))
expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean))
expand_img = Image.fromarray(expand_img)
expand_img.paste(img, (int(w_off), int(h_off)))
bbox_labels = transform_labels(bbox_labels, expand_bbox)
return expand_img, bbox_labels
return img, bbox_labels
......@@ -22,17 +22,38 @@ import os
class Settings(object):
def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value):
def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value,
apply_distort, apply_expand):
self._data_dir = data_dir
self._label_list = []
label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath):
self._label_list.append(line.strip())
self._apply_distort = apply_distort
self._apply_expand = apply_expand
self._resize_height = resize_h
self._resize_width = resize_w
self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
'float32')
self._expand_prob = 0.5
self._expand_max_ratio = 4
self._hue_prob = 0.5
self._hue_delta = 18
self._contrast_prob = 0.5
self._contrast_delta = 0.5
self._saturation_prob = 0.5
self._saturation_delta = 0.5
self._brightness_prob = 0.5
self._brightness_delta = 0.125
@property
def apply_distort(self):
return self._apply_expand
@property
def apply_distort(self):
return self._apply_distort
@property
def data_dir(self):
......@@ -71,7 +92,6 @@ def _reader_creator(settings, file_list, mode, shuffle):
img = Image.open(img_path)
img_width, img_height = img.size
img = np.array(img)
# layout: label | xmin | ymin | xmax | ymax | difficult
if mode == 'train' or mode == 'test':
......@@ -99,6 +119,12 @@ def _reader_creator(settings, file_list, mode, shuffle):
sample_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels = image_util.expand_image(
img, bbox_labels, img_width, img_height,
settings)
batch_sampler = []
# hard-code here
batch_sampler.append(
......@@ -126,6 +152,7 @@ def _reader_creator(settings, file_list, mode, shuffle):
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image(
......
......@@ -75,13 +75,10 @@ def train(args,
evaluate_difficult=False,
ap_version='11point')
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.exponential_decay(
learning_rate=learning_rate,
decay_steps=40000,
decay_rate=0.1,
staircase=True),
momentum=0.9,
boundaries = [40000, 60000]
values = [0.001, 0.0005, 0.00025]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), )
optimizer.minimize(loss)
......@@ -90,7 +87,8 @@ def train(args,
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
load_model.load_paddlev1_vars(place)
load_model.load_and_set_vars(place)
#load_model.load_paddlev1_vars(place)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
test_reader = paddle.batch(
......@@ -113,6 +111,7 @@ def train(args,
loss_v = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}"
.format(pass_id, batch_id, loss_v[0]))
test(pass_id)
......@@ -130,6 +129,8 @@ if __name__ == '__main__':
data_args = reader.Settings(
data_dir='./data',
label_file='label_list',
apply_distort=True,
apply_expand=True,
resize_h=300,
resize_w=300,
mean_value=[127.5, 127.5, 127.5])
......
......@@ -26,7 +26,12 @@ def conv_bn_pool(input,
bias_attr=bias,
is_test=is_test)
tmp = fluid.layers.pool2d(
input=tmp, pool_size=2, pool_type='max', pool_stride=2, use_cudnn=True)
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
ceil_mode=True)
return tmp
......@@ -136,26 +141,61 @@ def encoder_net(images,
def ctc_train_net(images, label, args, num_classes):
regularizer = fluid.regularizer.L2Decay(args.l2)
gradient_clip = None
if args.parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
images_ = pd.read_input(images)
label_ = pd.read_input(label)
fc_out = encoder_net(
images,
images_,
num_classes,
regularizer=regularizer,
gradient_clip=gradient_clip)
cost = fluid.layers.warpctc(
input=fc_out, label=label, blank=num_classes, norm_by_times=True)
input=fc_out,
label=label_,
blank=num_classes,
norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum)
optimizer.minimize(sum_cost)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
pd.write_output(sum_cost)
pd.write_output(decoded_out)
sum_cost, decoded_out = pd()
sum_cost = fluid.layers.reduce_sum(sum_cost)
else:
fc_out = encoder_net(
images,
num_classes,
regularizer=regularizer,
gradient_clip=gradient_clip)
cost = fluid.layers.warpctc(
input=fc_out, label=label, blank=num_classes, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
return sum_cost, error_evaluator
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(error_evaluator)
optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum)
_, params_grads = optimizer.minimize(sum_cost)
return sum_cost, error_evaluator, inference_program
def ctc_infer(images, num_classes):
......
"""Trainer for OCR CTC model."""
import paddle.v2 as paddle
import paddle.fluid as fluid
import dummy_reader
import ctc_reader
......@@ -24,12 +23,12 @@ add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('parallel', bool, True, "Whether use parallel training.")
# yapf: disable
def load_parameter(place):
params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/')
for name in params:
# print "param: %s" % name
t = fluid.global_scope().find_var(name).get_tensor()
t.set(params[name], place)
......@@ -41,7 +40,8 @@ def train(args, data_reader=dummy_reader):
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes)
sum_cost, error_evaluator, inference_program = ctc_train_net(images, label, args, num_classes)
# data reader
train_reader = data_reader.train(args.batch_size)
test_reader = data_reader.test()
......@@ -51,11 +51,8 @@ def train(args, data_reader=dummy_reader):
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#load_parameter(place)
inference_program = fluid.io.get_inference_program(error_evaluator)
for pass_id in range(args.pass_num):
error_evaluator.reset(exe)
batch_id = 1
......@@ -78,7 +75,6 @@ def train(args, data_reader=dummy_reader):
sys.stdout.flush()
batch_id += 1
# evaluate model on test data
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
......
# 命名实体识别
以下是本例的简要目录结构及说明:
```text
.
├── data # 存储运行本例所依赖的数据,从外部获取
├── network_conf.py # 模型定义
├── reader.py # 数据读取接口, 从外部获取
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── utils.py # 定义通用的函数, 从外部获取
└── utils_extend.py # 对utils.py的拓展
```
## 简介,模型详解
在PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中对于命名实体识别任务有较详细的介绍,在本例中不再重复介绍。
在模型上,我们沿用了v2版本的模型结构,唯一区别是我们使用LSTM代替原始的RNN。
## 数据获取
请参考PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md) 一节中数据获取方式,将该例中的data文件夹拷贝至本例目录下,运行其中的download.sh脚本获取训练和测试数据。
## 通用脚本获取
请将PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中提供的用于数据读取的文件[reader.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/reader.py)以及包含字典导入等通用功能的文件[utils.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/utils.py)复制到本目录下。本例将会使用到这两个脚本。
## 训练
1. 运行 `sh data/download.sh`
2. 修改 `train.py``main` 函数,指定数据路径
```python
main(
train_data_file="data/train",
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
emb_file="data/wordVectors.txt",
model_save_dir="models",
num_passes=1000,
use_gpu=False,
parallel=False)
```
3. 运行命令 `python train.py`**需要注意:直接运行使用的是示例数据,请替换真实的标记数据。**
```text
Pass 127, Batch 9525, Cost 4.0867705, Precision 0.3954984, Recall 0.37846154, F1_score0.38679245
Pass 127, Batch 9530, Cost 3.137265, Precision 0.42971888, Recall 0.38351256, F1_score0.405303
Pass 127, Batch 9535, Cost 3.6240938, Precision 0.4272152, Recall 0.41795665, F1_score0.4225352
Pass 127, Batch 9540, Cost 3.5352352, Precision 0.48464164, Recall 0.4536741, F1_score0.46864685
Pass 127, Batch 9545, Cost 4.1130385, Precision 0.40131578, Recall 0.3836478, F1_score0.39228293
Pass 127, Batch 9550, Cost 3.6826708, Precision 0.43333334, Recall 0.43730888, F1_score0.43531203
Pass 127, Batch 9555, Cost 3.6363933, Precision 0.42424244, Recall 0.3962264, F1_score0.4097561
Pass 127, Batch 9560, Cost 3.6101768, Precision 0.51363635, Recall 0.353125, F1_score0.41851854
Pass 127, Batch 9565, Cost 3.5935276, Precision 0.5152439, Recall 0.5, F1_score0.5075075
Pass 127, Batch 9570, Cost 3.4987144, Precision 0.5, Recall 0.4330218, F1_score0.46410686
Pass 127, Batch 9575, Cost 3.4659843, Precision 0.39864865, Recall 0.38064516, F1_score0.38943896
Pass 127, Batch 9580, Cost 3.1702557, Precision 0.5, Recall 0.4490446, F1_score0.47315437
Pass 127, Batch 9585, Cost 3.1587276, Precision 0.49377593, Recall 0.4089347, F1_score0.4473684
Pass 127, Batch 9590, Cost 3.5043538, Precision 0.4556962, Recall 0.4600639, F1_score0.45786962
Pass 127, Batch 9595, Cost 2.981989, Precision 0.44981414, Recall 0.45149255, F1_score0.4506518
[TrainSet] pass_id:127 pass_precision:[0.46023396] pass_recall:[0.43197003] pass_f1_score:[0.44565433]
[TestSet] pass_id:127 pass_precision:[0.4708409] pass_recall:[0.47971722] pass_f1_score:[0.4752376]
```
## 预测
1. 修改 [infer.py](./infer.py)`infer` 函数,指定:需要测试的模型的路径、测试数据、字典文件,预测标记文件的路径,默认参数如下:
```python
infer(
model_path="models/params_pass_0",
batch_size=6,
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
use_gpu=False
)
```
2. 在终端运行 `python infer.py`,开始测试,会看到如下预测结果(以下为训练70个pass所得模型的部分预测结果):
```text
leicestershire B-ORG B-LOC
extended O O
their O O
first O O
innings O O
by O O
DGDG O O
runs O O
before O O
being O O
bowled O O
out O O
for O O
296 O O
with O O
england B-LOC B-LOC
discard O O
andy B-PER B-PER
caddick I-PER I-PER
taking O O
three O O
for O O
DGDG O O
. O O
```
输出分为三列,以“\t” 分隔,第一列是输入的词语,第二列是标准结果,第三列为生成的标记结果。多条输入序列之间以空行分隔。
## 结果示例
<p align="center">
<img src="imgs/convergence_curve.png" width="80%" align="center"/><br/>
图1. 学习曲线, 横轴表示训练轮数,纵轴表示F1值
</p>
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
from network_conf import ner_net
import reader
from utils import load_dict, load_reverse_dict
from utils_extend import to_lodtensor
def infer(model_path, batch_size, test_data_file, vocab_file, target_file,
use_gpu):
"""
use the model under model_path to predict the test data, the result will be printed on the screen
return nothing
"""
word_dict = load_dict(vocab_file)
word_reverse_dict = load_reverse_dict(vocab_file)
label_dict = load_dict(target_file)
label_reverse_dict = load_reverse_dict(target_file)
test_data = paddle.batch(
reader.data_reader(test_data_file, word_dict, label_dict),
batch_size=batch_size)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
for data in test_data():
word = to_lodtensor(map(lambda x: x[0], data), place)
mark = to_lodtensor(map(lambda x: x[1], data), place)
target = to_lodtensor(map(lambda x: x[2], data), place)
crf_decode = exe.run(
inference_program,
feed={"word": word,
"mark": mark,
"target": target},
fetch_list=fetch_targets,
return_numpy=False)
lod_info = (crf_decode[0].lod())[0]
np_data = np.array(crf_decode[0])
assert len(data) == len(lod_info) - 1
for sen_index in xrange(len(data)):
assert len(data[sen_index][0]) == lod_info[
sen_index + 1] - lod_info[sen_index]
word_index = 0
for tag_index in xrange(lod_info[sen_index],
lod_info[sen_index + 1]):
word = word_reverse_dict[data[sen_index][0][word_index]]
gold_tag = label_reverse_dict[data[sen_index][2][
word_index]]
tag = label_reverse_dict[np_data[tag_index][0]]
print word + "\t" + gold_tag + "\t" + tag
word_index += 1
print ""
if __name__ == "__main__":
infer(
model_path="models/params_pass_0",
batch_size=6,
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
use_gpu=False)
import math
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer
from utils import logger, load_dict, get_embedding
def ner_net(word_dict_len, label_dict_len, parallel, stack_num=2):
mark_dict_len = 2
word_dim = 50
mark_dim = 5
hidden_dim = 300
IS_SPARSE = True
embedding_name = 'emb'
def _net_conf(word, mark, target):
word_embedding = fluid.layers.embedding(
input=word,
size=[word_dict_len, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
name=embedding_name, trainable=False))
mark_embedding = fluid.layers.embedding(
input=mark,
size=[mark_dict_len, mark_dim],
dtype='float32',
is_sparse=IS_SPARSE)
word_caps_vector = fluid.layers.concat(
input=[word_embedding, mark_embedding], axis=1)
mix_hidden_lr = 1
rnn_para_attr = fluid.ParamAttr(
initializer=NormalInitializer(
loc=0.0, scale=0.0),
learning_rate=mix_hidden_lr)
hidden_para_attr = fluid.ParamAttr(
initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)),
learning_rate=mix_hidden_lr)
hidden = fluid.layers.fc(
input=word_caps_vector,
name="__hidden00__",
size=hidden_dim,
act="tanh",
bias_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))),
param_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))))
fea = []
for direction in ["fwd", "bwd"]:
for i in range(stack_num):
if i != 0:
hidden = fluid.layers.fc(
name="__hidden%02d_%s__" % (i, direction),
size=hidden_dim,
act="stanh",
bias_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=1.0)),
input=[hidden, rnn[0], rnn[1]],
param_attr=[
hidden_para_attr, rnn_para_attr, rnn_para_attr
])
rnn = fluid.layers.dynamic_lstm(
name="__rnn%02d_%s__" % (i, direction),
input=hidden,
size=hidden_dim,
candidate_activation='relu',
gate_activation='sigmoid',
cell_activation='sigmoid',
bias_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=1.0)),
is_reverse=(i % 2) if direction == "fwd" else not i % 2,
param_attr=rnn_para_attr)
fea += [hidden, rnn[0], rnn[1]]
rnn_fea = fluid.layers.fc(
size=hidden_dim,
bias_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))),
act="stanh",
input=fea,
param_attr=[hidden_para_attr, rnn_para_attr, rnn_para_attr] * 2)
emission = fluid.layers.fc(
size=label_dict_len,
input=rnn_fea,
param_attr=fluid.ParamAttr(initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))))
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=target,
param_attr=fluid.ParamAttr(
name='crfw',
initializer=NormalInitializer(
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)),
learning_rate=mix_hidden_lr))
avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, emission
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data(name='mark', shape=[1], dtype='int64', lod_level=1)
target = fluid.layers.data(
name="target", shape=[1], dtype='int64', lod_level=1)
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
word_ = pd.read_input(word)
mark_ = pd.read_input(mark)
target_ = pd.read_input(target)
avg_cost, emission_base = _net_conf(word_, mark_, target_)
pd.write_output(avg_cost)
pd.write_output(emission_base)
avg_cost_list, emission = pd()
avg_cost = fluid.layers.mean(x=avg_cost_list)
emission.stop_gradient = True
else:
avg_cost, emission = _net_conf(word, mark, target)
return avg_cost, emission, word, mark, target
import os
import math
import numpy as np
import paddle.v2 as paddle
import paddle.fluid as fluid
import reader
from network_conf import ner_net
from utils import logger, load_dict
from utils_extend import to_lodtensor, get_embedding
def test(exe, chunk_evaluator, inference_program, test_data, place):
chunk_evaluator.reset(exe)
for data in test_data():
word = to_lodtensor(map(lambda x: x[0], data), place)
mark = to_lodtensor(map(lambda x: x[1], data), place)
target = to_lodtensor(map(lambda x: x[2], data), place)
acc = exe.run(inference_program,
feed={"word": word,
"mark": mark,
"target": target})
return chunk_evaluator.eval(exe)
def main(train_data_file, test_data_file, vocab_file, target_file, emb_file,
model_save_dir, num_passes, use_gpu, parallel):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
BATCH_SIZE = 200
word_dict = load_dict(vocab_file)
label_dict = load_dict(target_file)
word_vector_values = get_embedding(emb_file)
word_dict_len = len(word_dict)
label_dict_len = len(label_dict)
avg_cost, feature_out, word, mark, target = ner_net(
word_dict_len, label_dict_len, parallel)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
test_target = chunk_evaluator.metrics + chunk_evaluator.states
inference_program = fluid.io.get_inference_program(test_target)
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(train_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(test_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=BATCH_SIZE)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, mark, target], place=place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
embedding_name = 'emb'
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(word_vector_values, place)
batch_id = 0
for pass_id in xrange(num_passes):
chunk_evaluator.reset(exe)
for data in train_reader():
cost, batch_precision, batch_recall, batch_f1_score = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
if batch_id % 5 == 0:
print("Pass " + str(pass_id) + ", Batch " + str(
batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str(
batch_precision[0]) + ", Recall " + str(batch_recall[0])
+ ", F1_score" + str(batch_f1_score[0]))
batch_id = batch_id + 1
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe)
print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score))
pass_precision, pass_recall, pass_f1_score = test(
exe, chunk_evaluator, inference_program, test_reader, place)
print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score))
save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id)
fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'],
[crf_decode], exe)
if __name__ == "__main__":
main(
train_data_file="data/train",
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
emb_file="data/wordVectors.txt",
model_save_dir="models",
num_passes=1000,
use_gpu=False,
parallel=False)
import numpy as np
import paddle.fluid as fluid
def get_embedding(emb_file='data/wordVectors.txt'):
"""
Get the trained word vector.
"""
return np.loadtxt(emb_file, dtype='float32')
def to_lodtensor(data, place):
"""
convert data to lodtensor
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
此目录中代码示例PaddlePaddle所需版本至少为v0.11.0。如果您使用的PaddlePaddle版本早于v0.11.0, [请更新](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# 全球标准化阅读器
该模型实现以下功能:
Jonathan Raiman and John Miller. Globally Normalized Reader. Empirical Methods in Natural Language Processing (EMNLP), 2017
如果您在研究中使用数据集/代码,请引用上述论文:
```text
@inproceedings{raiman2015gnr,
author={Raiman, Jonathan and Miller, John},
booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
title={Globally Normalized Reader},
year={2017},
}
```
您也可以访问 https://github.com/baidu-research/GloballyNormalizedReader 以获取更多信息。
# 安装
1. 请使用 [docker image](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) 安装最新的PaddlePaddle,运行方法:
```bash
docker pull paddledev/paddle
```
2. 下载所有必要的数据,运行方法:
```bash
cd data && ./download.sh && cd ..
```
3. 预处理并特征化数据:
```bash
python featurize.py --datadir data --outdir data/featurized --glove-path data/glove.840B.300d.txt
```
# 模型训练
- 根据需要修改config.py来配置模型,然后运行:
```bash
python train.py 2>&1 | tee train.log
```
# 使用训练过的模型推断
- 运行以下训练模型来推断:
```bash
python infer.py \
--model_path models/pass_00000.tar.gz \
--data_dir data/featurized/ \
--batch_size 2 \
--use_gpu 0 \
--trainer_count 1 \
2>&1 | tee infer.log
```
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
---
......@@ -25,16 +25,16 @@
命名实体识别(Named Entity Recognition,NER)又称作“专名识别”,是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名、专有名词等,是自然语言处理研究的一个基础问题。NER任务通常包括实体边界识别、确定实体类别两部分,可以将其作为序列标注问题解决。
序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例:
序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)定义的标签集,如下是一个NER的标注结果示例:
<p align="center">
<img src="images/ner_label_ins.png" width="80%" align="center"/><br/>
图1. BIO标注方法示例
</p>
根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。
根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。
由于序列标注问题的广泛性,产生了[CRF](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等经典的序列模型,这些模型大多只能使用局部信息或需要人工设计特征。随着深度学习研究的发展,循环神经网络(Recurrent Neural Network,RNN等 序列模型能够处理序列元素之间前后关联问题,能够从原始输入文本中学习特征表示,而更加适合序列标注任务,更多相关知识可参考PaddleBook中[语义角色标注](https://github.com/PaddlePaddle/book/blob/develop/07.label_semantic_roles/README.cn.md)一课。
由于序列标注问题的广泛性,产生了[CRF](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)等经典的序列模型,这些模型大多只能使用局部信息或需要人工设计特征。随着深度学习研究的发展,循环神经网络(Recurrent Neural Network,RNN等 序列模型能够处理序列元素之间前后关联问题,能够从原始输入文本中学习特征表示,而更加适合序列标注任务,更多相关知识可参考PaddleBook中[语义角色标注](https://github.com/PaddlePaddle/book/blob/develop/07.label_semantic_roles/README.cn.md)一课。
## 模型详解
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册