From e387cdba77018805fcc3a0f5897e9747253798d7 Mon Sep 17 00:00:00 2001 From: alvations Date: Fri, 21 Oct 2016 11:28:46 +0800 Subject: [PATCH] Added Bidi-LSTM and DB-LSTM to quick_start demo (#226) --- demo/quick_start/train.sh | 2 + demo/quick_start/trainer_config.bidi-lstm.py | 62 +++++++++++++++++ demo/quick_start/trainer_config.db-lstm.py | 73 ++++++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 demo/quick_start/trainer_config.bidi-lstm.py create mode 100644 demo/quick_start/trainer_config.db-lstm.py diff --git a/demo/quick_start/train.sh b/demo/quick_start/train.sh index 1f0a137c8b..ea4e32249a 100755 --- a/demo/quick_start/train.sh +++ b/demo/quick_start/train.sh @@ -18,6 +18,8 @@ cfg=trainer_config.lr.py #cfg=trainer_config.emb.py #cfg=trainer_config.cnn.py #cfg=trainer_config.lstm.py +#cfg=trainer_config.bidi-lstm.py +#cfg=trainer_config.db-lstm.py paddle train \ --config=$cfg \ --save_dir=./output \ diff --git a/demo/quick_start/trainer_config.bidi-lstm.py b/demo/quick_start/trainer_config.bidi-lstm.py new file mode 100644 index 0000000000..3be3d37342 --- /dev/null +++ b/demo/quick_start/trainer_config.bidi-lstm.py @@ -0,0 +1,62 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_emb", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +bias_attr = ParamAttr(initial_std=0.,l2_rate=0.) +data = data_layer(name="word", size=len(word_dict)) +emb = embedding_layer(input=data, size=128) + +bi_lstm = bidirectional_lstm(input=emb, size=128) +dropout = dropout_layer(input=bi_lstm, dropout_rate=0.5) + +output = fc_layer(input=dropout, size=2, + bias_attr=bias_attr, + act=SoftmaxActivation()) + +if is_predict: + maxid = maxid_layer(output) + outputs([maxid, output]) +else: + label = data_layer(name="label", size=2) + cls = classification_cost(input=output, label=label) + outputs(cls) diff --git a/demo/quick_start/trainer_config.db-lstm.py b/demo/quick_start/trainer_config.db-lstm.py new file mode 100644 index 0000000000..b35bdf5a61 --- /dev/null +++ b/demo/quick_start/trainer_config.db-lstm.py @@ -0,0 +1,73 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_emb", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +bias_attr = ParamAttr(initial_std=0.,l2_rate=0.) + +data = data_layer(name="word", size=len(word_dict)) +emb = embedding_layer(input=data, size=128) + +hidden_0 = mixed_layer(size=128, input=[full_matrix_projection(input=emb)]) +lstm_0 = lstmemory(input=hidden_0, layer_attr=ExtraAttr(drop_rate=0.1)) + +input_layers = [hidden_0, lstm_0] + +for i in range(1,8): + fc = fc_layer(input=input_layers, size=128) + lstm = lstmemory(input=fc, layer_attr=ExtraAttr(drop_rate=0.1), + reverse=(i % 2) == 1,) + input_layers = [fc, lstm] + +lstm_last = pooling_layer(input=lstm, pooling_type=MaxPooling()) + +output = fc_layer(input=lstm_last, size=2, + bias_attr=bias_attr, + act=SoftmaxActivation()) + +if is_predict: + maxid = maxid_layer(output) + outputs([maxid, output]) +else: + label = data_layer(name="label", size=2) + cls = classification_cost(input=output, label=label) + outputs(cls) -- GitLab