提交 85e422bb 编写于 作者: X xyzhou-puck

refine text.py

上级 ed14907e
...@@ -22,7 +22,7 @@ import sys ...@@ -22,7 +22,7 @@ import sys
if six.PY2: if six.PY2:
reload(sys) reload(sys)
sys.setdefaultencoding('utf8') sys.setdefaultencoding('utf8')
import ast import ast
import time import time
import argparse as argparse import argparse as argparse
...@@ -44,13 +44,12 @@ from paddle.fluid import layers ...@@ -44,13 +44,12 @@ from paddle.fluid import layers
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
from paddle.fluid.layers import BeamSearchDecoder from paddle.fluid.layers import BeamSearchDecoder
__all__ = [ __all__ = [
'RNNCell', 'BasicLSTMCell', 'BasicGRUCell', 'RNN', 'DynamicDecode', 'RNNCell', 'BasicLSTMCell', 'BasicGRUCell', 'RNN', 'DynamicDecode',
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN', 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer', 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'GRUCell', 'GRUEncoderCell', 'TransformerDecoder', 'TransformerBeamSearchDecoder', 'BiGRU',
'BiGRU', 'Linear_chain_crf', 'Crf_decoding', 'SequenceTagging' 'Linear_chain_crf', 'Crf_decoding', 'SequenceTagging'
] ]
...@@ -219,7 +218,19 @@ class BasicLSTMCell(RNNCell): ...@@ -219,7 +218,19 @@ class BasicLSTMCell(RNNCell):
gate_activation=None, gate_activation=None,
activation=None, activation=None,
forget_bias=1.0, forget_bias=1.0,
dtype='float32'): dtype='float32',
forget_gate_weights={"w": None,
"h": None,
"b": None},
input_gate_weights={"w": None,
"h": None,
"b": None},
output_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
super(BasicLSTMCell, self).__init__() super(BasicLSTMCell, self).__init__()
self._hidden_size = hidden_size self._hidden_size = hidden_size
...@@ -233,25 +244,225 @@ class BasicLSTMCell(RNNCell): ...@@ -233,25 +244,225 @@ class BasicLSTMCell(RNNCell):
self._dtype = dtype self._dtype = dtype
self._input_size = input_size self._input_size = input_size
self._weight = self.create_parameter( self.use_customized_weight = False
attr=self._param_attr, for _weights in [
shape=[ forget_gate_weights, input_gate_weights, output_gate_weights,
self._input_size + self._hidden_size, 4 * self._hidden_size cell_weights
], ]:
dtype=self._dtype) for _key in _weights:
if _weights[_key] is not None:
self._bias = self.create_parameter( self.use_customized_weight = True
attr=self._bias_attr, break
shape=[4 * self._hidden_size], if self.use_customized_weight:
dtype=self._dtype, break
is_bias=True)
if not self.use_customized_weight:
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[
self._input_size + self._hidden_size, 4 * self._hidden_size
],
dtype=self._dtype)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
else:
if "w" in forget_gate_weights and forget_gate_weights[
"w"] is not None:
self.fg_w = forget_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_w"
else:
tmp_param_attr = self._param_attr
self.fg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in forget_gate_weights and forget_gate_weights[
"h"] is not None:
self.fg_h = forget_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_h"
else:
tmp_param_attr = self._param_attr
self.fg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in forget_gate_weights and forget_gate_weights[
"b"] is not None:
self.fg_b = forget_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_forget_gate_b"
else:
tmp_param_attr = self._bias_attr
self.fg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in input_gate_weights and input_gate_weights[
"w"] is not None:
self.ig_w = input_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_w"
else:
tmp_param_attr = self._param_attr
self.ig_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in input_gate_weights and input_gate_weights[
"h"] is not None:
self.ig_h = input_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_h"
else:
tmp_param_attr = self._param_attr
self.ig_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in input_gate_weights and input_gate_weights[
"b"] is not None:
self.ig_b = input_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_input_gate_b"
else:
tmp_param_attr = self._bias_attr
self.ig_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in output_gate_weights and output_gate_weights[
"w"] is not None:
self.og_w = output_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_w"
else:
tmp_param_attr = self._param_attr
self.og_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in output_gate_weights and output_gate_weights[
"h"] is not None:
self.og_h = output_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_h"
else:
tmp_param_attr = self._param_attr
self.og_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in output_gate_weights and output_gate_weights[
"b"] is not None:
self.og_b = output_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_output_gate_b"
else:
tmp_param_attr = self._bias_attr
self.og_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = self._param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = self._param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = self._bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state): def forward(self, input, state):
if self.use_customized_weight:
weight_w = fluid.layers.concat(
[self.ig_w, self.c_w, self.fg_w, self.og_w], axis=-1)
weight_h = fluid.layers.concat(
[self.ig_h, self.c_h, self.fg_h, self.og_h], axis=-1)
_weight = fluid.layers.concat([weight_w, weight_h], axis=0)
_bias = fluid.layers.concat(
[self.ig_b, self.c_b, self.fg_b, self.og_b])
else:
_weight = self._weight
_bias = self._bias
pre_hidden, pre_cell = state pre_hidden, pre_cell = state
concat_input_hidden = layers.concat([input, pre_hidden], 1) concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) gate_input = layers.matmul(x=concat_input_hidden, y=_weight)
gate_input = layers.elementwise_add(gate_input, self._bias) gate_input = layers.elementwise_add(gate_input, _bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add( new_cell = layers.elementwise_add(
layers.elementwise_mul( layers.elementwise_mul(
...@@ -308,7 +519,16 @@ class BasicGRUCell(RNNCell): ...@@ -308,7 +519,16 @@ class BasicGRUCell(RNNCell):
bias_attr=None, bias_attr=None,
gate_activation=None, gate_activation=None,
activation=None, activation=None,
dtype='float32'): dtype='float32',
update_gate_weights={"w": None,
"h": None,
"b": None},
reset_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
super(BasicGRUCell, self).__init__() super(BasicGRUCell, self).__init__()
self._input_size = input_size self._input_size = input_size
self._hidden_size = hidden_size self._hidden_size = hidden_size
...@@ -318,6 +538,20 @@ class BasicGRUCell(RNNCell): ...@@ -318,6 +538,20 @@ class BasicGRUCell(RNNCell):
self._activation = activation or layers.tanh self._activation = activation or layers.tanh
self._dtype = dtype self._dtype = dtype
assert isinstance(update_gate_weights, dict)
assert isinstance(reset_gate_weights, dict)
assert isinstance(cell_weights, dict)
self.use_customized_weight = False
for _weights in [
update_gate_weights, reset_gate_weights, cell_weights
]:
for _key in _weights:
if _weights[_key] is not None:
self.use_customized_weight = True
if self.use_customized_weight:
break
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
gate_param_attr = copy.deepcopy(self._param_attr) gate_param_attr = copy.deepcopy(self._param_attr)
candidate_param_attr = copy.deepcopy(self._param_attr) candidate_param_attr = copy.deepcopy(self._param_attr)
...@@ -327,43 +561,194 @@ class BasicGRUCell(RNNCell): ...@@ -327,43 +561,194 @@ class BasicGRUCell(RNNCell):
gate_param_attr = self._param_attr gate_param_attr = self._param_attr
candidate_param_attr = self._param_attr candidate_param_attr = self._param_attr
self._gate_weight = self.create_parameter( if not self.use_customized_weight:
attr=gate_param_attr, self._gate_weight = self.create_parameter(
shape=[self._input_size + self._hidden_size, 2 * self._hidden_size], attr=gate_param_attr,
dtype=self._dtype) shape=[
self._input_size + self._hidden_size, 2 * self._hidden_size
self._candidate_weight = self.create_parameter( ],
attr=candidate_param_attr, dtype=self._dtype)
shape=[self._input_size + self._hidden_size, self._hidden_size],
dtype=self._dtype) self._candidate_weight = self.create_parameter(
attr=candidate_param_attr,
shape=[
self._input_size + self._hidden_size, self._hidden_size
],
dtype=self._dtype)
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
gate_bias_attr.name += "_gate"
candidate_bias_attr.name += "_candidate"
else:
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr
self._gate_bias = self.create_parameter(
attr=gate_bias_attr,
shape=[2 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=candidate_bias_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
gate_bias_attr.name += "_gate"
candidate_bias_attr.name += "_candidate"
else: else:
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr # create the parameters of gates in gru
if "w" in update_gate_weights and update_gate_weights[
self._gate_bias = self.create_parameter( "w"] is not None:
attr=gate_bias_attr, self.ug_w = update_gate_weights["w"]
shape=[2 * self._hidden_size], else:
dtype=self._dtype, if gate_param_attr is not None and gate_param_attr.name is not None:
is_bias=True) tmp_param_attr = copy.deepcopy(gate_param_attr)
self._candidate_bias = self.create_parameter( tmp_param_attr.name += "_update_gate_w"
attr=candidate_bias_attr, else:
shape=[self._hidden_size], tmp_param_attr = gate_param_attr
dtype=self._dtype, self.ug_w = self.create_parameter(
is_bias=True) attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in update_gate_weights and update_gate_weights[
"h"] is not None:
self.ug_h = update_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_update_gate_h"
else:
tmp_param_attr = gate_param_attr
self.ug_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in update_gate_weights and update_gate_weights[
"b"] is not None:
self.ug_b = update_gate_weights["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_update_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.ug_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# reset gate parameters
if "w" in reset_gate_weights and reset_gate_weights[
"w"] is not None:
self.rg_w = reset_gate_weights["w"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_w"
else:
tmp_param_attr = gate_param_attr
self.rg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in reset_gate_weights and reset_gate_weights[
"h"] is not None:
self.rg_h = reset_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_h"
else:
tmp_param_attr = gate_param_attr
self.rg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in reset_gate_weights and reset_gate_weights[
"b"] is not None:
self.rg_b = reused_params["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_reset_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.rg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# cell parameters
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = gate_param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = gate_param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if candidate_bias_attr is not None and candidate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = gate_bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state): def forward(self, input, state):
if self.use_customized_weight:
rg_weights = layers.concat([self.rg_w, self.rg_h], axis=0)
ug_weights = layers.concat([self.ug_w, self.ug_h], axis=0)
_gate_weight = layers.concat([rg_weights, ug_weights], axis=-1)
_candidate_weight = layers.concat([self.c_w, self.c_h], axis=0)
_gate_bias = layers.concat([self.rg_b, self.ug_b], axis=0)
_candidate_bias = self.c_b
else:
_gate_weight = self._gate_weight
_gate_bias = self._gate_bias
_candidate_weight = self._candidate_weight
_candidate_bias = self._candidate_bias
pre_hidden = state pre_hidden = state
concat_input_hidden = layers.concat([input, pre_hidden], axis=1) concat_input_hidden = layers.concat([input, pre_hidden], axis=1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight) gate_input = layers.matmul(x=concat_input_hidden, y=_gate_weight)
gate_input = layers.elementwise_add(gate_input, self._gate_bias) gate_input = layers.elementwise_add(gate_input, _gate_bias)
gate_input = self._gate_activation(gate_input) gate_input = self._gate_activation(gate_input)
r, u = layers.split(gate_input, num_or_sections=2, dim=1) r, u = layers.split(gate_input, num_or_sections=2, dim=1)
...@@ -371,8 +756,8 @@ class BasicGRUCell(RNNCell): ...@@ -371,8 +756,8 @@ class BasicGRUCell(RNNCell):
r_hidden = r * pre_hidden r_hidden = r * pre_hidden
candidate = layers.matmul( candidate = layers.matmul(
layers.concat([input, r_hidden], 1), self._candidate_weight) layers.concat([input, r_hidden], 1), _candidate_weight)
candidate = layers.elementwise_add(candidate, self._candidate_bias) candidate = layers.elementwise_add(candidate, _candidate_bias)
c = self._activation(candidate) c = self._activation(candidate)
new_hidden = u * pre_hidden + (1 - u) * c new_hidden = u * pre_hidden + (1 - u) * c
...@@ -700,7 +1085,11 @@ class PrePostProcessLayer(Layer): ...@@ -700,7 +1085,11 @@ class PrePostProcessLayer(Layer):
PrePostProcessLayer PrePostProcessLayer
""" """
def __init__(self, process_cmd, d_model, dropout_rate): def __init__(self,
process_cmd,
d_model,
dropout_rate,
reused_layer_norm=None):
super(PrePostProcessLayer, self).__init__() super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd self.process_cmd = process_cmd
self.functors = [] self.functors = []
...@@ -708,16 +1097,21 @@ class PrePostProcessLayer(Layer): ...@@ -708,16 +1097,21 @@ class PrePostProcessLayer(Layer):
if cmd == "a": # add residual connection if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x) self.functors.append(lambda x, y: x + y if y else x)
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
if reused_layer_norm is not None:
layer_norm = reused_layer_norm
else:
layer_norm = LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
"layer_norm_%d" % len( "layer_norm_%d" % len(
self.sublayers(include_sublayers=False)), self.sublayers(include_sublayers=False)),
LayerNorm( layer_norm))
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
self.functors.append(lambda x: layers.dropout( self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False) x, dropout_prob=dropout_rate, is_test=False)
...@@ -737,21 +1131,48 @@ class MultiHeadAttention(Layer): ...@@ -737,21 +1131,48 @@ class MultiHeadAttention(Layer):
Multi-Head Attention Multi-Head Attention
""" """
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): def __init__(self,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.0,
reused_query_fc=None,
reused_key_fc=None,
reused_value_fc=None,
reused_proj_fc=None):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.n_head = n_head self.n_head = n_head
self.d_key = d_key self.d_key = d_key
self.d_value = d_value self.d_value = d_value
self.d_model = d_model self.d_model = d_model
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) if reused_query_fc is not None:
self.k_fc = Linear( self.q_fc = reused_query_fc
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) else:
self.v_fc = Linear( self.q_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False) input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.proj_fc = Linear( if reused_key_fc is not None:
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False) self.k_fc = reused_key_fc
else:
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
if reused_value_fc is not None:
self.v_fc = reused_value_fc
else:
self.v_fc = Linear(
input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
if reused_proj_fc is not None:
self.proj_fc = reused_proj_fc
else:
self.proj_fc = Linear(
input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None): def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention if keys is None: # self-attention
...@@ -828,12 +1249,24 @@ class FFN(Layer): ...@@ -828,12 +1249,24 @@ class FFN(Layer):
Feed-Forward Network Feed-Forward Network
""" """
def __init__(self, d_inner_hid, d_model, dropout_rate): def __init__(self,
d_inner_hid,
d_model,
dropout_rate,
fc1_act="relu",
reused_fc1=None,
reused_fc2=None):
super(FFN, self).__init__() super(FFN, self).__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.fc1 = Linear( if reused_fc1 is not None:
input_dim=d_model, output_dim=d_inner_hid, act="relu") self.fc1 = reused_fc1
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model) else:
self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act=fc1_act)
if reused_fc2 is not None:
self.fc2 = reused_fc2
else:
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x): def forward(self, x):
hidden = self.fc1(x) hidden = self.fc1(x)
...@@ -859,22 +1292,52 @@ class TransformerEncoderLayer(Layer): ...@@ -859,22 +1292,52 @@ class TransformerEncoderLayer(Layer):
attention_dropout, attention_dropout,
relu_dropout, relu_dropout,
preprocess_cmd="n", preprocess_cmd="n",
postprocess_cmd="da"): postprocess_cmd="da",
ffn_fc1_act="relu",
reused_pre_selatt_layernorm=None,
reused_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, reused_pre_selatt_layernorm)
attention_dropout) self.self_attn = MultiHeadAttention(
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model, d_key,
prepostprocess_dropout) d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_multihead_att_weights["reused_query_fc"],
reused_key_fc=reused_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_multihead_att_weights["reused_value_fc"],
reused_proj_fc=reused_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
self.ffn = FFN(d_inner_hid, d_model, relu_dropout) reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
fc1_act=ffn_fc1_act,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
reused_post_ffn_layernorm)
def forward(self, enc_input, attn_bias): def forward(self, enc_input, attn_bias):
attn_output = self.self_attn( attn_output = self.self_attn(
...@@ -902,7 +1365,8 @@ class TransformerEncoder(Layer): ...@@ -902,7 +1365,8 @@ class TransformerEncoder(Layer):
attention_dropout, attention_dropout,
relu_dropout, relu_dropout,
preprocess_cmd="n", preprocess_cmd="n",
postprocess_cmd="da"): postprocess_cmd="da",
ffn_fc1_act="relu"):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
...@@ -912,9 +1376,17 @@ class TransformerEncoder(Layer): ...@@ -912,9 +1376,17 @@ class TransformerEncoder(Layer):
self.add_sublayer( self.add_sublayer(
"layer_%d" % i, "layer_%d" % i,
TransformerEncoderLayer( TransformerEncoderLayer(
n_head, d_key, d_value, d_model, d_inner_hid, n_head,
prepostprocess_dropout, attention_dropout, d_key,
relu_dropout, preprocess_cmd, postprocess_cmd))) d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
ffn_fc1_act=ffn_fc1_act)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model, self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout)
...@@ -941,28 +1413,79 @@ class TransformerDecoderLayer(Layer): ...@@ -941,28 +1413,79 @@ class TransformerDecoderLayer(Layer):
attention_dropout, attention_dropout,
relu_dropout, relu_dropout,
preprocess_cmd="n", preprocess_cmd="n",
postprocess_cmd="da"): postprocess_cmd="da",
reused_pre_selfatt_layernorm=None,
reused_self_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_crossatt_layernorm=None,
reused_cross_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_crossatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, reused_pre_selfatt_layernorm)
attention_dropout) self.self_attn = MultiHeadAttention(
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model, d_key,
prepostprocess_dropout) d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_self_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_self_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_self_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_self_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, reused_pre_crossatt_layernorm)
attention_dropout) self.cross_attn = MultiHeadAttention(
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, d_key,
prepostprocess_dropout) d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_cross_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_cross_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_cross_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_cross_multihead_att_weights[
"reused_proj_fc"])
self.postprocesser2 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_crossatt_layernorm)
self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
self.ffn = FFN(d_inner_hid, d_model, relu_dropout) reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model, self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout,
reused_post_ffn_layernorm)
def forward(self, def forward(self,
dec_input, dec_input,
...@@ -1031,7 +1554,7 @@ class TransformerDecoder(Layer): ...@@ -1031,7 +1554,7 @@ class TransformerDecoder(Layer):
] ]
#TODO: we should merge GRUCell with BasicGRUCell
class GRUCell(RNNCell): class GRUCell(RNNCell):
def __init__(self, def __init__(self,
input_size, input_size,
...@@ -1044,9 +1567,7 @@ class GRUCell(RNNCell): ...@@ -1044,9 +1567,7 @@ class GRUCell(RNNCell):
super(GRUCell, self).__init__() super(GRUCell, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.fc_layer = Linear( self.fc_layer = Linear(
input_size, input_size, hidden_size * 3, param_attr=param_attr)
hidden_size * 3,
param_attr=param_attr)
self.gru_unit = GRUUnit( self.gru_unit = GRUUnit(
hidden_size * 3, hidden_size * 3,
...@@ -1067,7 +1588,8 @@ class GRUCell(RNNCell): ...@@ -1067,7 +1588,8 @@ class GRUCell(RNNCell):
return [self.hidden_size] return [self.hidden_size]
class GRUEncoderCell(RNNCell): #TODO: we should merge GRUCell with BasicGRUCell
class GRUEncoderCell(RNNCell):
def __init__(self, def __init__(self,
num_layers, num_layers,
input_size, input_size,
...@@ -1086,8 +1608,9 @@ class GRUEncoderCell(RNNCell): ...@@ -1086,8 +1608,9 @@ class GRUEncoderCell(RNNCell):
GRUCell( GRUCell(
input_size=input_size if i == 0 else hidden_size, input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size, hidden_size=hidden_size,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.UniformInitializer( param_attr=fluid.ParamAttr(
low=-init_scale, high=init_scale))))) initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))))
def forward(self, step_input, states): def forward(self, step_input, states):
new_states = [] new_states = []
...@@ -1109,18 +1632,17 @@ class GRUEncoderCell(RNNCell): ...@@ -1109,18 +1632,17 @@ class GRUEncoderCell(RNNCell):
class BiGRU(fluid.dygraph.Layer): class BiGRU(fluid.dygraph.Layer):
def __init__(self, input_dim, grnn_hidden_dim, init_bound, h_0=None): def __init__(self, input_dim, grnn_hidden_dim, init_bound, h_0=None):
super(BiGRU, self).__init__() super(BiGRU, self).__init__()
self.gru = RNN(GRUEncoderCell(1, input_dim, self.gru = RNN(GRUEncoderCell(1, input_dim, grnn_hidden_dim, 0.0,
grnn_hidden_dim, 0.0, init_bound), init_bound),
is_reverse=False, is_reverse=False,
time_major=False) time_major=False)
self.gru_r = RNN(GRUEncoderCell(1, input_dim, self.gru_r = RNN(GRUEncoderCell(1, input_dim, grnn_hidden_dim, 0.0,
grnn_hidden_dim, 0.0, init_bound), init_bound),
is_reverse=True, is_reverse=True,
time_major=False) time_major=False)
def forward(self, input_feature): def forward(self, input_feature):
pre_gru, pre_state = self.gru(input_feature) pre_gru, pre_state = self.gru(input_feature)
gru_r, r_state = self.gru_r(input_feature) gru_r, r_state = self.gru_r(input_feature)
bi_merge = fluid.layers.concat(input=[pre_gru, gru_r], axis=-1) bi_merge = fluid.layers.concat(input=[pre_gru, gru_r], axis=-1)
...@@ -1320,14 +1842,14 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1320,14 +1842,14 @@ class SequenceTagging(fluid.dygraph.Layer):
emission = self.fc(bigru_output) emission = self.fc(bigru_output)
if target is not None: if target is not None:
crf_cost = self.linear_chain_crf( crf_cost = self.linear_chain_crf(
input=emission, label=target, length=lengths) input=emission, label=target, length=lengths)
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
self.crf_decoding.weight = self.linear_chain_crf.weight self.crf_decoding.weight = self.linear_chain_crf.weight
crf_decode = self.crf_decoding(input=emission, length=lengths) crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, avg_cost, lengths return crf_decode, avg_cost, lengths
else: else:
self.linear_chain_crf.weight = self.crf_decoding.weight self.linear_chain_crf.weight = self.crf_decoding.weight
crf_decode = self.crf_decoding(input=emission, length=lengths) crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, lengths return crf_decode, lengths
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册