提交 50816a2d 编写于 作者: 0 0YuanZhang0

update_sequence_tagging

上级 dc437431
......@@ -25,7 +25,8 @@ from hapi.text.text import TransformerDecoderLayer as TransformerDecoderLayer
from hapi.text.text import TransformerEncoder as TransformerEncoder
from hapi.text.text import TransformerDecoder as TransformerDecoder
from hapi.text.text import TransformerBeamSearchDecoder as TransformerBeamSearchDecoder
from hapi.text.text import DynamicGRU as DynamicGRU
from hapi.text.text import GRUCell as GRUCell
from hapi.text.text import GRUEncoderCell as GRUEncoderCell
from hapi.text.text import BiGRU as BiGRU
from hapi.text.text import Linear_chain_crf as Linear_chain_crf
from hapi.text.text import Crf_decoding as Crf_decoding
......
......@@ -31,8 +31,6 @@ import multiprocessing
import collections
import copy
import six
import sys
from functools import partial, reduce
import paddle
......@@ -46,11 +44,12 @@ from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
from paddle.fluid.layers import BeamSearchDecoder
__all__ = [
'RNNCell', 'BasicLSTMCell', 'BasicGRUCell', 'RNN', 'DynamicDecode',
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'DynamicGRU',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'GRUCell', 'GRUEncoderCell',
'BiGRU', 'Linear_chain_crf', 'Crf_decoding', 'SequenceTagging'
]
......@@ -220,19 +219,7 @@ class BasicLSTMCell(RNNCell):
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32',
forget_gate_weights={"w": None,
"h": None,
"b": None},
input_gate_weights={"w": None,
"h": None,
"b": None},
output_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
dtype='float32'):
super(BasicLSTMCell, self).__init__()
self._hidden_size = hidden_size
......@@ -246,188 +233,19 @@ class BasicLSTMCell(RNNCell):
self._dtype = dtype
self._input_size = input_size
assert isinstance(forget_gate_weights, dict)
assert isinstance(input_gate_weights, dict)
assert isinstance(output_gate_weights, dict)
assert isinstance(cell_weights, dict)
# forgot get parameters
if "w" in forget_gate_weights and forget_gate_weights["w"] is not None:
self.fg_w = forget_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_w"
else:
tmp_param_attr = self._param_attr
self.fg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in forget_gate_weights and forget_gate_weights["h"] is not None:
self.fg_h = forget_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_h"
else:
tmp_param_attr = self._param_attr
self.fg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in forget_gate_weights and forget_gate_weights["b"] is not None:
self.fg_b = forget_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_forget_gate_b"
else:
tmp_param_attr = self._bias_attr
self.fg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# input gate parameters
if "w" in input_gate_weights and input_gate_weights["w"] is not None:
self.ig_w = input_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_w"
else:
tmp_param_attr = self._param_attr
self.ig_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in input_gate_weights and input_gate_weights["h"] is not None:
self.ig_h = input_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_h"
else:
tmp_param_attr = self._param_attr
self.ig_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in input_gate_weights and input_gate_weights["b"] is not None:
self.ig_b = input_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_input_gate_b"
else:
tmp_param_attr = self._bias_attr
self.ig_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# output gate parameters
if "w" in output_gate_weights and output_gate_weights["w"] is not None:
self.og_w = output_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_w"
else:
tmp_param_attr = self._param_attr
self.og_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in output_gate_weights and output_gate_weights["h"] is not None:
self.og_h = output_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_h"
else:
tmp_param_attr = self._param_attr
self.og_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in output_gate_weights and output_gate_weights["b"] is not None:
self.og_b = output_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_output_gate_b"
else:
tmp_param_attr = self._bias_attr
self.og_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# cell parameters
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = self._param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = self._param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[
self._input_size + self._hidden_size, 4 * self._hidden_size
],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = self._bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
# the weight is concated here in order to make the computation more efficent.
weight_w = fluid.layers.concat(
[self.ig_w, self.c_w, self.fg_w, self.og_w], axis=-1)
weight_h = fluid.layers.concat(
[self.ig_h, self.c_h, self.fg_h, self.og_h], axis=-1)
self._weight = fluid.layers.concat([weight_w, weight_h], axis=0)
self._bias = fluid.layers.concat(
[self.ig_b, self.c_b, self.fg_b, self.og_b])
def forward(self, input, state):
pre_hidden, pre_cell = state
concat_input_hidden = layers.concat([input, pre_hidden], 1)
......@@ -490,30 +308,16 @@ class BasicGRUCell(RNNCell):
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32',
update_gate_weights={"w": None,
"h": None,
"b": None},
reset_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
dtype='float32'):
super(BasicGRUCell, self).__init__()
self._input_size = input_size
self._hiden_size = hidden_size
self._hidden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._dtype = dtype
assert isinstance(update_gate_weights, dict)
assert isinstance(reset_gate_weights, dict)
assert isinstance(cell_weights, dict)
if self._param_attr is not None and self._param_attr.name is not None:
gate_param_attr = copy.deepcopy(self._param_attr)
candidate_param_attr = copy.deepcopy(self._param_attr)
......@@ -523,6 +327,16 @@ class BasicGRUCell(RNNCell):
gate_param_attr = self._param_attr
candidate_param_attr = self._param_attr
self._gate_weight = self.create_parameter(
attr=gate_param_attr,
shape=[self._input_size + self._hidden_size, 2 * self._hidden_size],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=candidate_param_attr,
shape=[self._input_size + self._hidden_size, self._hidden_size],
dtype=self._dtype)
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
......@@ -532,140 +346,17 @@ class BasicGRUCell(RNNCell):
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr
# create the parameters of gates in gru
if "w" in update_gate_weights and update_gate_weights["w"] is not None:
self.ug_w = update_gate_weights["w"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_update_gate_w"
else:
tmp_param_attr = gate_param_attr
self.ug_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in update_gate_weights and update_gate_weights["h"] is not None:
self.ug_h = update_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_update_gate_h"
else:
tmp_param_attr = gate_param_attr
self.ug_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in update_gate_weights and update_gate_weights["b"] is not None:
self.ug_b = update_gate_weights["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_update_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.ug_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# reset gate parameters
if "w" in reset_gate_weights and reset_gate_weights["w"] is not None:
self.rg_w = reset_gate_weights["w"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_w"
else:
tmp_param_attr = gate_param_attr
self.rg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in reset_gate_weights and reset_gate_weights["h"] is not None:
self.rg_h = reset_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_h"
else:
tmp_param_attr = gate_param_attr
self.rg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in reset_gate_weights and reset_gate_weights["b"] is not None:
self.rg_b = reused_params["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_reset_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.rg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
self._gate_bias = self.create_parameter(
attr=gate_bias_attr,
shape=[2 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
# cell parameters
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = gate_param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = gate_param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if candidate_bias_attr is not None and candidate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = gate_bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
self._candidate_bias = self.create_parameter(
attr=candidate_bias_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
rg_weights = layers.concat([self.rg_w, self.rg_h], axis=0)
ug_weights = layers.concat([self.ug_w, self.ug_h], axis=0)
self._gate_weight = layers.concat([rg_weights, ug_weights], axis=-1)
self._candidate_weight = layers.concat([self.c_w, self.c_h], axis=0)
self._gate_bias = layers.concat([self.rg_b, self.ug_b], axis=0)
self._candidate_bias = self.c_b
def forward(self, input, state):
pre_hidden = state
concat_input_hidden = layers.concat([input, pre_hidden], axis=1)
......@@ -870,7 +561,6 @@ class DynamicDecode(Layer):
# To confirm states.finished/finished be consistent with
# next_finished.
layers.assign(next_finished, finished)
next_sequence_lengths = layers.elementwise_add(
sequence_lengths,
layers.cast(
......@@ -1010,11 +700,7 @@ class PrePostProcessLayer(Layer):
PrePostProcessLayer
"""
def __init__(self,
process_cmd,
d_model,
dropout_rate,
reused_layer_norm=None):
def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
self.functors = []
......@@ -1022,21 +708,16 @@ class PrePostProcessLayer(Layer):
if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x)
elif cmd == "n": # add layer normalization
if reused_layer_norm is not None:
layer_norm = reused_layer_norm
else:
layer_norm = LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
self.functors.append(
self.add_sublayer(
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
layer_norm))
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout
self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False)
......@@ -1056,48 +737,21 @@ class MultiHeadAttention(Layer):
Multi-Head Attention
"""
def __init__(self,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.0,
reused_query_fc=None,
reused_key_fc=None,
reused_value_fc=None,
reused_proj_fc=None):
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
if reused_query_fc is not None:
self.q_fc = reused_query_fc
else:
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
if reused_key_fc is not None:
self.k_fc = reused_key_fc
else:
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
if reused_value_fc is not None:
self.v_fc = reused_value_fc
else:
self.v_fc = Linear(
input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
if reused_proj_fc is not None:
self.proj_fc = reused_proj_fc
else:
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.proj_fc = Linear(
input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
......@@ -1174,23 +828,11 @@ class FFN(Layer):
Feed-Forward Network
"""
def __init__(self,
d_inner_hid,
d_model,
dropout_rate,
fc1_act="relu",
reused_fc1=None,
reused_fc2=None):
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
if reused_fc1 is not None:
self.fc1 = reused_fc1
else:
self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act=fc1_act)
if reused_fc2 is not None:
self.fc2 = reused_fc2
else:
input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x):
......@@ -1217,52 +859,22 @@ class TransformerEncoderLayer(Layer):
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
ffn_fc1_act="relu",
reused_pre_selatt_layernorm=None,
reused_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
postprocess_cmd="da"):
super(TransformerEncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_selatt_layernorm)
self.self_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_multihead_att_weights["reused_query_fc"],
reused_key_fc=reused_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_multihead_att_weights["reused_value_fc"],
reused_proj_fc=reused_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
fc1_act=ffn_fc1_act,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout,
reused_post_ffn_layernorm)
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(
......@@ -1290,8 +902,7 @@ class TransformerEncoder(Layer):
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
ffn_fc1_act="relu"):
postprocess_cmd="da"):
super(TransformerEncoder, self).__init__()
......@@ -1301,17 +912,9 @@ class TransformerEncoder(Layer):
self.add_sublayer(
"layer_%d" % i,
TransformerEncoderLayer(
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
ffn_fc1_act=ffn_fc1_act)))
n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
......@@ -1338,79 +941,28 @@ class TransformerDecoderLayer(Layer):
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
reused_pre_selfatt_layernorm=None,
reused_self_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_crossatt_layernorm=None,
reused_cross_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_crossatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
postprocess_cmd="da"):
super(TransformerDecoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_selfatt_layernorm)
self.self_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_self_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_self_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_self_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_self_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_crossatt_layernorm)
self.cross_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_cross_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_cross_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_cross_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_cross_multihead_att_weights[
"reused_proj_fc"])
self.postprocesser2 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_crossatt_layernorm)
prepostprocess_dropout)
self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout,
reused_post_ffn_layernorm)
prepostprocess_dropout)
def forward(self,
dec_input,
......@@ -1479,98 +1031,99 @@ class TransformerDecoder(Layer):
]
class DynamicGRU(fluid.dygraph.Layer):
class GRUCell(RNNCell):
def __init__(self,
size,
h_0=None,
input_size,
hidden_size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
init_size=None):
super(DynamicGRU, self).__init__()
origin_mode=False):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.fc_layer = Linear(
input_size,
hidden_size * 3,
param_attr=param_attr)
self.gru_unit = GRUUnit(
size * 3,
hidden_size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs, states):
# for GRUCell, `step_outputs` and `new_states` both are hidden
x = self.fc_layer(inputs)
hidden, _, _ = self.gru_unit(x, states)
return hidden, hidden
def forward(self, inputs):
hidden = self.h_0
res = []
@property
def state_shape(self):
return [self.hidden_size]
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[:, i:i + 1, :]
input_ = fluid.layers.reshape(
input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(
hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class GRUEncoderCell(RNNCell):
def __init__(self,
num_layers,
input_size,
hidden_size,
dropout_prob=0.,
init_scale=0.1):
super(GRUEncoderCell, self).__init__()
self.dropout_prob = dropout_prob
# use add_sublayer to add multi-layers
self.gru_cells = []
for i in range(num_layers):
self.gru_cells.append(
self.add_sublayer(
"gru_%d" % i,
#BasicGRUCell(
GRUCell(
input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))))
def forward(self, step_input, states):
new_states = []
for i, gru_cell in enumerate(self.gru_cells):
out, state = gru_cell(step_input, states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation='upscale_in_train'
) if self.dropout_prob > 0 else out
new_states.append(step_input)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.gru_cells]
class BiGRU(fluid.dygraph.Layer):
def __init__(self, input_dim, grnn_hidden_dim, init_bound, h_0=None):
super(BiGRU, self).__init__()
self.gru = RNN(GRUEncoderCell(1, input_dim,
grnn_hidden_dim, 0.0, init_bound),
is_reverse=False,
time_major=False)
self.pre_gru = Linear(
input_dim=input_dim,
output_dim=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
self.gru = DynamicGRU(
size=grnn_hidden_dim,
h_0=h_0,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
self.pre_gru_r = Linear(
input_dim=input_dim,
output_dim=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
self.gru_r = DynamicGRU(
size=grnn_hidden_dim,
self.gru_r = RNN(GRUEncoderCell(1, input_dim,
grnn_hidden_dim, 0.0, init_bound),
is_reverse=True,
h_0=h_0,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
time_major=False)
def forward(self, input_feature):
res_pre_gru = self.pre_gru(input_feature)
res_gru = self.gru(res_pre_gru)
res_pre_gru_r = self.pre_gru_r(input_feature)
res_gru_r = self.gru_r(res_pre_gru_r)
bi_merge = fluid.layers.concat(input=[res_gru, res_gru_r], axis=-1)
pre_gru, pre_state = self.gru(input_feature)
gru_r, r_state = self.gru_r(input_feature)
bi_merge = fluid.layers.concat(input=[pre_gru, gru_r], axis=-1)
return bi_merge
......@@ -1610,7 +1163,7 @@ class Linear_chain_crf(fluid.dygraph.Layer):
"Transition": self._transition,
"Label": [label]
}
if length:
if length is not None:
this_inputs['Length'] = [length]
self._helper.append_op(
type='linear_chain_crf',
......@@ -1655,7 +1208,7 @@ class Crf_decoding(fluid.dygraph.Layer):
"Transition": self._transition,
"Label": label
}
if length:
if length is not None:
this_inputs['Length'] = [length]
self._helper.append_op(
type='crf_decoding',
......@@ -1767,7 +1320,7 @@ class SequenceTagging(fluid.dygraph.Layer):
emission = self.fc(bigru_output)
if target:
if target is not None:
crf_cost = self.linear_chain_crf(
input=emission, label=target, length=lengths)
avg_cost = fluid.layers.mean(x=crf_cost)
......@@ -1775,5 +1328,6 @@ class SequenceTagging(fluid.dygraph.Layer):
crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, avg_cost, lengths
else:
self.linear_chain_crf.weight = self.crf_decoding.weight
crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, lengths
......@@ -6,7 +6,7 @@ Sequence Tagging,是一个序列标注模型,模型可用于实现,分词
|模型|Precision|Recall|F1-score|
|:-:|:-:|:-:|:-:|
|Lexical Analysis|89.2%|89.4%|89.3%|
|Lexical Analysis|88.26%|89.20%|88.73%|
## 2. 快速开始
......@@ -139,7 +139,7 @@ python predict.py \
--init_from_checkpoint model_baseline/params \
--output_file predict.result \
--mode predict \
--device gpu \
--device cpu \
-d
# -d: 是否使用动态图模式进行训练,如果使用静态图训练,命令行请删除-d参数
......@@ -157,7 +157,7 @@ python eval.py \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--init_from_checkpoint ./model_baseline/params \
--device gpu \
--device cpu \
-d
# -d: 是否使用动态图模式进行训练,如果使用静态图训练,命令行请删除-d参数
......@@ -189,7 +189,10 @@ python eval.py \
### 模型原理介绍
上面介绍的模型原理如下图所示:<br />
![GRU-CRF-MODEL](./images/gru-crf-model.png)
<p align="center">
<img src="./images/gru-crf-model.png" width = "340" height = "300" /> <br />
Overall Architecture of GRU-CRF-MODEL
</p>
### 数据格式
训练使用的数据可以由用户根据实际的应用场景,自己组织数据。除了第一行是 `text_a\tlabel` 固定的开头,后面的每行数据都是由两列组成,以制表符分隔,第一列是 utf-8 编码的中文文本,以 `\002` 分割,第二列是对应每个字的标注,以 `\002` 分隔。我们采用 IOB2 标注体系,即以 X-B 作为类型为 X 的词的开始,以 X-I 作为类型为 X 的词的持续,以 O 表示不关注的字(实际上,在词性、专名联合标注中,不存在 O )。示例如下:
......
......@@ -25,8 +25,9 @@ import math
import argparse
import numpy as np
from train import SeqTagging, Chunk_eval
from train import SeqTagging
from utils.check import check_gpu, check_version
from utils.metrics import chunk_count
from reader import LacDataset, create_lexnet_data_generator, create_dataloader
work_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
......@@ -42,14 +43,13 @@ def main(args):
place = set_device(args.device)
fluid.enable_dygraph(place) if args.dynamic else None
inputs = [Input([None, args.max_seq_len], 'int64', name='words'),
inputs = [Input([None, None], 'int64', name='words'),
Input([None], 'int64', name='length')]
feed_list = None if args.dynamic else [x.forward() for x in inputs]
dataset = LacDataset(args)
eval_path = args.test_file
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
......@@ -69,21 +69,19 @@ def main(args):
model.mode = "test"
model.prepare(inputs=inputs)
model.load(args.init_from_checkpoint, skip_mismatch=True)
model.load(args.init_from_checkpoint)
f = open(args.output_file, "wb")
for data in eval_dataset():
words, lens, targets, targets = data
crf_decode, length = model.test(inputs=flatten(data))
crf_decode = fluid.dygraph.to_variable(crf_decode)
length = fluid.dygraph.to_variable(length)
(num_infer_chunks, num_label_chunks, num_correct_chunks) = chunk_eval(
input=crf_decode,
label=targets,
seq_length=length)
print(num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
chunk_evaluator.update(num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
if len(data) == 1:
batch_data = data[0]
targets = np.array(batch_data[2])
else:
batch_data = data
targets = batch_data[2].numpy()
inputs_data = [batch_data[0], batch_data[1]]
crf_decode, length = model.test(inputs=inputs_data)
num_infer_chunks, num_label_chunks, num_correct_chunks = chunk_count(crf_decode, targets, length, dataset.id2label_dict)
chunk_evaluator.update(num_infer_chunks, num_label_chunks, num_correct_chunks)
precision, recall, f1 = chunk_evaluator.eval()
print("[test] P: %.5f, R: %.5f, F1: %.5f" % (precision, recall, f1))
......@@ -176,7 +174,8 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
check_gpu(args.device)
use_gpu = True if args.device == "gpu" else False
check_gpu(use_gpu)
check_version()
main(args)
......@@ -42,7 +42,7 @@ def main(args):
place = set_device(args.device)
fluid.enable_dygraph(place) if args.dynamic else None
inputs = [Input([None, args.max_seq_len], 'int64', name='words'),
inputs = [Input([None, None], 'int64', name='words'),
Input([None], 'int64', name='length')]
feed_list = None if args.dynamic else [x.forward() for x in inputs]
......@@ -70,8 +70,11 @@ def main(args):
f = open(args.output_file, "wb")
for data in predict_dataset():
results, length = model.test(inputs=flatten(data))
#length_list = np.fromstring(length, dtype=str)
if len(data) == 1:
input_data = data[0]
else:
input_data = data
results, length = model.test(inputs=flatten(input_data))
for i in range(len(results)):
word_len = length[i]
word_ids = results[i][: word_len]
......@@ -162,7 +165,8 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
check_gpu(args.device)
use_gpu = True if args.device == "gpu" else False
check_gpu(use_gpu)
check_version()
main(args)
......@@ -21,7 +21,7 @@ from __future__ import print_function
import io
import numpy as np
import paddle.fluid as fluid
import paddle
class LacDataset(object):
......@@ -120,7 +120,7 @@ class LacDataset(object):
def wrapper():
fread = io.open(filename, "r", encoding="utf-8")
if mode == "train" or mode == "test":
if mode == "train":
headline = next(fread)
headline = headline.strip().split('\t')
assert len(headline) == 2 and headline[0] == "text_a" and headline[
......@@ -133,6 +133,8 @@ class LacDataset(object):
word_ids = self.word_to_ids(words.split("\002"))
label_ids = self.label_to_ids(labels.split("\002"))
assert len(word_ids) == len(label_ids)
words_len = np.int64(len(word_ids))
word_ids = word_ids[0:max_seq_len]
words_len = np.int64(len(word_ids))
word_ids += [0 for _ in range(max_seq_len - words_len)]
......@@ -140,6 +142,21 @@ class LacDataset(object):
label_ids += [0 for _ in range(max_seq_len - words_len)]
assert len(word_ids) == len(label_ids)
yield word_ids, label_ids, words_len
elif mode == "test":
headline = next(fread)
headline = headline.strip().split('\t')
assert len(headline) == 2 and headline[0] == "text_a" and headline[
1] == "label"
buf = []
for line in fread:
words, labels = line.strip("\n").split("\t")
if len(words) < 1:
continue
word_ids = self.word_to_ids(words.split("\002"))
label_ids = self.label_to_ids(labels.split("\002"))
assert len(word_ids) == len(label_ids)
words_len = np.int64(len(word_ids))
yield word_ids, label_ids, words_len
else:
for line in fread:
words = line.strip("\n").split('\t')[0]
......@@ -158,8 +175,15 @@ class LacDataset(object):
def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
def padding_data(max_len, batch_data):
padding_batch_data = []
for data in batch_data:
data += [0 for _ in range(max_len - len(data))]
padding_batch_data.append(data)
return padding_batch_data
def wrapper():
if mode == "train" or mode == "test":
if mode == "train":
batch_words, batch_labels, seq_lens = [], [], []
for epoch in xrange(args.epoch):
for instance in reader.file_reader(
......@@ -175,6 +199,26 @@ def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
if len(seq_lens) > 0:
yield batch_words, seq_lens, batch_labels, batch_labels
elif mode == "test":
batch_words, batch_labels, seq_lens, max_len = [], [], [], 0
for instance in reader.file_reader(
file_name, mode, max_seq_len=args.max_seq_len)():
words, labels, words_len = instance
max_len = words_len if words_len > max_len else max_len
if len(seq_lens) < args.batch_size:
batch_words.append(words)
seq_lens.append(words_len)
batch_labels.append(labels)
if len(seq_lens) == args.batch_size:
padding_batch_words = padding_data(max_len, batch_words)
padding_batch_labels = padding_data(max_len, batch_labels)
yield padding_batch_words, seq_lens, padding_batch_labels, padding_batch_labels
batch_words, batch_labels, seq_lens, max_len = [], [], [], 0
if len(seq_lens) > 0:
padding_batch_words = padding_data(max_len, batch_words)
padding_batch_labels = padding_data(max_len, batch_labels)
yield padding_batch_words, seq_lens, padding_batch_labels, padding_batch_labels
else:
batch_words, seq_lens, max_len = [], [], 0
for instance in reader.file_reader(
......@@ -183,20 +227,13 @@ def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
if len(seq_lens) < args.batch_size:
batch_words.append(words)
seq_lens.append(words_len)
if words_len > max_len:
max_len = words_len
max_len = words_len if words_len > max_len else max_len
if len(seq_lens) == args.batch_size:
padding_batch_words = []
for words in batch_words:
words += [0 for _ in range(max_len - len(words))]
padding_batch_words.append(words)
padding_batch_words = padding_data(max_len, batch_words)
yield padding_batch_words, seq_lens
batch_words, seq_lens, max_len = [], [], 0
if len(seq_lens) > 0:
padding_batch_words = []
for words in batch_words:
words += [0 for _ in range(max_len - len(words))]
padding_batch_words.append(words)
padding_batch_words = padding_data(max_len, batch_words)
yield padding_batch_words, seq_lens
return wrapper
......@@ -204,13 +241,13 @@ def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
def create_dataloader(generator, place, feed_list=None):
if not feed_list:
data_loader = fluid.io.DataLoader.from_generator(
data_loader = paddle.io.DataLoader.from_generator(
capacity=50,
use_double_buffer=True,
iterable=True,
return_list=True)
else:
data_loader = fluid.io.DataLoader.from_generator(
data_loader = paddle.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=50,
use_double_buffer=True,
......
......@@ -154,9 +154,10 @@ class ChunkEval(Metric):
int(math.ceil((num_labels - 1) / 2.0)), "IOB")
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
crf_decode = pred[0]
lengths = pred[2]
def add_metric_op(self, *args):
crf_decode = args[0]
lengths = args[2]
label = args[3]
(num_infer_chunks, num_label_chunks,
num_correct_chunks) = self.chunk_eval(
input=crf_decode, label=label, seq_length=lengths)
......@@ -204,11 +205,11 @@ def main(args):
place = set_device(args.device)
fluid.enable_dygraph(place) if args.dynamic else None
inputs = [Input([None, args.max_seq_len], 'int64', name='words'),
inputs = [Input([None, None], 'int64', name='words'),
Input([None], 'int64', name='length'),
Input([None, args.max_seq_len], 'int64', name='target')]
Input([None, None], 'int64', name='target')]
labels = [Input([None, args.max_seq_len], 'int64', name='labels')]
labels = [Input([None, None], 'int64', name='labels')]
feed_list = None if args.dynamic else [x.forward() for x in inputs + labels]
dataset = LacDataset(args)
......@@ -343,7 +344,8 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
check_gpu(args.device)
use_gpu = True if args.device == "gpu" else False
check_gpu(use_gpu)
check_version()
main(args)
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 train.py \
--train_file ./data/train.tsv \
--test_file ./data/test.tsv \
--word_dict_path ./data/word.dic \
--label_dict_path ./data/tag.dic \
--word_rep_dict_path ./data/q2b.dic \
--device gpu \
--grnn_hidden_dim 128 \
--word_emb_dim 128 \
--bigru_num 2 \
--base_learning_rate 1e-3 \
--batch_size 300 \
--epoch 10 \
--save_dir ./model \
-d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
__all__ = ['chunk_count', "build_chunk"]
def build_chunk(data_list, id2label_dict):
"""
Assembly entity
"""
tag_list = [id2label_dict.get(str(id)) for id in data_list]
ner_dict = {}
ner_str = ""
ner_start = 0
for i in range(len(tag_list)):
tag = tag_list[i]
if tag == u"O":
if i != 0:
key = "%d_%d" % (ner_start, i - 1)
ner_dict[key] = ner_str
ner_start = i
ner_str = tag
elif tag.endswith(u"B"):
if i != 0:
key = "%d_%d" % (ner_start, i - 1)
ner_dict[key] = ner_str
ner_start = i
ner_str = tag.split('-')[0]
elif tag.endswith(u"I"):
if tag.split('-')[0] != ner_str:
if i != 0:
key = "%d_%d" % (ner_start, i - 1)
ner_dict[key] = ner_str
ner_start = i
ner_str = tag.split('-')[0]
return ner_dict
def chunk_count(infer_numpy, label_numpy, seq_len, id2label_dict):
"""
calculate num_correct_chunks num_error_chunks total_num for metrics
"""
num_infer_chunks, num_label_chunks, num_correct_chunks = 0, 0, 0
assert infer_numpy.shape[0] == label_numpy.shape[0]
for i in range(infer_numpy.shape[0]):
infer_list = infer_numpy[i][: seq_len[i]]
label_list = label_numpy[i][: seq_len[i]]
infer_dict = build_chunk(infer_list, id2label_dict)
num_infer_chunks += len(infer_dict)
label_dict = build_chunk(label_list, id2label_dict)
num_label_chunks += len(label_dict)
for key in infer_dict:
if key in label_dict and label_dict[key] == infer_dict[key]:
num_correct_chunks += 1
return num_infer_chunks, num_label_chunks, num_correct_chunks
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册