提交 56b08830 编写于 作者: C chenjiawen

增加finetune代码

上级 a9c4cbb9
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
import os
cell_clip=3.0
proj_clip=3.0
hidden_size=4096
vocab_size=52445
emb_size=512
def dropout(input):
dropout_rate=0.5
return layers.dropout(
input,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name, proj_size, args):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
init = None
init_b = None
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=init),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=proj_clip,
cell_clip=cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=init),
bias_attr=fluid.ParamAttr(initializer=init_b))
return hidden, cell, input_proj
def emb(x, vocab_size=52445,emb_size=512):
x_emb = layers.embedding(
input=x,
size=[vocab_size, emb_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
return x_emb
def encoder_1(x_emb,
vocab_size,
emb_size,
init_hidden=None,
init_cell=None,
para_name='',
args=None):
rnn_input = x_emb
#rnn_input.stop_gradient = True
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
num_layers=2
for i in range(num_layers):
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1), emb_size, args)
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out.stop_gradient = True
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
return rnn_outs, rnn_outs_ori
def weight_layers(lm_embeddings, name="", l2_coef=0.0):
'''
Weight the layers of a biLM with trainable scalar weights to
compute ELMo representations.
Input:
lm_embeddings(list): representations of 2 layers from biLM.
name = a string prefix used for the trainable variable names
l2_coef: the l2 regularization coefficient $\lambda$.
Pass None or 0.0 for no regularization.
Output:
weighted_lm_layers: weighted embeddings form biLM
'''
n_lm_layers = len(lm_embeddings)
W = layers.create_parameter([n_lm_layers, ], dtype="float32", name=name+"ELMo_w",
attr=fluid.ParamAttr(name=name+"ELMo_w",
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(l2_coef)))
normed_weights = layers.softmax( W + 1.0 / n_lm_layers)
splited_normed_weights = layers.split(normed_weights, n_lm_layers, dim=0)
# compute the weighted, normalized LM activations
pieces = []
for w, t in zip(splited_normed_weights, lm_embeddings):
pieces.append(t * w)
sum_pieces = layers.sums(pieces)
# scale the weighted sum by gamma
gamma = layers.create_parameter([1], dtype="float32", name=name+"ELMo_gamma",
attr=fluid.ParamAttr(name=name+"ELMo_gamma",
initializer=fluid.initializer.Constant(1.0)))
weighted_lm_layers = sum_pieces * gamma
return weighted_lm_layers
def elmo_encoder(x_emb, elmo_l2_coef):
lstm_outputs = []
x_emb_r=fluid.layers.sequence_reverse(x_emb, name=None)
fw_hiddens, fw_hiddens_ori = encoder_1(
x_emb,
vocab_size,
emb_size,
para_name='fw_',
args=None)
bw_hiddens, bw_hiddens_ori = encoder_1(
x_emb_r,
vocab_size,
emb_size,
para_name='bw_',
args=None)
num_layers = len(fw_hiddens_ori)
token_embeddings = layers.concat(input=[x_emb, x_emb], axis=1)
token_embeddings.stop_gradient = True
concate_embeddings = [token_embeddings]
for index in range(num_layers):
embedding = layers.concat(input = [fw_hiddens_ori[index], bw_hiddens_ori[index]], axis=1)
embedding = dropout(embedding)
embedding.stop_gradient=True
concate_embeddings.append(embedding)
weighted_meb = weight_layers(concate_embeddings, l2_coef=elmo_l2_coef)
return weighted_meb
def init_pretraining_params(exe,
pretraining_params_path,
main_program):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def if_exist(var):
path = os.path.join(pretraining_params_path, var.name)
exist = os.path.exists(path)
if exist:
print('Load model: %s' % path)
return exist
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=if_exist)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
 
、 ,
。 .
— -
~ ~
‖ |
… .
‘ '
’ '
“ "
” "
〔 (
〕 )
〈 <
〉 >
「 '
」 '
『 "
』 "
〖 [
〗 ]
【 [
】 ]
∶ :
$ $
! !
" "
# #
% %
& &
' '
( (
) )
* *
+ +
, ,
- -
. .
/ /
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
; ;
< <
= =
> >
? ?
@ @
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
[ [
\ \
] ]
^ ^
_ _
` `
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
{ {
| |
} }
 ̄ ~
〝 "
〞 "
﹐ ,
﹑ ,
﹒ .
﹔ ;
﹕ :
﹖ ?
﹗ !
﹙ (
﹚ )
﹛ {
﹜ {
﹝ [
﹞ ]
﹟ #
﹠ &
﹡ *
﹢ +
﹣ -
﹤ <
﹥ >
﹦ =
﹨ \
﹩ $
﹪ %
﹫ @
,
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
此差异已折叠。
a
a-B
a-I
ad
ad-B
ad-I
an
an-B
an-I
c
d
d-B
d-I
f
f-B
f-I
m
m-B
m-I
n
n-B
n-I
nr
nr-B
nr-I
ns
ns-B
ns-I
nt
nt-B
nt-I
nw
nw-B
nw-I
nz
nz-B
nz-I
p
q
r
r-B
r-I
s
s-B
s-I
t
t-B
t-I
u
v
v-B
v-I
vd
vd-B
vd-I
vn
vn-B
vn-I
w
xc
\ No newline at end of file
此差异已折叠。
from __future__ import print_function
import numpy as np
import reader
import paddle.fluid as fluid
import paddle
import argparse
import time
import sys
import io
if sys.version_info > (3,):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
else:
reload(sys)
sys.setdefaultencoding("utf8")
def parse_args():
parser = argparse.ArgumentParser("Run inference.")
parser.add_argument(
'--batch_size',
type=int,
default=5,
help='The size of a batch. (default: %(default)d)'
)
parser.add_argument(
'--model_path',
type=str,
default='./conf/model',
help='A path to the model. (default: %(default)s)'
)
parser.add_argument(
'--test_data_dir',
type=str,
default='./data/test_data',
help='A directory with test data files. (default: %(default)s)'
)
parser.add_argument(
"--word_dict_path",
type=str,
default="./conf/word.dic",
help="The path of the word dictionary. (default: %(default)s)"
)
parser.add_argument(
"--label_dict_path",
type=str,
default="./conf/tag.dic",
help="The path of the label dictionary. (default: %(default)s)"
)
parser.add_argument(
"--word_rep_dict_path",
type=str,
default="./conf/q2b.dic",
help="The path of the word replacement Dictionary. (default: %(default)s)"
)
args = parser.parse_args()
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def get_real_tag(origin_tag):
if origin_tag == "O":
return "O"
return origin_tag[0:len(origin_tag) - 2]
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def infer(args):
id2word_dict = reader.load_dict(args.word_dict_path)
word2id_dict = reader.load_reverse_dict(args.word_dict_path)
id2label_dict = reader.load_dict(args.label_dict_path)
label2id_dict = reader.load_reverse_dict(args.label_dict_path)
q2b_dict = reader.load_dict(args.word_rep_dict_path)
test_data = paddle.batch(
reader.test_reader(args.test_data_dir,
word2id_dict,
label2id_dict,
q2b_dict),
batch_size = args.batch_size)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)
for data in test_data():
full_out_str = ""
word_idx = to_lodtensor([x[0] for x in data], place)
word_list = [x[1] for x in data]
(crf_decode, ) = exe.run(inference_program,
feed={"word":word_idx},
fetch_list=fetch_targets,
return_numpy=False)
lod_info = (crf_decode.lod())[0]
np_data = np.array(crf_decode)
assert len(data) == len(lod_info) - 1
for sen_index in range(len(data)):
assert len(data[sen_index][0]) == lod_info[
sen_index + 1] - lod_info[sen_index]
word_index = 0
outstr = ""
cur_full_word = ""
cur_full_tag = ""
words = word_list[sen_index]
for tag_index in range(lod_info[sen_index],
lod_info[sen_index + 1]):
cur_word = words[word_index]
cur_tag = id2label_dict[str(np_data[tag_index][0])]
if cur_tag.endswith("-B") or cur_tag.endswith("O"):
if len(cur_full_word) != 0:
outstr += cur_full_word + u"/" + cur_full_tag + u" "
cur_full_word = cur_word
cur_full_tag = get_real_tag(cur_tag)
else:
cur_full_word += cur_word
word_index += 1
outstr += cur_full_word + u"/" + cur_full_tag + u" "
outstr = outstr.strip()
full_out_str += outstr + u"\n"
print(full_out_str.strip(), file=sys.stdout)
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
infer(args)
"""
The function lex_net(args) define the lexical analysis network structure
"""
import sys
import os
import math
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer
import paddle.fluid.layers as layers
from bilm import elmo_encoder
from bilm import emb
#import bilm
import ipdb
def lex_net(args, word_dict_len, label_dict_len):
"""
define the lexical analysis network structure
"""
word_emb_dim = args.word_emb_dim
grnn_hidden_dim = args.grnn_hidden_dim
emb_lr = args.emb_learning_rate
crf_lr = args.crf_learning_rate
bigru_num = args.bigru_num
init_bound = 0.1
IS_SPARSE = True
def _bigru_layer(input_feature):
"""
define the bidirectional gru layer
"""
pre_gru = fluid.layers.fc(
input=input_feature,
size=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru = fluid.layers.dynamic_gru(
input=pre_gru,
size=grnn_hidden_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
pre_gru_r = fluid.layers.fc(
input=input_feature,
size=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru_r = fluid.layers.dynamic_gru(
input=pre_gru_r,
size=grnn_hidden_dim,
is_reverse=True,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
bi_merge = fluid.layers.concat(input=[gru, gru_r], axis=1)
return bi_merge
def _net_conf(word, target):
"""
Configure the network
"""
#add elmo
#ipdb.set_trace()
#elmo_embedding = emb(word)
#layers.Print(word, message='input_seq', summarize=10)
#drnn = layers.DynamicRNN()
#with drnn.block():
# elmo_embedding = drnn.step_input(elmo_embedding)
# elmo_enc= elmo_encoder(elmo_embedding)
# drnn.output(elmo_enc)
# elmo_enc = drnn()
word_embedding = fluid.layers.embedding(
input=word,
size=[word_dict_len, word_emb_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="word_emb",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
#layers.Print(word, message='word', summarize=-1)
#layers.Print(word_r, message='word_r', summarize=-1)
#word_r=fluid.layers.sequence_reverse(word, name=None)
#layers.Print(word_r, message='word_r_1', summarize=-1)
elmo_embedding = emb(word)
#elmo_embedding_r=emb(word_r)
#layers.Print(elmo_embedding, message='elmo_embedding', summarize=10)
#layers.Print(word, message='input_seq', summarize=10)
#drnn = layers.DynamicRNN()
#with drnn.block():
#elmo_embed = drnn.step_input(elmo_embedding)
#layers.Print(elmo_embed, message='elmo_enc', summarize=10)
#elmo_enc = elmo_encoder(elmo_embedding)
elmo_enc = elmo_encoder(elmo_embedding, args.elmo_l2_coef)
#input_feature=layers.concat(input=[elmo_enc, word_embedding], axis=1)
#input_feature=elmo_enc
#input_feature=layers.concat#drnn.output(input_feature)
#input_feature = drnn()
# input_feature = word_embedding
#layers.Print(elmo_enc, message='elmo_enc', summarize=10)
input_feature=layers.concat(input=[elmo_enc, word_embedding], axis=1)
for i in range(bigru_num):
bigru_output = _bigru_layer(input_feature)
input_feature = bigru_output
emission = fluid.layers.fc(
size=label_dict_len,
input=bigru_output,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=target,
param_attr=fluid.ParamAttr(
name='crfw',
learning_rate=crf_lr))
crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw'))
avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, crf_decode
word = fluid.layers.data(
name='word', shape=[1], dtype='int64', lod_level=1)
#word_r = fluid.layers.data(
# name='word_r', shape=[1], dtype='int64', lod_level=1)
target = fluid.layers.data(
name="target", shape=[1], dtype='int64', lod_level=1)
avg_cost, crf_decode= _net_conf(word, target)
return avg_cost, crf_decode, word,target
#coding: utf-8
"""
The file_reader converts raw corpus to input.
"""
import os
import __future__
import io
def file_reader(file_dir,
word2id_dict,
label2id_dict,
word_replace_dict,
filename_feature=""):
"""
define the reader to read files in file_dir
"""
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
def reader():
"""
the data generator
"""
index = 0
for root, dirs, files in os.walk(file_dir):
for filename in files:
for line in io.open(os.path.join(root, filename), 'r', encoding='utf8'):
index += 1
bad_line = False
line = line.strip("\n")
if len(line) == 0:
continue
seg_tag = line.rfind("\t")
# TODO 词和字模型
word_part = line[0:seg_tag].strip().split(' ')
label_part = line[seg_tag + 1:]
word_idx = []
words = word_part
for word in words:
if word in word_replace_dict:
word = word_replace_dict[word]
if word in word2id_dict:
word_idx.append(int(word2id_dict[word]))
else:
word_idx.append(int(word2id_dict["<UNK>"]))
target_idx = []
labels = label_part.strip().split(" ")
for label in labels:
if label in label2id_dict:
target_idx.append(int(label2id_dict[label]))
else:
target_idx.append(int(label2id_dict["O"]))
if len(word_idx) != len(target_idx):
print(line)
continue
#import ipdb;ipdb.set_trace()
#import copy;word_idx1=copy.deepcopy(word_idx)
#word_idx1=word_idx
#word_idx.reverse()
#import ipdb;ipdb.set_trace()
yield word_idx, target_idx
return reader
def test_reader(file_dir,
word2id_dict,
label2id_dict,
word_replace_dict,
filename_feature=""):
"""
define the reader to read test files in file_dir
"""
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
def reader():
"""
the data generator
"""
index = 0
for root, dirs, files in os.walk(file_dir):
for filename in files:
if not filename.startswith(filename_feature):
continue
for line in io.open(os.path.join(root, filename), 'r', encoding='utf8'):
index += 1
bad_line = False
line = line.strip("\n")
if len(line) == 0:
continue
seg_tag = line.rfind("\t")
if seg_tag == -1:
seg_tag = len(line)
word_part = line[0:seg_tag]
label_part = line[seg_tag + 1:]
word_idx = []
words = word_part
for word in words:
if ord(word) < 0x20:
word = ' '
if word in word_replace_dict:
word = word_replace_dict[word]
if word in word2id_dict:
word_idx.append(int(word2id_dict[word]))
else:
word_idx.append(int(word2id_dict["OOV"]))
yield word_idx, words
return reader
def load_reverse_dict(dict_path):
"""
Load a dict. The first column is the key and the second column is the value.
"""
result_dict = {}
# TODO 字和词模型
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')):
terms = line.strip("\n")
result_dict[terms] = idx
return result_dict
def load_dict(dict_path):
"""
Load a dict. The first column is the value and the second column is the key.
"""
result_dict = {}
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')):
terms = line.strip("\n")
result_dict[idx] = terms
return result_dict
export FLAGS_fraction_of_gpu_memory_to_use=0.5
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
export CUDA_VISIBLE_DEVICES=4
python train.py \
--traindata_dir data/train \
--model_save_dir model \
--use_gpu 1 \
--corpus_type_list train \
--corpus_proportion_list 1 \
--num_iterations 200000 \
--testdata_dir data/dev $@ \
"""
This file is used to train the model.
"""
import os
import sys
import math
import time
import random
import argparse
import numpy as np
import paddle
import paddle.fluid as fluid
import reader
from network import lex_net
from bilm import init_pretraining_params
def parse_args():
"""
Parsing the input parameters.
"""
parser = argparse.ArgumentParser("Training for lexical analyzer.")
parser.add_argument(
"--traindata_dir",
type=str,
default="data/train_data",
help="The folder where the training data is located.")
parser.add_argument(
"--testdata_dir",
type=str,
default="data/test_data",
help="The folder where the training data is located.")
parser.add_argument(
"--model_save_dir",
type=str,
default="./models",
help="The model will be saved in this path.")
parser.add_argument(
"--save_model_per_batchs",
type=int,
default=1000,
help="Save the model once per xxxx batch of training")
parser.add_argument(
"--eval_window",
type=int,
default=20,
help="Training will be suspended when the evaluation indicators on the validation set" \
" no longer increase. The eval_window specifies the scope of the evaluation.")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The number of sequences contained in a mini-batch, or the maximum" \
"number of tokens (include paddings) contained in a mini-batch.")
parser.add_argument(
"--corpus_type_list",
type=str,
default=["human", "feed", "query", "title", "news"],
nargs='+',
help="The pattern list of different types of corpus used in training.")
parser.add_argument(
"--corpus_proportion_list",
type=float,
default=[0.2, 0.2, 0.2, 0.2, 0.2],
nargs='+',
help="The proportion list of different types of corpus used in training.")
parser.add_argument(
"--use_gpu",
type=int,
default=False,
help="Whether or not to use GPU. 0-->CPU 1-->GPU")
parser.add_argument(
"--traindata_shuffle_buffer",
type=int,
default=200000,
help="The buffer size used in shuffle the training data.")
parser.add_argument(
"--word_emb_dim",
type=int,
default=128,
help="The dimension in which a word is embedded.")
parser.add_argument(
"--grnn_hidden_dim",
type=int,
default=256,
help="The number of hidden nodes in the GRNN layer.")
parser.add_argument(
"--bigru_num",
type=int,
default=2,
help="The number of bi_gru layers in the network.")
parser.add_argument(
"--base_learning_rate",
type=float,
default=1e-3,
help="The basic learning rate that affects the entire network.")
parser.add_argument(
"--emb_learning_rate",
type=float,
default=5,
help="The real learning rate of the embedding layer will be" \
" (emb_learning_rate * base_learning_rate)."
)
parser.add_argument(
"--crf_learning_rate",
type=float,
default=0.2,
help="The real learning rate of the embedding layer will be" \
" (crf_learning_rate * base_learning_rate)."
)
parser.add_argument(
"--word_dict_path",
type=str,
default="../data/vocabulary_min5k.txt",
help="The path of the word dictionary."
)
parser.add_argument(
"--label_dict_path",
type=str,
default="data/tag.dic",
help="The path of the label dictionary."
)
parser.add_argument(
"--word_rep_dict_path",
type=str,
default="conf/q2b.dic",
help="The path of the word replacement Dictionary."
)
parser.add_argument(
"--num_iterations",
type=int,
default=40000,
help="The maximum number of iterations. If set to 0 (default), do not limit the number."
)
#add elmo args
parser.add_argument(
"--elmo_l2_coef",
type=float,
default=0.001,
help="Weight decay. (default: %(default)f)"
)
parser.add_argument(
"--elmo_dict_dir",
default='data/vocabulary_min5k.txt',
help="If set, load elmo dict."
)
parser.add_argument(
'--pretrain_elmo_model_path',
default="data/baike_elmo_checkpoint",
help="If set, load elmo checkpoint."
)
args = parser.parse_args()
if len(args.corpus_proportion_list) != len(args.corpus_type_list):
sys.stderr.write(
"The length of corpus_proportion_list should be equal to the length of corpus_type_list.\n"
)
exit(-1)
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def to_lodtensor(data, place):
"""
Convert data in list into lodtensor.
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def test(exe, chunk_evaluator, save_dirname, test_data, place):
"""
Test the network in training.
"""
#import ipdb;ipdb.set_trace()
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
chunk_evaluator.reset()
for data in test_data():
word = to_lodtensor(list(map(lambda x: x[0], data)), place)
#import ipdb;ipdb.set_trace()
#word_r = to_lodtensor(list(map(lambda x: x[0].reverse(), data)), place)
#word_r_lod=[]
#for data1 in data:
# data1[0].reverse()
# word_r_lod.append(data1[0])
#import ipdb;ipdb.set_trace()
#word_r = to_lodtensor(word_r_lod,place)
target = to_lodtensor(list(map(lambda x: x[1], data)), place)
#import ipdb
#ipdb.set_trace()
result_list = exe.run(
inference_program,
feed={
"word": word,
"target": target
},
fetch_list=fetch_targets)
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
number_correct = np.array(result_list[2])
chunk_evaluator.update(int(number_infer[0]), int(number_label[0]),
int(number_correct[0]))
return chunk_evaluator.eval()
def train(args):
"""
Train the network.
"""
if not os.path.exists(args.model_save_dir):
os.mkdir(args.model_save_dir)
word2id_dict = reader.load_reverse_dict(args.word_dict_path)
label2id_dict = reader.load_reverse_dict(args.label_dict_path)
word_rep_dict = reader.load_dict(args.word_rep_dict_path)
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
avg_cost, crf_decode, word,target= lex_net(args, word_dict_len, label_dict_len)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.base_learning_rate)
adam_optimizer.minimize(avg_cost)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
train_reader_list = []
corpus_num = len(args.corpus_type_list)
for i in range(corpus_num):
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.file_reader(args.traindata_dir,
word2id_dict,
label2id_dict,
word_rep_dict,
args.corpus_type_list[i]),
buf_size=args.traindata_shuffle_buffer),
batch_size=int(args.batch_size * args.corpus_proportion_list[i]))
train_reader_list.append(train_reader)
test_reader = paddle.batch(
reader.file_reader(args.testdata_dir, word2id_dict, label2id_dict, word_rep_dict),
batch_size=args.batch_size)
train_reader_itr_list = []
for train_reader in train_reader_list:
cur_reader_itr = train_reader()
train_reader_itr_list.append(cur_reader_itr)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, target], place=place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#add elmo finetune layers
init_pretraining_params(exe,args.pretrain_elmo_model_path,fluid.default_main_program())
batch_id = 0
start_time = time.time()
eval_list = []
iter = 0
while True:
full_batch = []
cur_batch = []
for i in range(corpus_num):
reader_itr = train_reader_itr_list[i]
try:
cur_batch = next(reader_itr)
except StopIteration:
print(args.corpus_type_list[i] +
" corpus finish a pass of training")
new_reader = train_reader_list[i]
train_reader_itr_list[i] = new_reader()
cur_batch = next(train_reader_itr_list[i])
full_batch += cur_batch
random.shuffle(full_batch)
cost_var, nums_infer, nums_label, nums_correct = exe.run(
fluid.default_main_program(),
fetch_list=[
avg_cost, num_infer_chunks, num_label_chunks,
num_correct_chunks
],
feed=feeder.feed(full_batch))
print("batch_id:" + str(batch_id) + ", avg_cost:" + str(cost_var[0]))
chunk_evaluator.update(nums_infer, nums_label, nums_correct)
batch_id += 1
if (batch_id % args.save_model_per_batchs == 1):
save_exe = fluid.Executor(place)
save_dirname = os.path.join(args.model_save_dir,
"params_batch_%d" % batch_id)
#fluid.io.save_inference_model(save_dirname, ['word'], [crf_decode],
# save_exe)
temp_save_model = os.path.join(args.model_save_dir, "temp_model_for_test")
fluid.io.save_inference_model(temp_save_model, ['word', 'target'], [num_infer_chunks, num_label_chunks, num_correct_chunks], save_exe)
precision, recall, f1_score = chunk_evaluator.eval()
print("[train] batch_id:" + str(batch_id) + ", precision:" +
str(precision) + ", recall:" + str(recall) + ", f1:" +
str(f1_score))
chunk_evaluator.reset()
p, r, f1 = test(
exe, chunk_evaluator, temp_save_model, test_reader, place)
chunk_evaluator.reset()
print("[test] batch_id:" + str(batch_id) + ", precision:" +
str(p) + ", recall:" + str(r) + ", f1:" + str(f1))
end_time = time.time()
print("cur_batch_id:" + str(batch_id) + ", last " +
str(args.save_model_per_batchs) + " batchs, time_cost:" +
str(end_time - start_time))
start_time = time.time()
if len(eval_list) < 2 * args.eval_window:
eval_list.append(f1)
else:
eval_list.pop(0)
eval_list.append(f1)
last_avg_f1 = sum(
eval_list[0:args.eval_window]) / args.eval_window
cur_avg_f1 = sum(eval_list[
args.eval_window:2 * args.eval_window]) / args.eval_window
if cur_avg_f1 <= last_avg_f1:
return
else:
print("keep training!")
iter += 1
if (iter == args.num_iterations):
return
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
train(args)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册