diff --git a/fluid/DeepASR/data_utils/async_data_reader.py b/fluid/DeepASR/data_utils/async_data_reader.py index 1515b299d4357eac16892dceda6e4f05bf1fc045..03448fadccfbcfb67ab28cdf2071fc4b743ef6e5 100644 --- a/fluid/DeepASR/data_utils/async_data_reader.py +++ b/fluid/DeepASR/data_utils/async_data_reader.py @@ -218,8 +218,6 @@ class AsyncDataReader(object): self._sample_proc_num = self._proc_num - 2 self._verbose = verbose self._force_exit = ForceExitWrapper(self._manager.Value('b', False)) - self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size * - 3, self._manager) def generate_bucket_list(self, is_shuffle): if self._block_info_list is None: @@ -424,6 +422,9 @@ class AsyncDataReader(object): sample_queue = self._start_async_processing() batch_queue = self._manager.Queue(self._batch_buffer_size) + self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size * + 3, self._manager) + assembling_proc = DaemonProcessGroup( proc_num=1, target=batch_assembling_task, @@ -439,3 +440,6 @@ class AsyncDataReader(object): if isinstance(batch_data, EpochEndSignal): break yield batch_data + + # clean the shared memory + del self._pool_manager diff --git a/fluid/DeepASR/data_utils/util.py b/fluid/DeepASR/data_utils/util.py index e8ccbadc0bf2106ccabd73a449eb5e53983ccf95..5d519c0ac30cc63c967f25503ca9dff1def59a8e 100644 --- a/fluid/DeepASR/data_utils/util.py +++ b/fluid/DeepASR/data_utils/util.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys +import sys, time from six import reraise from tblib import Traceback from multiprocessing import Manager, Process @@ -161,9 +161,10 @@ class SharedMemoryPoolManager(object): def __init__(self, pool_size, manager, name_prefix='/deep_asr'): self._names = [] self._dict = manager.dict() + self._time_prefix = time.strftime('%Y%m%d%H%M%S') for i in xrange(pool_size): - name = name_prefix + '_' + str(i) + name = name_prefix + '_' + self._time_prefix + '_' + str(i) self._dict[name] = SharedNDArray(name) self._names.append(name) diff --git a/fluid/DeepASR/decoder/decoder.cc b/fluid/DeepASR/decoder/decoder.cc deleted file mode 100644 index a99f972e2fc2341247eb2a6aa564d8d6b5e2905d..0000000000000000000000000000000000000000 --- a/fluid/DeepASR/decoder/decoder.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "decoder.h" - -std::string decode(std::vector> probs_mat) { - // Add decoding logic here - - return "example decoding result"; -} diff --git a/fluid/DeepASR/decoder/decoder.h b/fluid/DeepASR/decoder/decoder.h deleted file mode 100644 index 4a67fa366ae31dd393d0e3b2f04d27e360596fe5..0000000000000000000000000000000000000000 --- a/fluid/DeepASR/decoder/decoder.h +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -std::string decode(std::vector> probs_mat); diff --git a/fluid/DeepASR/decoder/post_decode_faster.cc b/fluid/DeepASR/decoder/post_decode_faster.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7f1d1ab34a18285d1d96b9ff6a67cff42d519b3 --- /dev/null +++ b/fluid/DeepASR/decoder/post_decode_faster.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "post_decode_faster.h" + +typedef kaldi::int32 int32; +using fst::SymbolTable; +using fst::VectorFst; +using fst::StdArc; + +Decoder::Decoder(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename) { + const char* usage = + "Decode, reading log-likelihoods (of transition-ids or whatever symbol " + "is on the graph) as matrices."; + + kaldi::ParseOptions po(usage); + binary = true; + acoustic_scale = 1.5; + allow_partial = true; + kaldi::FasterDecoderOptions decoder_opts; + decoder_opts.Register(&po, true); // true == include obscure settings. + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("allow-partial", + &allow_partial, + "Produce output even when final state was not reached"); + po.Register("acoustic-scale", + &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + word_syms = NULL; + if (word_syms_filename != "") { + word_syms = fst::SymbolTable::ReadText(word_syms_filename); + if (!word_syms) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + } + + std::ifstream is_logprior(logprior_rxfilename); + logprior.Read(is_logprior, false); + + // It's important that we initialize decode_fst after loglikes_reader, as it + // can prevent crashes on systems installed without enough virtual memory. + // It has to do with what happens on UNIX systems if you call fork() on a + // large process: the page-table entries are duplicated, which requires a + // lot of virtual memory. + decode_fst = fst::ReadFstKaldi(fst_in_filename); + + decoder = new kaldi::FasterDecoder(*decode_fst, decoder_opts); +} + + +Decoder::~Decoder() { + if (!word_syms) delete word_syms; + delete decode_fst; + delete decoder; +} + +std::string Decoder::decode( + std::string key, + const std::vector>& log_probs) { + size_t num_frames = log_probs.size(); + size_t dim_label = log_probs[0].size(); + + kaldi::Matrix loglikes( + num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols); + for (size_t i = 0; i < num_frames; ++i) { + memcpy(loglikes.Data() + i * dim_label, + log_probs[i].data(), + sizeof(kaldi::BaseFloat) * dim_label); + } + + return decode(key, loglikes); +} + + +std::vector Decoder::decode(std::string posterior_rspecifier) { + kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier); + std::vector decoding_results; + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + std::string key = posterior_reader.Key(); + kaldi::Matrix loglikes(posterior_reader.Value()); + + decoding_results.push_back(decode(key, loglikes)); + } + + return decoding_results; +} + + +std::string Decoder::decode(std::string key, + kaldi::Matrix& loglikes) { + std::string decoding_result; + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << key; + } + KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); + + loglikes.ApplyLog(); + loglikes.AddVecToRows(-1.0, logprior); + + kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale); + decoder->Decode(&decodable); + + VectorFst decoded; // linear FST. + + if ((allow_partial || decoder->ReachedFinal()) && + decoder->GetBestPath(&decoded)) { + if (!decoder->ReachedFinal()) + KALDI_WARN << "Decoder did not reach end-state, outputting partial " + "traceback."; + + std::vector alignment; + std::vector words; + kaldi::LatticeWeight weight; + + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + + if (word_syms != NULL) { + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + decoding_result += s; + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + } + } + } + + return decoding_result; +} diff --git a/fluid/DeepASR/decoder/post_decode_faster.h b/fluid/DeepASR/decoder/post_decode_faster.h new file mode 100644 index 0000000000000000000000000000000000000000..2e31a1c19e40bd879a1c76f1542b94eaa853be12 --- /dev/null +++ b/fluid/DeepASR/decoder/post_decode_faster.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "base/kaldi-common.h" +#include "base/timer.h" +#include "decoder/decodable-matrix.h" +#include "decoder/faster-decoder.h" +#include "fstext/fstext-lib.h" +#include "hmm/transition-model.h" +#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc +#include "tree/context-dep.h" +#include "util/common-utils.h" + + +class Decoder { +public: + Decoder(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename); + ~Decoder(); + + // Interface to accept the scores read from specifier and return + // the batch decoding results + std::vector decode(std::string posterior_rspecifier); + + // Accept the scores of one utterance and return the decoding result + std::string decode( + std::string key, + const std::vector> &log_probs); + +private: + // For decoding one utterance + std::string decode(std::string key, + kaldi::Matrix &loglikes); + + fst::SymbolTable *word_syms; + fst::VectorFst *decode_fst; + kaldi::FasterDecoder *decoder; + kaldi::Vector logprior; + + bool binary; + kaldi::BaseFloat acoustic_scale; + bool allow_partial; +}; diff --git a/fluid/DeepASR/decoder/pybind.cc b/fluid/DeepASR/decoder/pybind.cc index 8cd65903eae63268d9f4412bd737162c639d8910..56439d180263b4d753eccd82826d1b39c9d2fa85 100644 --- a/fluid/DeepASR/decoder/pybind.cc +++ b/fluid/DeepASR/decoder/pybind.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,25 @@ limitations under the License. */ #include #include -#include "decoder.h" +#include "post_decode_faster.h" namespace py = pybind11; -PYBIND11_MODULE(decoder, m) { - m.doc() = "Decode function for Deep ASR model"; - - m.def("decode", - &decode, - "Decode one input probability matrix " - "and return the transcription"); +PYBIND11_MODULE(post_decode_faster, m) { + m.doc() = "Decoder for Deep ASR model"; + + py::class_(m, "Decoder") + .def(py::init()) + .def("decode", + (std::vector (Decoder::*)(std::string)) & + Decoder::decode, + "Decode for the probability matrices in specifier " + "and return the transcriptions.") + .def( + "decode", + (std::string (Decoder::*)( + std::string, const std::vector>&)) & + Decoder::decode, + "Decode one input probability matrix " + "and return the transcription."); } diff --git a/fluid/DeepASR/decoder/setup.py b/fluid/DeepASR/decoder/setup.py index cedd5d644e0dc1ca8855ab7e75ee1b7a30f8fcb1..a98c0b4cc17717a6769b8322e4f5afe3de6ab2de 100644 --- a/fluid/DeepASR/decoder/setup.py +++ b/fluid/DeepASR/decoder/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,27 +13,57 @@ # limitations under the License. import os +import glob from distutils.core import setup, Extension from distutils.sysconfig import get_config_vars -args = ['-std=c++11'] +try: + kaldi_root = os.environ['KALDI_ROOT'] +except: + raise ValueError("Enviroment variable 'KALDI_ROOT' is not defined. Please " + "install kaldi and export KALDI_ROOT= .") + +args = [ + '-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable', + '-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable', + '-Wno-deprecated-declarations', '-Wno-unused-function' +] # remove warning about -Wstrict-prototypes (opt, ) = get_config_vars('OPT') os.environ['OPT'] = " ".join(flag for flag in opt.split() if flag != '-Wstrict-prototypes') +os.environ['CC'] = 'g++' + +LIBS = [ + 'fst', 'kaldi-base', 'kaldi-util', 'kaldi-matrix', 'kaldi-tree', + 'kaldi-hmm', 'kaldi-fstext', 'kaldi-decoder', 'kaldi-lat' +] + +LIB_DIRS = [ + 'tools/openfst/lib', 'src/base', 'src/matrix', 'src/util', 'src/tree', + 'src/hmm', 'src/fstext', 'src/decoder', 'src/lat' +] +LIB_DIRS = [os.path.join(kaldi_root, path) for path in LIB_DIRS] +LIB_DIRS = [os.path.abspath(path) for path in LIB_DIRS] ext_modules = [ Extension( - 'decoder', - ['pybind.cc', 'decoder.cc'], - include_dirs=['pybind11/include', '.'], + 'post_decode_faster', + ['pybind.cc', 'post_decode_faster.cc'], + include_dirs=[ + 'pybind11/include', '.', os.path.join(kaldi_root, 'src'), + os.path.join(kaldi_root, 'tools/openfst/src/include') + ], language='c++', + libraries=LIBS, + library_dirs=LIB_DIRS, + runtime_library_dirs=LIB_DIRS, extra_compile_args=args, ), ] setup( - name='decoder', + name='post_decode_faster', version='0.0.1', author='Paddle', author_email='', diff --git a/fluid/DeepASR/decoder/setup.sh b/fluid/DeepASR/decoder/setup.sh index 71fd6626efe1b7cf72a1e678ab7b74000ebfb8c3..1471f85f414ae8dd5230f04cf08da282adc3b0b7 100644 --- a/fluid/DeepASR/decoder/setup.sh +++ b/fluid/DeepASR/decoder/setup.sh @@ -1,4 +1,4 @@ - +set -e if [ ! -d pybind11 ]; then git clone https://github.com/pybind/pybind11.git diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index 68dd573647d498704fd22f70a7df2255e7ac66cd..f267f674986a87d552bb1a2a277c21c27cca148a 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -13,7 +13,7 @@ import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice import data_utils.async_data_reader as reader -import decoder.decoder as decoder +from decoder.post_decode_faster import Decoder from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model from data_utils.util import split_infer_result @@ -32,6 +32,11 @@ def parse_args(): default=1, help='The minimum sequence number of a batch data. ' '(default: %(default)d)') + parser.add_argument( + '--frame_dim', + type=int, + default=120 * 11, + help='Frame dimension of feature data. (default: %(default)d)') parser.add_argument( '--stacked_num', type=int, @@ -47,6 +52,11 @@ def parse_args(): type=int, default=1024, help='Hidden size of lstmp unit. (default: %(default)d)') + parser.add_argument( + '--class_num', + type=int, + default=1749, + help='Number of classes in label. (default: %(default)d)') parser.add_argument( '--learning_rate', type=float, @@ -81,6 +91,21 @@ def parse_args(): type=str, default='./checkpoint', help="The checkpoint path to init model. (default: %(default)s)") + parser.add_argument( + '--vocabulary', + type=str, + default='./decoder/graph/words.txt', + help="The path to vocabulary. (default: %(default)s)") + parser.add_argument( + '--graphs', + type=str, + default='./decoder/graph/TLG.fst', + help="The path to TLG graphs for decoding. (default: %(default)s)") + parser.add_argument( + '--log_prior', + type=str, + default="./decoder/logprior", + help="The log prior probs for training data. (default: %(default)s)") args = parser.parse_args() return args @@ -99,10 +124,11 @@ def infer_from_ckpt(args): raise IOError("Invalid checkpoint!") prediction, avg_cost, accuracy = stacked_lstmp_model( + frame_dim=args.frame_dim, hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, stacked_num=args.stacked_num, - class_num=1749, + class_num=args.class_num, parallel=args.parallel) infer_program = fluid.default_main_program().clone() @@ -154,8 +180,8 @@ def infer_from_ckpt(args): probs, lod = lodtensor_to_ndarray(results[0]) infer_batch = split_infer_result(probs, lod) for index, sample in enumerate(infer_batch): - print("Decoding %d: " % (batch_id * args.batch_size + index), - decoder.decode(sample)) + key = "utter#%d" % (batch_id * args.batch_size + index) + print(key, ": ", decoder.decode(key, sample), "\n") print(np.mean(infer_costs), np.mean(infer_accs)) diff --git a/fluid/DeepASR/model_utils/model.py b/fluid/DeepASR/model_utils/model.py index 541f869c7224e620c519c97472dbe79ca73bd84b..8fb7596e122447979cf392d6610ad2b7281d195b 100644 --- a/fluid/DeepASR/model_utils/model.py +++ b/fluid/DeepASR/model_utils/model.py @@ -6,7 +6,8 @@ import paddle.v2 as paddle import paddle.fluid as fluid -def stacked_lstmp_model(hidden_dim, +def stacked_lstmp_model(frame_dim, + hidden_dim, proj_dim, stacked_num, class_num, @@ -20,12 +21,13 @@ def stacked_lstmp_model(hidden_dim, label data respectively. And in inference, only `feature` is needed. Args: - hidden_dim(int): The hidden state's dimension of the LSTMP layer. - proj_dim(int): The projection size of the LSTMP layer. - stacked_num(int): The number of stacked LSTMP layers. - parallel(bool): Run in parallel or not, default `False`. - is_train(bool): Run in training phase or not, default `True`. - class_dim(int): The number of output classes. + frame_dim(int): The frame dimension of feature data. + hidden_dim(int): The hidden state's dimension of the LSTMP layer. + proj_dim(int): The projection size of the LSTMP layer. + stacked_num(int): The number of stacked LSTMP layers. + parallel(bool): Run in parallel or not, default `False`. + is_train(bool): Run in training phase or not, default `True`. + class_dim(int): The number of output classes. """ # network configuration @@ -78,7 +80,7 @@ def stacked_lstmp_model(hidden_dim, # data feeder feature = fluid.layers.data( - name="feature", shape=[-1, 120 * 11], dtype="float32", lod_level=1) + name="feature", shape=[-1, frame_dim], dtype="float32", lod_level=1) label = fluid.layers.data( name="label", shape=[-1, 1], dtype="int64", lod_level=1) @@ -92,11 +94,12 @@ def stacked_lstmp_model(hidden_dim, feat_ = pd.read_input(feature) label_ = pd.read_input(label) prediction, avg_cost, acc = _net_conf(feat_, label_) - for out in [avg_cost, acc]: + for out in [prediction, avg_cost, acc]: pd.write_output(out) # get mean loss and acc through every devices. - avg_cost, acc = pd() + prediction, avg_cost, acc = pd() + prediction.stop_gradient = True avg_cost = fluid.layers.mean(x=avg_cost) acc = fluid.layers.mean(x=acc) else: diff --git a/fluid/DeepASR/tools/profile.py b/fluid/DeepASR/tools/profile.py index 77dff3cb371659ce672e10735174b846827c9d6b..cf7329445393a3e767f35cd23939dc6777e06633 100644 --- a/fluid/DeepASR/tools/profile.py +++ b/fluid/DeepASR/tools/profile.py @@ -31,6 +31,11 @@ def parse_args(): default=1, help='The minimum sequence number of a batch data. ' '(default: %(default)d)') + parser.add_argument( + '--frame_dim', + type=int, + default=120 * 11, + help='Frame dimension of feature data. (default: %(default)d)') parser.add_argument( '--stacked_num', type=int, @@ -46,6 +51,11 @@ def parse_args(): type=int, default=1024, help='Hidden size of lstmp unit. (default: %(default)d)') + parser.add_argument( + '--class_num', + type=int, + default=1749, + help='Number of classes in label. (default: %(default)d)') parser.add_argument( '--learning_rate', type=float, @@ -119,10 +129,11 @@ def profile(args): "arg 'first_batches_to_skip' must not be smaller than 0.") _, avg_cost, accuracy = stacked_lstmp_model( + frame_dim=args.frame_dim, hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, stacked_num=args.stacked_num, - class_num=1749, + class_num=args.class_num, parallel=args.parallel) optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 8297ab7403775125b6da372c4d916e5e3111b669..446e9e0ab16b1d1ee98738ca8cc1510e0e96636e 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -30,6 +30,11 @@ def parse_args(): default=1, help='The minimum sequence number of a batch data. ' '(default: %(default)d)') + parser.add_argument( + '--frame_dim', + type=int, + default=120 * 11, + help='Frame dimension of feature data. (default: %(default)d)') parser.add_argument( '--stacked_num', type=int, @@ -45,6 +50,11 @@ def parse_args(): type=int, default=1024, help='Hidden size of lstmp unit. (default: %(default)d)') + parser.add_argument( + '--class_num', + type=int, + default=1749, + help='Number of classes in label. (default: %(default)d)') parser.add_argument( '--pass_num', type=int, @@ -137,10 +147,11 @@ def train(args): os.mkdir(args.infer_models) prediction, avg_cost, accuracy = stacked_lstmp_model( + frame_dim=args.frame_dim, hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, stacked_num=args.stacked_num, - class_num=1749, + class_num=args.class_num, parallel=args.parallel) # program for test diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 091ea175291c56d63e1d8b42a874516d9733f1cf..71e4314953383b8f89b40fdfd8cc4274f954fed1 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -3,18 +3,37 @@ class TrainTaskConfig(object): # the epoch number to train. pass_num = 2 - # number of sequences contained in a mini-batch. + # the number of sequences contained in a mini-batch. batch_size = 64 - # the hyper params for Adam optimizer. + # the hyper parameters for Adam optimizer. learning_rate = 0.001 beta1 = 0.9 beta2 = 0.98 eps = 1e-9 - # the params for learning rate scheduling + # the parameters for learning rate scheduling. warmup_steps = 4000 + # the directory for saving trained models. + model_dir = "trained_models" + + +class InferTaskConfig(object): + use_gpu = False + # the number of examples in one run for sequence generation. + # currently the batch size can only be set to 1. + batch_size = 1 + + # the parameters for beam search. + beam_size = 5 + max_length = 30 + # the number of decoded sentences to output. + n_best = 1 + + # the directory for loading the trained model. + model_path = "trained_models/pass_1.infer.model" + class ModelHyperParams(object): # Dictionary size for source and target language. This model directly uses @@ -33,6 +52,11 @@ class ModelHyperParams(object): # index for token in target language. trg_pad_idx = trg_vocab_size + # index for token + bos_idx = 0 + # index for token + eos_idx = 1 + # position value corresponding to the token. pos_pad_idx = 0 @@ -64,14 +88,21 @@ pos_enc_param_names = ( "src_pos_enc_table", "trg_pos_enc_table", ) -# Names of all data layers listed in order. -input_data_names = ( +# Names of all data layers in encoder listed in order. +encoder_input_data_names = ( "src_word", "src_pos", + "src_slf_attn_bias", ) + +# Names of all data layers in decoder listed in order. +decoder_input_data_names = ( "trg_word", "trg_pos", - "src_slf_attn_bias", "trg_slf_attn_bias", "trg_src_attn_bias", + "enc_output", ) + +# Names of label related data layers listed in order. +label_data_names = ( "lbl_word", "lbl_weight", ) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4dee220cedf856633ee626b762804e49a10cfe8 --- /dev/null +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -0,0 +1,234 @@ +import numpy as np + +import paddle.v2 as paddle +import paddle.fluid as fluid + +import model +from model import wrap_encoder as encoder +from model import wrap_decoder as decoder +from config import InferTaskConfig, ModelHyperParams, \ + encoder_input_data_names, decoder_input_data_names +from train import pad_batch_data + + +def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, + decoder, dec_in_names, dec_out_names, beam_size, max_length, + n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, + bos_idx, eos_idx): + """ + Run the encoder program once and run the decoder program multiple times to + implement beam search externally. + """ + # Prepare data for encoder and run the encoder. + enc_in_data = pad_batch_data( + src_words, + src_pad_idx, + n_head, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True) + enc_output = exe.run(encoder, + feed=dict(zip(enc_in_names, enc_in_data)), + fetch_list=enc_out_names)[0] + + # Beam Search. + # To store the beam info. + scores = np.zeros((batch_size, beam_size), dtype="float32") + prev_branchs = [[]] * batch_size + next_ids = [[]] * batch_size + # Use beam_map to map the instance idx in batch to beam idx, since the + # size of feeded batch is changing. + beam_map = range(batch_size) + + def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): + """ + Decode and select n_best sequences for one instance by backtrace. + """ + seqs = [] + for i in range(n_best): + k = i + seq = [] + for j in range(len(prev_branchs) - 1, -1, -1): + seq.append(next_ids[j][k]) + k = prev_branchs[j][k] + seq = seq[::-1] + seq = [bos_idx] + seq if add_bos else seq + seqs.append(seq) + return seqs + + def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output): + """ + Initialize the input data for decoder. + """ + trg_words = np.array( + [[bos_idx]] * batch_size * beam_size, dtype="int64") + trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") + src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ + -1], enc_in_data[-2], 1 + # This is used to remove attention on subsequent words. + trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, + trg_max_len)) + trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape( + [-1, 1, trg_max_len, trg_max_len]) + trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) * + [-1e9]).astype("float32") + # This is used to remove attention on the paddings of source sequences. + trg_src_attn_bias = np.tile( + src_slf_attn_bias[:, :, ::src_max_length, :], + [beam_size, 1, trg_max_len, 1]) + enc_output = np.tile(enc_output, [beam_size, 1, 1]) + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + + def update_dec_in_data(dec_in_data, next_ids, active_beams): + """ + Update the input data of decoder mainly by slicing from the previous + input data and dropping the finished instance beams. + """ + trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data + trg_cur_len = len(next_ids[0]) + 1 # include the + trg_words = np.array( + [ + beam_backtrace( + prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True) + for beam_idx in active_beams + ], + dtype="int64") + trg_words = trg_words.reshape([-1, 1]) + trg_pos = np.array( + [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, + dtype="int64").reshape([-1, 1]) + active_beams_indice = ( + (np.array(active_beams) * beam_size)[:, np.newaxis] + + np.array(range(beam_size))[np.newaxis, :]).flatten() + # This is used to remove attention on subsequent words. + trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len, + trg_cur_len)) + trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape( + [-1, 1, trg_cur_len, trg_cur_len]) + trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) * + [-1e9]).astype("float32") + # This is used to remove attention on the paddings of source sequences. + trg_src_attn_bias = np.tile(trg_src_attn_bias[ + active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], + [1, 1, trg_cur_len, 1]) + enc_output = enc_output[active_beams_indice, :, :] + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + + dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, + enc_output) + for i in range(max_length): + predict_all = exe.run(decoder, + feed=dict(zip(dec_in_names, dec_in_data)), + fetch_list=dec_out_names)[0] + predict_all = np.log( + predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, + -1, :]) + predict_all = (predict_all + scores[beam_map].reshape( + [len(beam_map) * beam_size, -1])).reshape( + [len(beam_map), beam_size, -1]) + active_beams = [] + for inst_idx, beam_idx in enumerate(beam_map): + predict = (predict_all[inst_idx, :, :] + if i != 0 else predict_all[inst_idx, 0, :]).flatten() + top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] + top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: + -1]] + top_scores = predict[top_scores_ids] + scores[beam_idx] = top_scores + prev_branchs[beam_idx].append(top_scores_ids / + predict_all.shape[-1]) + next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) + if next_ids[beam_idx][-1][0] != eos_idx: + active_beams.append(beam_idx) + beam_map = active_beams + if len(beam_map) == 0: + break + dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) + + # Decode beams and select n_best sequences for each instance by backtrace. + seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] + + return seqs, scores[:, :n_best].tolist() + + +def main(): + place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + # The current program desc is coupled with batch_size and the only + # supported batch size is 1 currently. + encoder_program = fluid.Program() + model.batch_size = InferTaskConfig.batch_size + with fluid.program_guard(main_program=encoder_program): + enc_output = encoder( + ModelHyperParams.src_vocab_size + 1, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) + + model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size + decoder_program = fluid.Program() + with fluid.program_guard(main_program=decoder_program): + predict = decoder( + ModelHyperParams.trg_vocab_size + 1, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) + + # Load model parameters of encoder and decoder separately from the saved + # transformer model. + encoder_var_names = [] + for op in encoder_program.block(0).ops: + encoder_var_names += op.input_arg_names + encoder_param_names = filter( + lambda var_name: isinstance(encoder_program.block(0).var(var_name), + fluid.framework.Parameter), + encoder_var_names) + encoder_params = map(encoder_program.block(0).var, encoder_param_names) + decoder_var_names = [] + for op in decoder_program.block(0).ops: + decoder_var_names += op.input_arg_names + decoder_param_names = filter( + lambda var_name: isinstance(decoder_program.block(0).var(var_name), + fluid.framework.Parameter), + decoder_var_names) + decoder_params = map(decoder_program.block(0).var, decoder_param_names) + fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params) + fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) + + # This is used here to set dropout to the test mode. + encoder_program = fluid.io.get_inference_program( + target_vars=[enc_output], main_program=encoder_program) + decoder_program = fluid.io.get_inference_program( + target_vars=[predict], main_program=decoder_program) + + test_data = paddle.batch( + paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + batch_size=InferTaskConfig.batch_size) + + trg_idx2word = paddle.dataset.wmt16.get_dict( + "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) + + for batch_id, data in enumerate(test_data()): + batch_seqs, batch_scores = translate_batch( + exe, [item[0] for item in data], encoder_program, + encoder_input_data_names, [enc_output.name], decoder_program, + decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, + InferTaskConfig.max_length, InferTaskConfig.n_best, + len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, + ModelHyperParams.eos_idx) + for i in range(len(batch_seqs)): + seqs = batch_seqs[i] + scores = batch_scores[i] + for seq in seqs: + print(" ".join([trg_idx2word[idx] for idx in seq])) + + +if __name__ == "__main__": + main() diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 379a17221c3aaa4daf7f530f9553bcef89b42de6..ba5ba4470759da5fd2c6dd3b3d61b88c3468bd27 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -4,7 +4,8 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from config import TrainTaskConfig, input_data_names, pos_enc_param_names +from config import TrainTaskConfig, pos_enc_param_names, \ + encoder_input_data_names, decoder_input_data_names, label_data_names # FIXME(guosheng): Remove out the batch_size from the model. batch_size = TrainTaskConfig.batch_size @@ -127,7 +128,9 @@ def multi_head_attention(queries, scaled_q = layers.scale(x=q, scale=d_model**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) - weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) + weights = __softmax( + layers.elementwise_add( + x=product, y=attn_bias) if attn_bias else product) if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) @@ -280,8 +283,15 @@ def encoder(enc_input, encoder_layer. """ for i in range(n_layer): - enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, - d_model, d_inner_hid, dropout_rate) + enc_output = encoder_layer( + enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) enc_input = enc_output return enc_output @@ -373,6 +383,62 @@ def decoder(dec_input, return dec_output +def make_inputs(input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos, + slf_attn_bias_flag, + src_attn_bias_flag, + enc_output_flag=False): + """ + Define the input data layers for the transformer model. + """ + input_layers = [] + # The shapes here act as placeholder. + # The shapes set here is to pass the infer-shape in compile time. + word = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + input_layers += [word] + # This is used for position data or label weight. + pos = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size * max_length, 1], + dtype="int64" if is_pos else "float32", + append_batch_size=False) + input_layers += [pos] + if slf_attn_bias_flag: + # This input is used to remove attention weights on paddings for the + # encoder and to remove attention weights on subsequent words for the + # decoder. + slf_attn_bias = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + input_layers += [slf_attn_bias] + if src_attn_bias_flag: + # This input is used to remove attention weights on paddings. + src_attn_bias = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + input_layers += [src_attn_bias] + if enc_output_flag: + enc_output = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size, max_length, d_model], + dtype="float32", + append_batch_size=False) + input_layers += [enc_output] + return input_layers + + def transformer( src_vocab_size, trg_vocab_size, @@ -387,61 +453,72 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - # The shapes here act as placeholder. - # The shapes set here is to pass the infer-shape in compile time. The actual - # shape of src_word in run time is: - # [batch_size * max_src_length_in_a_batch, 1]. - src_word = layers.data( - name=input_data_names[0], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of src_pos in runtime is: - # [batch_size * max_src_length_in_a_batch, 1]. - src_pos = layers.data( - name=input_data_names[1], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of trg_word is in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - trg_word = layers.data( - name=input_data_names[2], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of trg_pos in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - trg_pos = layers.data( - name=input_data_names[3], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of src_slf_attn_bias in runtime is: - # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch]. - # This input is used to remove attention weights on paddings. - src_slf_attn_bias = layers.data( - name=input_data_names[4], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) - # The actual shape of trg_slf_attn_bias in runtime is: - # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch]. - # This is used to remove attention weights on paddings and subsequent words. - trg_slf_attn_bias = layers.data( - name=input_data_names[5], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) - # The actual shape of trg_src_attn_bias in runtime is: - # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch]. - # This is used to remove attention weights on paddings. - trg_src_attn_bias = layers.data( - name=input_data_names[6], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) + enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, + batch_size, max_length, True, True, False) + + enc_output = wrap_encoder( + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + pos_pad_idx, + enc_input_layers, ) + + dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, + batch_size, max_length, True, True, True) + + predict = wrap_decoder( + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + trg_pad_idx, + pos_pad_idx, + dec_input_layers, + enc_output, ) + # Padding index do not contribute to the total loss. The weights is used to + # cancel padding index in calculating the loss. + gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, + max_length, False, False, False) + cost = layers.cross_entropy(input=predict, label=gold) + weighted_cost = cost * weights + return layers.reduce_sum(weighted_cost), predict + + +def wrap_encoder(src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + pos_pad_idx, + enc_input_layers=None): + """ + The wrapper assembles together all needed layers for the encoder. + """ + if enc_input_layers is None: + # This is used to implement independent encoder program in inference. + src_word, src_pos, src_slf_attn_bias = make_inputs( + encoder_input_data_names, n_head, d_model, batch_size, max_length, + True, True, False) + else: + src_word, src_pos, src_slf_attn_bias = enc_input_layers enc_input = prepare_encoder( src_word, src_pos, @@ -460,6 +537,32 @@ def transformer( d_model, d_inner_hid, dropout_rate, ) + return enc_output + + +def wrap_decoder(trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + trg_pad_idx, + pos_pad_idx, + dec_input_layers=None, + enc_output=None): + """ + The wrapper assembles together all needed layers for the decoder. + """ + if dec_input_layers is None: + # This is used to implement independent decoder program in inference. + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( + decoder_input_data_names, n_head, d_model, batch_size, max_length, + True, True, True, True) + else: + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers dec_input = prepare_decoder( trg_word, @@ -482,32 +585,11 @@ def transformer( d_inner_hid, dropout_rate, ) - # TODO(guosheng): Share the weight matrix between the embedding layers and - # the pre-softmax linear transformation. predict = layers.reshape( x=layers.fc(input=dec_output, size=trg_vocab_size, - param_attr=fluid.initializer.Xavier(uniform=False), bias_attr=False, num_flatten_dims=2), shape=[-1, trg_vocab_size], act="softmax") - # The actual shape of gold in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - gold = layers.data( - name=input_data_names[7], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - cost = layers.cross_entropy(input=predict, label=gold) - # The actual shape of weights in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - # Padding index do not contribute to the total loss. This Weight is used to - # cancel padding index in calculating the loss. - weights = layers.data( - name=input_data_names[8], - shape=[batch_size * max_length, 1], - dtype="float32", - append_batch_size=False) - weighted_cost = cost * weights - return layers.reduce_sum(weighted_cost) + return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 0494792c5180382139635ab70ef60d3f688e213b..65de8ef7fa8421bd72175175f1cf421a4237ddd5 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -1,3 +1,4 @@ +import os import numpy as np import paddle.v2 as paddle @@ -5,86 +6,74 @@ import paddle.fluid as fluid from model import transformer, position_encoding_init from optim import LearningRateScheduler -from config import TrainTaskConfig, ModelHyperParams, \ - pos_enc_param_names, input_data_names +from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ + encoder_input_data_names, decoder_input_data_names, label_data_names -def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, - max_length, n_head, place): +def pad_batch_data(insts, + pad_idx, + n_head, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True): """ Pad the instances to the max sequence length in batch, and generate the - corresponding position data and attention bias. Then, convert the numpy - data to tensors and return a dict mapping names to tensors. + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if return_pos: + inst_pos = np.array([[ + pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) + ] for inst in inst_data]) + + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( + [-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * + (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + return return_list if len(return_list) > 1 else return_list[0] + + +def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, + max_length, n_head): + """ + Put all padded data needed by training into a dict. """ - input_dict = {} - - def __pad_batch_data(insts, - pad_idx, - is_target=False, - return_pos=True, - return_attn_bias=True, - return_max_len=True): - """ - Pad the instances to the max sequence length in batch, and generate the - corresponding position data and attention bias. - """ - return_list = [] - max_len = max(len(inst) for inst in insts) - inst_data = np.array( - [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) - return_list += [inst_data.astype("int64").reshape([-1, 1])] - if return_pos: - inst_pos = np.array([[ - pos_i + 1 if w_i != pad_idx else 0 - for pos_i, w_i in enumerate(inst) - ] for inst in inst_data]) - - return_list += [inst_pos.astype("int64").reshape([-1, 1])] - if return_attn_bias: - if is_target: - # This is used to avoid attention on paddings and subsequent - # words. - slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, - max_len)) - slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( - [-1, 1, max_len, max_len]) - slf_attn_bias_data = np.tile(slf_attn_bias_data, - [1, n_head, 1, 1]) * [-1e9] - else: - # This is used to avoid attention on paddings. - slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * - (max_len - len(inst)) - for inst in insts]) - slf_attn_bias_data = np.tile( - slf_attn_bias_data.reshape([-1, 1, 1, max_len]), - [1, n_head, max_len, 1]) - return_list += [slf_attn_bias_data.astype("float32")] - if return_max_len: - return_list += [max_len] - return return_list if len(return_list) > 1 else return_list[0] - - def data_to_tensor(data_list, name_list, input_dict, place): - assert len(data_list) == len(name_list) - for i in range(len(name_list)): - tensor = fluid.LoDTensor() - tensor.set(data_list[i], place) - input_dict[name_list[i]] = tensor - - src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, is_target=False) - trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data( - [inst[1] for inst in insts], trg_pad_idx, is_target=True) + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") - lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, - False, False, False) + lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, + False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) - - data_to_tensor([ - src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight - ], input_data_names, input_dict, place) - + input_dict = dict( + zip(input_data_names, [ + src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + ])) return input_dict @@ -92,7 +81,7 @@ def main(): place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - cost = transformer( + cost, predict = transformer( ModelHyperParams.src_vocab_size + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, @@ -118,6 +107,31 @@ def main(): buf_size=100000), batch_size=TrainTaskConfig.batch_size) + # Program to do validation. + test_program = fluid.default_main_program().clone() + with fluid.program_guard(test_program): + test_program = fluid.io.get_inference_program([cost]) + val_data = paddle.batch( + paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + batch_size=TrainTaskConfig.batch_size) + + def test(exe): + test_costs = [] + for batch_id, data in enumerate(val_data()): + if len(data) != TrainTaskConfig.batch_size: + continue + data_input = prepare_batch_input( + data, encoder_input_data_names + decoder_input_data_names[:-1] + + label_data_names, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, + ModelHyperParams.n_head) + test_cost = exe.run(test_program, + feed=data_input, + fetch_list=[cost])[0] + test_costs.append(test_cost) + return np.mean(test_costs) + # Initialize the parameters. exe.run(fluid.framework.default_startup_program()) for pos_enc_param_name in pos_enc_param_names: @@ -134,9 +148,10 @@ def main(): if len(data) != TrainTaskConfig.batch_size: continue data_input = prepare_batch_input( - data, input_data_names, ModelHyperParams.src_pad_idx, + data, encoder_input_data_names + decoder_input_data_names[:-1] + + label_data_names, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head, place) + ModelHyperParams.n_head) lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, @@ -145,6 +160,14 @@ def main(): cost_val = np.array(outs[0]) print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + " cost = " + str(cost_val)) + # Validate and save the model for inference. + val_cost = test(exe) + print("pass_id = " + str(pass_id) + " val_cost = " + str(val_cost)) + fluid.io.save_inference_model( + os.path.join(TrainTaskConfig.model_dir, + "pass_" + str(pass_id) + ".infer.model"), + encoder_input_data_names + decoder_input_data_names[:-1], + [predict], exe) if __name__ == "__main__": diff --git a/fluid/object_detection/image_util.py b/fluid/object_detection/image_util.py index ba8744eda0a078acd38cad9b10ca7511185efc43..781932293e57c715d15f9b26ceec345b6b81cd26 100644 --- a/fluid/object_detection/image_util.py +++ b/fluid/object_detection/image_util.py @@ -1,4 +1,4 @@ -from PIL import Image +from PIL import Image, ImageEnhance import numpy as np import random import math @@ -159,3 +159,77 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height): sample_img = img[ymin:ymax, xmin:xmax] sample_labels = transform_labels(bbox_labels, sample_bbox) return sample_img, sample_labels + + +def random_brightness(img, settings): + prob = random.uniform(0, 1) + if prob < settings._brightness_prob: + delta = random.uniform(-settings._brightness_delta, + settings._brightness_delta) + 1 + img = ImageEnhance.Brightness(img).enhance(delta) + return img + + +def random_contrast(img, settings): + prob = random.uniform(0, 1) + if prob < settings._contrast_prob: + delta = random.uniform(-settings._contrast_delta, + settings._contrast_delta) + 1 + img = ImageEnhance.Contrast(img).enhance(delta) + return img + + +def random_saturation(img, settings): + prob = random.uniform(0, 1) + if prob < settings._saturation_prob: + delta = random.uniform(-settings._saturation_delta, + settings._saturation_delta) + 1 + img = ImageEnhance.Color(img).enhance(delta) + return img + + +def random_hue(img, settings): + prob = random.uniform(0, 1) + if prob < settings._hue_prob: + delta = random.uniform(-settings._hue_delta, settings._hue_delta) + img_hsv = np.array(img.convert('HSV')) + img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta + img = Image.fromarray(img_hsv, mode='HSV').convert('RGB') + return img + + +def distort_image(img, settings): + prob = random.uniform(0, 1) + # Apply different distort order + if prob > 0.5: + img = random_brightness(img, settings) + img = random_contrast(img, settings) + img = random_saturation(img, settings) + img = random_hue(img, settings) + else: + img = random_brightness(img, settings) + img = random_saturation(img, settings) + img = random_hue(img, settings) + img = random_contrast(img, settings) + return img + + +def expand_image(img, bbox_labels, img_width, img_height, settings): + prob = random.uniform(0, 1) + if prob < settings._hue_prob: + expand_ratio = random.uniform(1, settings._expand_max_ratio) + if expand_ratio - 1 >= 0.01: + height = int(img_height * expand_ratio) + width = int(img_width * expand_ratio) + h_off = math.floor(random.uniform(0, height - img_height)) + w_off = math.floor(random.uniform(0, width - img_width)) + expand_bbox = bbox(-w_off / img_width, -h_off / img_height, + (width - w_off) / img_width, + (height - h_off) / img_height) + expand_img = np.ones((height, width, 3)) + expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean)) + expand_img = Image.fromarray(expand_img) + expand_img.paste(img, (int(w_off), int(h_off))) + bbox_labels = transform_labels(bbox_labels, expand_bbox) + return expand_img, bbox_labels + return img, bbox_labels diff --git a/fluid/object_detection/reader.py b/fluid/object_detection/reader.py index 4e680c29997b432c14b92ea641aa9f956de41292..6a6beb6e50f5b0a7f6b969ca53868178db2527a6 100644 --- a/fluid/object_detection/reader.py +++ b/fluid/object_detection/reader.py @@ -22,17 +22,38 @@ import os class Settings(object): - def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value): + def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value, + apply_distort, apply_expand): self._data_dir = data_dir self._label_list = [] label_fpath = os.path.join(data_dir, label_file) for line in open(label_fpath): self._label_list.append(line.strip()) + self._apply_distort = apply_distort + self._apply_expand = apply_expand self._resize_height = resize_h self._resize_width = resize_w self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( 'float32') + self._expand_prob = 0.5 + self._expand_max_ratio = 4 + self._hue_prob = 0.5 + self._hue_delta = 18 + self._contrast_prob = 0.5 + self._contrast_delta = 0.5 + self._saturation_prob = 0.5 + self._saturation_delta = 0.5 + self._brightness_prob = 0.5 + self._brightness_delta = 0.125 + + @property + def apply_distort(self): + return self._apply_expand + + @property + def apply_distort(self): + return self._apply_distort @property def data_dir(self): @@ -71,7 +92,6 @@ def _reader_creator(settings, file_list, mode, shuffle): img = Image.open(img_path) img_width, img_height = img.size - img = np.array(img) # layout: label | xmin | ymin | xmax | ymax | difficult if mode == 'train' or mode == 'test': @@ -99,6 +119,12 @@ def _reader_creator(settings, file_list, mode, shuffle): sample_labels = bbox_labels if mode == 'train': + if settings._apply_distort: + img = image_util.distort_image(img, settings) + if settings._apply_expand: + img, bbox_labels = image_util.expand_image( + img, bbox_labels, img_width, img_height, + settings) batch_sampler = [] # hard-code here batch_sampler.append( @@ -126,13 +152,14 @@ def _reader_creator(settings, file_list, mode, shuffle): sampled_bbox = image_util.generate_batch_samples( batch_sampler, bbox_labels, img_width, img_height) + img = np.array(img) if len(sampled_bbox) > 0: idx = int(random.uniform(0, len(sampled_bbox))) img, sample_labels = image_util.crop_image( img, bbox_labels, sampled_bbox[idx], img_width, img_height) - img = Image.fromarray(img) + img = Image.fromarray(img) img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS) img = np.array(img) diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py index 498bd733b2bc0b28e21ebfa5d386eb650f74999e..e1a4d75569153266d483b8013f89572fc2ccc274 100644 --- a/fluid/object_detection/train.py +++ b/fluid/object_detection/train.py @@ -75,13 +75,10 @@ def train(args, evaluate_difficult=False, ap_version='11point') - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.exponential_decay( - learning_rate=learning_rate, - decay_steps=40000, - decay_rate=0.1, - staircase=True), - momentum=0.9, + boundaries = [40000, 60000] + values = [0.001, 0.0005, 0.00025] + optimizer = fluid.optimizer.RMSProp( + learning_rate=fluid.layers.piecewise_decay(boundaries, values), regularization=fluid.regularizer.L2Decay(0.00005), ) optimizer.minimize(loss) @@ -90,7 +87,8 @@ def train(args, exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - load_model.load_paddlev1_vars(place) + load_model.load_and_set_vars(place) + #load_model.load_paddlev1_vars(place) train_reader = paddle.batch( reader.train(data_args, train_file_list), batch_size=batch_size) test_reader = paddle.batch( @@ -113,8 +111,9 @@ def train(args, loss_v = exe.run(fluid.default_main_program(), feed=feeder.feed(data), fetch_list=[loss]) - print("Pass {0}, batch {1}, loss {2}" - .format(pass_id, batch_id, loss_v[0])) + if batch_id % 20 == 0: + print("Pass {0}, batch {1}, loss {2}" + .format(pass_id, batch_id, loss_v[0])) test(pass_id) if pass_id % 10 == 0: @@ -130,6 +129,8 @@ if __name__ == '__main__': data_args = reader.Settings( data_dir='./data', label_file='label_list', + apply_distort=True, + apply_expand=True, resize_h=300, resize_w=300, mean_value=[127.5, 127.5, 127.5]) diff --git a/fluid/ocr_recognition/crnn_ctc_model.py b/fluid/ocr_recognition/crnn_ctc_model.py index 719c0158ec0e28c46a2915e42bd81533f848673c..73616ecb36ca2661eb8e4898caf34fc2d91b9bdc 100644 --- a/fluid/ocr_recognition/crnn_ctc_model.py +++ b/fluid/ocr_recognition/crnn_ctc_model.py @@ -26,7 +26,12 @@ def conv_bn_pool(input, bias_attr=bias, is_test=is_test) tmp = fluid.layers.pool2d( - input=tmp, pool_size=2, pool_type='max', pool_stride=2, use_cudnn=True) + input=tmp, + pool_size=2, + pool_type='max', + pool_stride=2, + use_cudnn=True, + ceil_mode=True) return tmp @@ -136,26 +141,61 @@ def encoder_net(images, def ctc_train_net(images, label, args, num_classes): regularizer = fluid.regularizer.L2Decay(args.l2) gradient_clip = None - fc_out = encoder_net( - images, - num_classes, - regularizer=regularizer, - gradient_clip=gradient_clip) + if args.parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places) + with pd.do(): + images_ = pd.read_input(images) + label_ = pd.read_input(label) + + fc_out = encoder_net( + images_, + num_classes, + regularizer=regularizer, + gradient_clip=gradient_clip) + + cost = fluid.layers.warpctc( + input=fc_out, + label=label_, + blank=num_classes, + norm_by_times=True) + sum_cost = fluid.layers.reduce_sum(cost) + + decoded_out = fluid.layers.ctc_greedy_decoder( + input=fc_out, blank=num_classes) + + pd.write_output(sum_cost) + pd.write_output(decoded_out) + + sum_cost, decoded_out = pd() + sum_cost = fluid.layers.reduce_sum(sum_cost) + + else: + fc_out = encoder_net( + images, + num_classes, + regularizer=regularizer, + gradient_clip=gradient_clip) + + cost = fluid.layers.warpctc( + input=fc_out, label=label, blank=num_classes, norm_by_times=True) + sum_cost = fluid.layers.reduce_sum(cost) + decoded_out = fluid.layers.ctc_greedy_decoder( + input=fc_out, blank=num_classes) - cost = fluid.layers.warpctc( - input=fc_out, label=label, blank=num_classes, norm_by_times=True) - sum_cost = fluid.layers.reduce_sum(cost) + casted_label = fluid.layers.cast(x=label, dtype='int64') + error_evaluator = fluid.evaluator.EditDistance( + input=decoded_out, label=casted_label) + + inference_program = fluid.default_main_program().clone() + with fluid.program_guard(inference_program): + inference_program = fluid.io.get_inference_program(error_evaluator) optimizer = fluid.optimizer.Momentum( learning_rate=args.learning_rate, momentum=args.momentum) - optimizer.minimize(sum_cost) + _, params_grads = optimizer.minimize(sum_cost) - decoded_out = fluid.layers.ctc_greedy_decoder( - input=fc_out, blank=num_classes) - casted_label = fluid.layers.cast(x=label, dtype='int64') - error_evaluator = fluid.evaluator.EditDistance( - input=decoded_out, label=casted_label) - return sum_cost, error_evaluator + return sum_cost, error_evaluator, inference_program def ctc_infer(images, num_classes): diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py index 85b1d2e708f73d7ac049af276626a38e76d19399..c2d8fd26bbdeb3ad5c9fb2c1ade3b2b22a0dfd44 100644 --- a/fluid/ocr_recognition/ctc_train.py +++ b/fluid/ocr_recognition/ctc_train.py @@ -1,5 +1,4 @@ """Trainer for OCR CTC model.""" -import paddle.v2 as paddle import paddle.fluid as fluid import dummy_reader import ctc_reader @@ -24,12 +23,12 @@ add_arg('momentum', float, 0.9, "Momentum.") add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.") add_arg('device', int, 0, "Device id.'-1' means running on CPU" "while '0' means GPU-0.") +add_arg('parallel', bool, True, "Whether use parallel training.") # yapf: disable def load_parameter(place): params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/') for name in params: - # print "param: %s" % name t = fluid.global_scope().find_var(name).get_tensor() t.set(params[name], place) @@ -41,7 +40,8 @@ def train(args, data_reader=dummy_reader): # define network images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1) - sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes) + sum_cost, error_evaluator, inference_program = ctc_train_net(images, label, args, num_classes) + # data reader train_reader = data_reader.train(args.batch_size) test_reader = data_reader.test() @@ -51,11 +51,8 @@ def train(args, data_reader=dummy_reader): place = fluid.CUDAPlace(args.device) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - #load_parameter(place) - inference_program = fluid.io.get_inference_program(error_evaluator) - for pass_id in range(args.pass_num): error_evaluator.reset(exe) batch_id = 1 @@ -78,7 +75,6 @@ def train(args, data_reader=dummy_reader): sys.stdout.flush() batch_id += 1 - # evaluate model on test data error_evaluator.reset(exe) for data in test_reader(): exe.run(inference_program, feed=get_feeder_data(data, place)) diff --git a/fluid/sequence_tagging_for_ner/README.md b/fluid/sequence_tagging_for_ner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1f634da4e2e385b06589cde0c6979812ff52e450 --- /dev/null +++ b/fluid/sequence_tagging_for_ner/README.md @@ -0,0 +1,120 @@ +# 命名实体识别 + +以下是本例的简要目录结构及说明: + +```text +. +├── data # 存储运行本例所依赖的数据,从外部获取 +├── network_conf.py # 模型定义 +├── reader.py # 数据读取接口, 从外部获取 +├── README.md # 文档 +├── train.py # 训练脚本 +├── infer.py # 预测脚本 +├── utils.py # 定义通用的函数, 从外部获取 +└── utils_extend.py # 对utils.py的拓展 +``` + + +## 简介,模型详解 + +在PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中对于命名实体识别任务有较详细的介绍,在本例中不再重复介绍。 +在模型上,我们沿用了v2版本的模型结构,唯一区别是我们使用LSTM代替原始的RNN。 + +## 数据获取 + +请参考PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md) 一节中数据获取方式,将该例中的data文件夹拷贝至本例目录下,运行其中的download.sh脚本获取训练和测试数据。 + +## 通用脚本获取 + +请将PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中提供的用于数据读取的文件[reader.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/reader.py)以及包含字典导入等通用功能的文件[utils.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/utils.py)复制到本目录下。本例将会使用到这两个脚本。 + +## 训练 + +1. 运行 `sh data/download.sh` +2. 修改 `train.py` 的 `main` 函数,指定数据路径 + + ```python + main( + train_data_file="data/train", + test_data_file="data/test", + vocab_file="data/vocab.txt", + target_file="data/target.txt", + emb_file="data/wordVectors.txt", + model_save_dir="models", + num_passes=1000, + use_gpu=False, + parallel=False) + ``` + +3. 运行命令 `python train.py` ,**需要注意:直接运行使用的是示例数据,请替换真实的标记数据。** + + ```text + Pass 127, Batch 9525, Cost 4.0867705, Precision 0.3954984, Recall 0.37846154, F1_score0.38679245 + Pass 127, Batch 9530, Cost 3.137265, Precision 0.42971888, Recall 0.38351256, F1_score0.405303 + Pass 127, Batch 9535, Cost 3.6240938, Precision 0.4272152, Recall 0.41795665, F1_score0.4225352 + Pass 127, Batch 9540, Cost 3.5352352, Precision 0.48464164, Recall 0.4536741, F1_score0.46864685 + Pass 127, Batch 9545, Cost 4.1130385, Precision 0.40131578, Recall 0.3836478, F1_score0.39228293 + Pass 127, Batch 9550, Cost 3.6826708, Precision 0.43333334, Recall 0.43730888, F1_score0.43531203 + Pass 127, Batch 9555, Cost 3.6363933, Precision 0.42424244, Recall 0.3962264, F1_score0.4097561 + Pass 127, Batch 9560, Cost 3.6101768, Precision 0.51363635, Recall 0.353125, F1_score0.41851854 + Pass 127, Batch 9565, Cost 3.5935276, Precision 0.5152439, Recall 0.5, F1_score0.5075075 + Pass 127, Batch 9570, Cost 3.4987144, Precision 0.5, Recall 0.4330218, F1_score0.46410686 + Pass 127, Batch 9575, Cost 3.4659843, Precision 0.39864865, Recall 0.38064516, F1_score0.38943896 + Pass 127, Batch 9580, Cost 3.1702557, Precision 0.5, Recall 0.4490446, F1_score0.47315437 + Pass 127, Batch 9585, Cost 3.1587276, Precision 0.49377593, Recall 0.4089347, F1_score0.4473684 + Pass 127, Batch 9590, Cost 3.5043538, Precision 0.4556962, Recall 0.4600639, F1_score0.45786962 + Pass 127, Batch 9595, Cost 2.981989, Precision 0.44981414, Recall 0.45149255, F1_score0.4506518 + [TrainSet] pass_id:127 pass_precision:[0.46023396] pass_recall:[0.43197003] pass_f1_score:[0.44565433] + [TestSet] pass_id:127 pass_precision:[0.4708409] pass_recall:[0.47971722] pass_f1_score:[0.4752376] + ``` +## 预测 +1. 修改 [infer.py](./infer.py) 的 `infer` 函数,指定:需要测试的模型的路径、测试数据、字典文件,预测标记文件的路径,默认参数如下: + + ```python + infer( + model_path="models/params_pass_0", + batch_size=6, + test_data_file="data/test", + vocab_file="data/vocab.txt", + target_file="data/target.txt", + use_gpu=False + ) + ``` + +2. 在终端运行 `python infer.py`,开始测试,会看到如下预测结果(以下为训练70个pass所得模型的部分预测结果): + + ```text + leicestershire B-ORG B-LOC + extended O O + their O O + first O O + innings O O + by O O + DGDG O O + runs O O + before O O + being O O + bowled O O + out O O + for O O + 296 O O + with O O + england B-LOC B-LOC + discard O O + andy B-PER B-PER + caddick I-PER I-PER + taking O O + three O O + for O O + DGDG O O + . O O + ``` + + 输出分为三列,以“\t” 分隔,第一列是输入的词语,第二列是标准结果,第三列为生成的标记结果。多条输入序列之间以空行分隔。 + +## 结果示例 + +

+
+图1. 学习曲线, 横轴表示训练轮数,纵轴表示F1值 +

diff --git a/fluid/sequence_tagging_for_ner/imgs/convergence_curve.png b/fluid/sequence_tagging_for_ner/imgs/convergence_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..6b862b751dd7ec0ef761dce78b9515769366d5f4 Binary files /dev/null and b/fluid/sequence_tagging_for_ner/imgs/convergence_curve.png differ diff --git a/fluid/sequence_tagging_for_ner/infer.py b/fluid/sequence_tagging_for_ner/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0bd9496ed2ec1db019a0124905093e0b12531a --- /dev/null +++ b/fluid/sequence_tagging_for_ner/infer.py @@ -0,0 +1,71 @@ +import numpy as np + +import paddle.fluid as fluid +import paddle.v2 as paddle + +from network_conf import ner_net +import reader +from utils import load_dict, load_reverse_dict +from utils_extend import to_lodtensor + + +def infer(model_path, batch_size, test_data_file, vocab_file, target_file, + use_gpu): + """ + use the model under model_path to predict the test data, the result will be printed on the screen + + return nothing + """ + word_dict = load_dict(vocab_file) + word_reverse_dict = load_reverse_dict(vocab_file) + + label_dict = load_dict(target_file) + label_reverse_dict = load_reverse_dict(target_file) + + test_data = paddle.batch( + reader.data_reader(test_data_file, word_dict, label_dict), + batch_size=batch_size) + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + for data in test_data(): + word = to_lodtensor(map(lambda x: x[0], data), place) + mark = to_lodtensor(map(lambda x: x[1], data), place) + target = to_lodtensor(map(lambda x: x[2], data), place) + crf_decode = exe.run( + inference_program, + feed={"word": word, + "mark": mark, + "target": target}, + fetch_list=fetch_targets, + return_numpy=False) + lod_info = (crf_decode[0].lod())[0] + np_data = np.array(crf_decode[0]) + assert len(data) == len(lod_info) - 1 + for sen_index in xrange(len(data)): + assert len(data[sen_index][0]) == lod_info[ + sen_index + 1] - lod_info[sen_index] + word_index = 0 + for tag_index in xrange(lod_info[sen_index], + lod_info[sen_index + 1]): + word = word_reverse_dict[data[sen_index][0][word_index]] + gold_tag = label_reverse_dict[data[sen_index][2][ + word_index]] + tag = label_reverse_dict[np_data[tag_index][0]] + print word + "\t" + gold_tag + "\t" + tag + word_index += 1 + print "" + + +if __name__ == "__main__": + infer( + model_path="models/params_pass_0", + batch_size=6, + test_data_file="data/test", + vocab_file="data/vocab.txt", + target_file="data/target.txt", + use_gpu=False) diff --git a/fluid/sequence_tagging_for_ner/network_conf.py b/fluid/sequence_tagging_for_ner/network_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..5eaa704f67641bd9bb98bbac162a0adb7a72c246 --- /dev/null +++ b/fluid/sequence_tagging_for_ner/network_conf.py @@ -0,0 +1,127 @@ +import math + +import paddle.fluid as fluid +from paddle.fluid.initializer import NormalInitializer + +from utils import logger, load_dict, get_embedding + + +def ner_net(word_dict_len, label_dict_len, parallel, stack_num=2): + mark_dict_len = 2 + word_dim = 50 + mark_dim = 5 + hidden_dim = 300 + IS_SPARSE = True + embedding_name = 'emb' + + def _net_conf(word, mark, target): + word_embedding = fluid.layers.embedding( + input=word, + size=[word_dict_len, word_dim], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr=fluid.ParamAttr( + name=embedding_name, trainable=False)) + + mark_embedding = fluid.layers.embedding( + input=mark, + size=[mark_dict_len, mark_dim], + dtype='float32', + is_sparse=IS_SPARSE) + + word_caps_vector = fluid.layers.concat( + input=[word_embedding, mark_embedding], axis=1) + mix_hidden_lr = 1 + + rnn_para_attr = fluid.ParamAttr( + initializer=NormalInitializer( + loc=0.0, scale=0.0), + learning_rate=mix_hidden_lr) + hidden_para_attr = fluid.ParamAttr( + initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)), + learning_rate=mix_hidden_lr) + + hidden = fluid.layers.fc( + input=word_caps_vector, + name="__hidden00__", + size=hidden_dim, + act="tanh", + bias_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))), + param_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)))) + fea = [] + for direction in ["fwd", "bwd"]: + for i in range(stack_num): + if i != 0: + hidden = fluid.layers.fc( + name="__hidden%02d_%s__" % (i, direction), + size=hidden_dim, + act="stanh", + bias_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=1.0)), + input=[hidden, rnn[0], rnn[1]], + param_attr=[ + hidden_para_attr, rnn_para_attr, rnn_para_attr + ]) + rnn = fluid.layers.dynamic_lstm( + name="__rnn%02d_%s__" % (i, direction), + input=hidden, + size=hidden_dim, + candidate_activation='relu', + gate_activation='sigmoid', + cell_activation='sigmoid', + bias_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=1.0)), + is_reverse=(i % 2) if direction == "fwd" else not i % 2, + param_attr=rnn_para_attr) + fea += [hidden, rnn[0], rnn[1]] + + rnn_fea = fluid.layers.fc( + size=hidden_dim, + bias_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))), + act="stanh", + input=fea, + param_attr=[hidden_para_attr, rnn_para_attr, rnn_para_attr] * 2) + + emission = fluid.layers.fc( + size=label_dict_len, + input=rnn_fea, + param_attr=fluid.ParamAttr(initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)))) + + crf_cost = fluid.layers.linear_chain_crf( + input=emission, + label=target, + param_attr=fluid.ParamAttr( + name='crfw', + initializer=NormalInitializer( + loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)), + learning_rate=mix_hidden_lr)) + avg_cost = fluid.layers.mean(x=crf_cost) + return avg_cost, emission + + word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1) + mark = fluid.layers.data(name='mark', shape=[1], dtype='int64', lod_level=1) + target = fluid.layers.data( + name="target", shape=[1], dtype='int64', lod_level=1) + + if parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places) + with pd.do(): + word_ = pd.read_input(word) + mark_ = pd.read_input(mark) + target_ = pd.read_input(target) + avg_cost, emission_base = _net_conf(word_, mark_, target_) + pd.write_output(avg_cost) + pd.write_output(emission_base) + avg_cost_list, emission = pd() + avg_cost = fluid.layers.mean(x=avg_cost_list) + emission.stop_gradient = True + else: + avg_cost, emission = _net_conf(word, mark, target) + + return avg_cost, emission, word, mark, target diff --git a/fluid/sequence_tagging_for_ner/train.py b/fluid/sequence_tagging_for_ner/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed77cd5ca1d504a8b79b4f87349242b5051c539 --- /dev/null +++ b/fluid/sequence_tagging_for_ner/train.py @@ -0,0 +1,122 @@ +import os +import math +import numpy as np + +import paddle.v2 as paddle +import paddle.fluid as fluid + +import reader +from network_conf import ner_net +from utils import logger, load_dict +from utils_extend import to_lodtensor, get_embedding + + +def test(exe, chunk_evaluator, inference_program, test_data, place): + chunk_evaluator.reset(exe) + for data in test_data(): + word = to_lodtensor(map(lambda x: x[0], data), place) + mark = to_lodtensor(map(lambda x: x[1], data), place) + target = to_lodtensor(map(lambda x: x[2], data), place) + acc = exe.run(inference_program, + feed={"word": word, + "mark": mark, + "target": target}) + return chunk_evaluator.eval(exe) + + +def main(train_data_file, test_data_file, vocab_file, target_file, emb_file, + model_save_dir, num_passes, use_gpu, parallel): + if not os.path.exists(model_save_dir): + os.mkdir(model_save_dir) + + BATCH_SIZE = 200 + word_dict = load_dict(vocab_file) + label_dict = load_dict(target_file) + + word_vector_values = get_embedding(emb_file) + + word_dict_len = len(word_dict) + label_dict_len = len(label_dict) + + avg_cost, feature_out, word, mark, target = ner_net( + word_dict_len, label_dict_len, parallel) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + sgd_optimizer.minimize(avg_cost) + + crf_decode = fluid.layers.crf_decoding( + input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) + + chunk_evaluator = fluid.evaluator.ChunkEvaluator( + input=crf_decode, + label=target, + chunk_scheme="IOB", + num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0))) + + inference_program = fluid.default_main_program().clone() + with fluid.program_guard(inference_program): + test_target = chunk_evaluator.metrics + chunk_evaluator.states + inference_program = fluid.io.get_inference_program(test_target) + + train_reader = paddle.batch( + paddle.reader.shuffle( + reader.data_reader(train_data_file, word_dict, label_dict), + buf_size=20000), + batch_size=BATCH_SIZE) + test_reader = paddle.batch( + paddle.reader.shuffle( + reader.data_reader(test_data_file, word_dict, label_dict), + buf_size=20000), + batch_size=BATCH_SIZE) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=[word, mark, target], place=place) + exe = fluid.Executor(place) + + exe.run(fluid.default_startup_program()) + + embedding_name = 'emb' + embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor() + embedding_param.set(word_vector_values, place) + + batch_id = 0 + for pass_id in xrange(num_passes): + chunk_evaluator.reset(exe) + for data in train_reader(): + cost, batch_precision, batch_recall, batch_f1_score = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + chunk_evaluator.metrics) + if batch_id % 5 == 0: + print("Pass " + str(pass_id) + ", Batch " + str( + batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str( + batch_precision[0]) + ", Recall " + str(batch_recall[0]) + + ", F1_score" + str(batch_f1_score[0])) + batch_id = batch_id + 1 + + pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe) + print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str( + pass_precision) + " pass_recall:" + str(pass_recall) + + " pass_f1_score:" + str(pass_f1_score)) + pass_precision, pass_recall, pass_f1_score = test( + exe, chunk_evaluator, inference_program, test_reader, place) + print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str( + pass_precision) + " pass_recall:" + str(pass_recall) + + " pass_f1_score:" + str(pass_f1_score)) + + save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id) + fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'], + [crf_decode], exe) + + +if __name__ == "__main__": + main( + train_data_file="data/train", + test_data_file="data/test", + vocab_file="data/vocab.txt", + target_file="data/target.txt", + emb_file="data/wordVectors.txt", + model_save_dir="models", + num_passes=1000, + use_gpu=False, + parallel=False) diff --git a/fluid/sequence_tagging_for_ner/utils_extend.py b/fluid/sequence_tagging_for_ner/utils_extend.py new file mode 100644 index 0000000000000000000000000000000000000000..03e7e62fd5f8496d4a9436ad34ec7763b46b460d --- /dev/null +++ b/fluid/sequence_tagging_for_ner/utils_extend.py @@ -0,0 +1,28 @@ +import numpy as np + +import paddle.fluid as fluid + + +def get_embedding(emb_file='data/wordVectors.txt'): + """ + Get the trained word vector. + """ + return np.loadtxt(emb_file, dtype='float32') + + +def to_lodtensor(data, place): + """ + convert data to lodtensor + """ + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = fluid.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res diff --git a/globally_normalized_reader/README.cn.md b/globally_normalized_reader/README.cn.md new file mode 100644 index 0000000000000000000000000000000000000000..b1d3910754538ffb2743a7eb80ee7225eabcd534 --- /dev/null +++ b/globally_normalized_reader/README.cn.md @@ -0,0 +1,59 @@ +此目录中代码示例PaddlePaddle所需版本至少为v0.11.0。如果您使用的PaddlePaddle版本早于v0.11.0, [请更新](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html). + +--- + +# 全球标准化阅读器 + +该模型实现以下功能: + +Jonathan Raiman and John Miller. Globally Normalized Reader. Empirical Methods in Natural Language Processing (EMNLP), 2017 + +如果您在研究中使用数据集/代码,请引用上述论文: + +```text +@inproceedings{raiman2015gnr, + author={Raiman, Jonathan and Miller, John}, + booktitle={Empirical Methods in Natural Language Processing (EMNLP)}, + title={Globally Normalized Reader}, + year={2017}, +} +``` + +您也可以访问 https://github.com/baidu-research/GloballyNormalizedReader 以获取更多信息。 + + +# 安装 + +1. 请使用 [docker image](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) 安装最新的PaddlePaddle,运行方法: + ```bash + docker pull paddledev/paddle + ``` +2. 下载所有必要的数据,运行方法: + ```bash + cd data && ./download.sh && cd .. + ``` +3. 预处理并特征化数据: + ```bash + python featurize.py --datadir data --outdir data/featurized --glove-path data/glove.840B.300d.txt + ``` + +# 模型训练 + +- 根据需要修改config.py来配置模型,然后运行: + + ```bash + python train.py 2>&1 | tee train.log + ``` + +# 使用训练过的模型推断 + +- 运行以下训练模型来推断: + ```bash + python infer.py \ + --model_path models/pass_00000.tar.gz \ + --data_dir data/featurized/ \ + --batch_size 2 \ + --use_gpu 0 \ + --trainer_count 1 \ + 2>&1 | tee infer.log + ``` diff --git a/sequence_tagging_for_ner/README.md b/sequence_tagging_for_ner/README.md index 38e187554537bc5b83a5c658d639c9743047f085..9870e3cf2edb4e0a0514b33c59c91d861a8caf5d 100644 --- a/sequence_tagging_for_ner/README.md +++ b/sequence_tagging_for_ner/README.md @@ -1,4 +1,4 @@ -运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。 +运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。 --- @@ -25,16 +25,16 @@ 命名实体识别(Named Entity Recognition,NER)又称作“专名识别”,是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名、专有名词等,是自然语言处理研究的一个基础问题。NER任务通常包括实体边界识别、确定实体类别两部分,可以将其作为序列标注问题解决。 -序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例: +序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)定义的标签集,如下是一个NER的标注结果示例:


图1. BIO标注方法示例

-根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。 +根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。 -由于序列标注问题的广泛性,产生了[CRF](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等经典的序列模型,这些模型大多只能使用局部信息或需要人工设计特征。随着深度学习研究的发展,循环神经网络(Recurrent Neural Network,RNN等 序列模型能够处理序列元素之间前后关联问题,能够从原始输入文本中学习特征表示,而更加适合序列标注任务,更多相关知识可参考PaddleBook中[语义角色标注](https://github.com/PaddlePaddle/book/blob/develop/07.label_semantic_roles/README.cn.md)一课。 +由于序列标注问题的广泛性,产生了[CRF](http://www.paddlepaddle.org/docs/develop/book/07.label_semantic_roles/index.cn.html)等经典的序列模型,这些模型大多只能使用局部信息或需要人工设计特征。随着深度学习研究的发展,循环神经网络(Recurrent Neural Network,RNN等 序列模型能够处理序列元素之间前后关联问题,能够从原始输入文本中学习特征表示,而更加适合序列标注任务,更多相关知识可参考PaddleBook中[语义角色标注](https://github.com/PaddlePaddle/book/blob/develop/07.label_semantic_roles/README.cn.md)一课。 ## 模型详解 diff --git a/sequence_tagging_for_ner/images/BIO tag example.png b/sequence_tagging_for_ner/images/BIO tag example.png new file mode 100644 index 0000000000000000000000000000000000000000..88ee9e84b7cc9ed8fd794c66c0929c1351d34d8e Binary files /dev/null and b/sequence_tagging_for_ner/images/BIO tag example.png differ diff --git a/sequence_tagging_for_ner/images/ner_model_en.png b/sequence_tagging_for_ner/images/ner_model_en.png new file mode 100644 index 0000000000000000000000000000000000000000..da541cda7e9632cfdac86df6f3f7d3e4c462b85b Binary files /dev/null and b/sequence_tagging_for_ner/images/ner_model_en.png differ