提交 56e2729c 编写于 作者: G guosheng

Remove hapi.text apis' reuse parameter args for coverage.

test=develop
上级 6e962618
......@@ -16,8 +16,9 @@ from paddle.fluid.dygraph.nn import Linear, Embedding
from paddle.fluid.dygraph.base import to_variable
import numpy as np
from hapi.model import Model
from hapi.text.text import GRUEncoder as BiGRUEncoder
from hapi.text.test import BOWEncoder, CNNEncoder, GRUEncoder
from hapi.text.text import _GRUEncoder as GRUEncoder
from hapi.text.text import _GRUEncoder as BiGRUEncoder
from hapi.text.test import BOWEncoder, CNNEncoder
class CNN(Model):
......
......@@ -28,47 +28,6 @@ from hapi.model import Model, Input, set_device
from hapi.text.text import *
def sigmoid(x):
return 1. / (1. + np.exp(-x))
def tanh(x):
return 2. * sigmoid(2. * x) - 1.
def lstm_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
i, j, f, o = np.split(gate_input, indices_or_sections=4, axis=1)
new_cell = pre_cell * sigmoid(f + forget_bias) + sigmoid(i) * tanh(j)
new_hidden = tanh(new_cell) * sigmoid(o)
return new_hidden, new_cell
def gru_step(step_in, pre_hidden, gate_w, gate_b, candidate_w, candidate_b):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
gate_input = sigmoid(gate_input)
r, u = np.split(gate_input, indices_or_sections=2, axis=1)
r_hidden = r * pre_hidden
candidate = np.matmul(np.concatenate([step_in, r_hidden], 1), candidate_w)
candidate += candidate_b
c = tanh(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
class ModuleApiTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
......
......@@ -37,9 +37,6 @@ from hapi.text.text import TransformerDecoder as TransformerDecoder
from hapi.text.text import TransformerCell as TransformerCell
from hapi.text.text import TransformerBeamSearchDecoder as TransformerBeamSearchDecoder
from hapi.text.text import GRUCell as GRUCell
from hapi.text.text import GRUEncoderCell as GRUEncoderCell
from hapi.text.text import BiGRU as BiGRU
from hapi.text.text import LinearChainCRF as LinearChainCRF
from hapi.text.text import CRFDecoding as CRFDecoding
from hapi.text.text import SequenceTagging as SequenceTagging
......@@ -16,33 +16,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import collections
import six
import sys
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf8')
from functools import partial, reduce
import ast
import time
import argparse as argparse
import numpy as np
import multiprocessing
import collections
import copy
from functools import partial, reduce
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.utils as utils
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.dygraph import Embedding, Linear, LayerNorm, GRUUnit, Conv2D, Pool2D
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
from paddle.fluid.layers import BeamSearchDecoder
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.dygraph import Layer, Embedding, Linear, LayerNorm, GRUUnit, Conv2D, Pool2D
from paddle.fluid.data_feeder import convert_dtype
__all__ = [
'RNNCell',
......@@ -72,7 +61,6 @@ __all__ = [
'LinearChainCRF',
'CRFDecoding',
'SequenceTagging',
'GRUEncoder',
]
......@@ -234,25 +222,6 @@ class BasicLSTMCell(RNNCell):
forget_bias(float, optional): forget bias used when computing forget gate.
Default 1.0
dtype(string, optional): The data type used in this cell. Default float32.
forget_gate_weights (dict, optional): A dict includes `w`, `h` and `b`
as keys, and the corresponding values should be instances of Parameter
which represent :math:`W_{x_{f}}, W_{h_{f}}, b_{f}` and have shape
[input_size, hidden_size], [hidden_size, hidden_size], [hidden_size]
separately. It is used for reusing and sharing weights when provided,
otherwise create these parameters. Note that parameters from input
gate, forget gate and cell would be concatenated in implementation.
input_gate_weights (dict, optional): A dict includes `w`, `h` and `b` as keys,
and the corresponding values should be instances of Parameter which
represent :math:`W_{x_{i}}, W_{h_{i}}, b_{i}` separately. It has the
same usage as :attr:`forget_gate_weights`.
output_gate_weights (dict, optional): A dict includes `w`, `h` and `b` as keys,
and the corresponding values should be instances of Parameter which
represent :math:`W_{x_{o}}, W_{h_{o}}, b_{o}` separately. It has the
same usage as :attr:`forget_gate_weights`.
cell_weights (dict, optional): A dict includes `w`, `h` and `b` as keys,
and the corresponding values should be instances of Parameter which
represent :math:`W_{x_{c}}, W_{h_{c}}, b_{c}` separately. It has the
same usage as :attr:`forget_gate_weights`.
"""
def __init__(self,
......@@ -263,19 +232,7 @@ class BasicLSTMCell(RNNCell):
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32',
forget_gate_weights={"w": None,
"h": None,
"b": None},
input_gate_weights={"w": None,
"h": None,
"b": None},
output_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
dtype='float32'):
super(BasicLSTMCell, self).__init__()
self._hidden_size = hidden_size
......@@ -290,225 +247,43 @@ class BasicLSTMCell(RNNCell):
self._dtype = dtype
self._input_size = input_size
self.use_customized_weight = False
for _weights in [
forget_gate_weights, input_gate_weights, output_gate_weights,
cell_weights
]:
for _key in _weights:
if _weights[_key] is not None:
self.use_customized_weight = True
break
if self.use_customized_weight:
break
if not self.use_customized_weight:
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[
self._input_size + self._hidden_size, 4 * self._hidden_size
],
dtype=self._dtype)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
else:
if "w" in forget_gate_weights and forget_gate_weights[
"w"] is not None:
self.fg_w = forget_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_w"
else:
tmp_param_attr = self._param_attr
self.fg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in forget_gate_weights and forget_gate_weights[
"h"] is not None:
self.fg_h = forget_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_forget_gate_h"
else:
tmp_param_attr = self._param_attr
self.fg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in forget_gate_weights and forget_gate_weights[
"b"] is not None:
self.fg_b = forget_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_forget_gate_b"
else:
tmp_param_attr = self._bias_attr
self.fg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in input_gate_weights and input_gate_weights[
"w"] is not None:
self.ig_w = input_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_w"
else:
tmp_param_attr = self._param_attr
self.ig_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in input_gate_weights and input_gate_weights[
"h"] is not None:
self.ig_h = input_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_input_gate_h"
else:
tmp_param_attr = self._param_attr
self.ig_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in input_gate_weights and input_gate_weights[
"b"] is not None:
self.ig_b = input_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_input_gate_b"
else:
tmp_param_attr = self._bias_attr
self.ig_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in output_gate_weights and output_gate_weights[
"w"] is not None:
self.og_w = output_gate_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_w"
else:
tmp_param_attr = self._param_attr
self.og_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in output_gate_weights and output_gate_weights[
"h"] is not None:
self.og_h = output_gate_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_output_gate_h"
else:
tmp_param_attr = self._param_attr
self.og_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in output_gate_weights and output_gate_weights[
"b"] is not None:
self.og_b = output_gate_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_output_gate_b"
else:
tmp_param_attr = self._bias_attr
self.og_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = self._param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if self._param_attr is not None and self._param_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = self._param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if self._bias_attr is not None and self._bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(self._bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = self._bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
if self.use_customized_weight:
weight_w = fluid.layers.concat(
[self.ig_w, self.c_w, self.fg_w, self.og_w], axis=-1)
weight_h = fluid.layers.concat(
[self.ig_h, self.c_h, self.fg_h, self.og_h], axis=-1)
_weight = fluid.layers.concat([weight_w, weight_h], axis=0)
_bias = fluid.layers.concat(
[self.ig_b, self.c_b, self.fg_b, self.og_b])
else:
_weight = self._weight
_bias = self._bias
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[
self._input_size + self._hidden_size, 4 * self._hidden_size
],
dtype=self._dtype)
pre_hidden, pre_cell = state
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=_weight)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
gate_input = layers.elementwise_add(gate_input, _bias)
def forward(self, inputs, states):
"""
Performs single step LSTM calculations.
Parameters:
inputs (Variable): A tensor with shape `[batch_size, input_size]`,
corresponding to :math:`x_t` in the formula. The data type
should be float32 or float64.
states (Variable): A list of containing two tensors, each shaped
`[batch_size, hidden_size]`, corresponding to :math:`h_{t-1}, c_{t-1}`
in the formula. The data type should be float32 or float64.
Returns:
tuple: A tuple( :code:`(outputs, new_states)` ), where `outputs` is \
a tensor with shape `[batch_size, hidden_size]`, corresponding \
to :math:`h_{t}` in the formula; `new_states` is a list containing \
two tenser variables shaped `[batch_size, hidden_size]`, corresponding \
to :math:`h_{t}, c_{t}` in the formula. The data type of these \
tensors all is same as that of `states`.
"""
pre_hidden, pre_cell = states
concat_input_hidden = layers.concat([inputs, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
......@@ -564,21 +339,6 @@ class BasicGRUCell(RNNCell):
GRU, that is :math:`act_c` in the formula. Default: None,
representing for 'fluid.layers.tanh'.
dtype(string, optional): The data type used in this cell. Default float32.
update_gate_weights (dict, optional): A dict includes `w`, `h` and `b`
as keys, and the corresponding values should be instances of Parameter
which represent :math:`W_{ux}, W_{uh}, b_{u}` and have shape
[input_size, hidden_size], [hidden_size, hidden_size], [hidden_size]
separately. It is used for reusing and sharing weights when provided,
otherwise create these parameters. Note that parameters from update
gate and reset gate would be concatenated in implementation.
reset_gate_weights (dict, optional): A dict includes `w`, `h` and `b` as keys,
and the corresponding values should be instances of Parameter which
represent :math:`W_{rx}, W_{rh}, b_{r}` separately. It has the
same usage as :attr:`update_gate_weights`.
cell_weights (dict, optional): A dict includes `w`, `h` and `b` as keys,
and the corresponding values should be instances of Parameter which
represent :math:`W_{cx}, W_{ch}, b_{c}`` separately. It has the
same usage as :attr:`update_gate_weights`.
"""
def __init__(self,
......@@ -588,16 +348,7 @@ class BasicGRUCell(RNNCell):
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32',
update_gate_weights={"w": None,
"h": None,
"b": None},
reset_gate_weights={"w": None,
"h": None,
"b": None},
cell_weights={"w": None,
"h": None,
"b": None}):
dtype='float32'):
super(BasicGRUCell, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
......@@ -607,20 +358,6 @@ class BasicGRUCell(RNNCell):
self._activation = activation or layers.tanh
self._dtype = dtype
assert isinstance(update_gate_weights, dict)
assert isinstance(reset_gate_weights, dict)
assert isinstance(cell_weights, dict)
self.use_customized_weight = False
for _weights in [
update_gate_weights, reset_gate_weights, cell_weights
]:
for _key in _weights:
if _weights[_key] is not None:
self.use_customized_weight = True
if self.use_customized_weight:
break
if self._param_attr is not None and self._param_attr.name is not None:
gate_param_attr = copy.deepcopy(self._param_attr)
candidate_param_attr = copy.deepcopy(self._param_attr)
......@@ -630,194 +367,62 @@ class BasicGRUCell(RNNCell):
gate_param_attr = self._param_attr
candidate_param_attr = self._param_attr
if not self.use_customized_weight:
self._gate_weight = self.create_parameter(
attr=gate_param_attr,
shape=[
self._input_size + self._hidden_size, 2 * self._hidden_size
],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=candidate_param_attr,
shape=[
self._input_size + self._hidden_size, self._hidden_size
],
dtype=self._dtype)
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
gate_bias_attr.name += "_gate"
candidate_bias_attr.name += "_candidate"
else:
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr
self._gate_bias = self.create_parameter(
attr=gate_bias_attr,
shape=[2 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=candidate_bias_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
self._gate_weight = self.create_parameter(
attr=gate_param_attr,
shape=[
self._input_size + self._hidden_size, 2 * self._hidden_size
],
dtype=self._dtype)
else:
self._candidate_weight = self.create_parameter(
attr=candidate_param_attr,
shape=[self._input_size + self._hidden_size, self._hidden_size],
dtype=self._dtype)
# create the parameters of gates in gru
if "w" in update_gate_weights and update_gate_weights[
"w"] is not None:
self.ug_w = update_gate_weights["w"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_update_gate_w"
else:
tmp_param_attr = gate_param_attr
self.ug_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in update_gate_weights and update_gate_weights[
"h"] is not None:
self.ug_h = update_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_update_gate_h"
else:
tmp_param_attr = gate_param_attr
self.ug_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in update_gate_weights and update_gate_weights[
"b"] is not None:
self.ug_b = update_gate_weights["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_update_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.ug_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# reset gate parameters
if "w" in reset_gate_weights and reset_gate_weights[
"w"] is not None:
self.rg_w = reset_gate_weights["w"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_w"
else:
tmp_param_attr = gate_param_attr
self.rg_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in reset_gate_weights and reset_gate_weights[
"h"] is not None:
self.rg_h = reset_gate_weights["h"]
else:
if gate_param_attr is not None and gate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_param_attr)
tmp_param_attr.name += "_reset_gate_h"
else:
tmp_param_attr = gate_param_attr
self.rg_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in reset_gate_weights and reset_gate_weights[
"b"] is not None:
self.rg_b = reset_gate_weights["b"]
else:
if gate_bias_attr is not None and gate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(gate_bias_attr)
tmp_param_attr.name += "_reset_gate_b"
else:
tmp_param_attr = gate_bias_attr
self.rg_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
# cell parameters
if "w" in cell_weights and cell_weights["w"] is not None:
self.c_w = cell_weights["w"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_w"
else:
tmp_param_attr = gate_param_attr
self.c_w = self.create_parameter(
attr=tmp_param_attr,
shape=[self._input_size, self._hidden_size],
dtype=self._dtype)
if "h" in cell_weights and cell_weights["h"] is not None:
self.c_h = cell_weights["h"]
else:
if candidate_param_attr is not None and candidate_param_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_param_attr)
tmp_param_attr.name += "_cell_h"
else:
tmp_param_attr = gate_param_attr
self.c_h = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size, self._hidden_size],
dtype=self._dtype)
if "b" in cell_weights and cell_weights["b"] is not None:
self.c_b = cell_weights["b"]
else:
if candidate_bias_attr is not None and candidate_bias_attr.name is not None:
tmp_param_attr = copy.deepcopy(candidate_bias_attr)
tmp_param_attr.name += "_cell_b"
else:
tmp_param_attr = gate_bias_attr
self.c_b = self.create_parameter(
attr=tmp_param_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
if self.use_customized_weight:
rg_weights = layers.concat([self.rg_w, self.rg_h], axis=0)
ug_weights = layers.concat([self.ug_w, self.ug_h], axis=0)
_gate_weight = layers.concat([rg_weights, ug_weights], axis=-1)
_candidate_weight = layers.concat([self.c_w, self.c_h], axis=0)
_gate_bias = layers.concat([self.rg_b, self.ug_b], axis=0)
_candidate_bias = self.c_b
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
gate_bias_attr.name += "_gate"
candidate_bias_attr.name += "_candidate"
else:
_gate_weight = self._gate_weight
_gate_bias = self._gate_bias
_candidate_weight = self._candidate_weight
_candidate_bias = self._candidate_bias
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr
self._gate_bias = self.create_parameter(
attr=gate_bias_attr,
shape=[2 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=candidate_bias_attr,
shape=[self._hidden_size],
dtype=self._dtype,
is_bias=True)
pre_hidden = state
concat_input_hidden = layers.concat([input, pre_hidden], axis=1)
def forward(self, inputs, states):
"""
Performs single step GRU calculations.
gate_input = layers.matmul(x=concat_input_hidden, y=_gate_weight)
Parameters:
inputs (Variable): A tensor with shape `[batch_size, input_size]`,
corresponding to :math:`x_t` in the formula. The data type
should be float32 or float64.
states (Variable): A tensor with shape `[batch_size, hidden_size]`.
corresponding to :math:`h_{t-1}` in the formula. The data type
should be float32 or float64.
gate_input = layers.elementwise_add(gate_input, _gate_bias)
Returns:
tuple: A tuple( :code:`(outputs, new_states)` ), where `outputs` and \
`new_states` is the same tensor shaped `[batch_size, hidden_size]`, \
corresponding to :math:`h_t` in the formula. The data type of the \
tensor is same as that of `states`.
"""
pre_hidden = states
concat_input_hidden = layers.concat([inputs, pre_hidden], axis=1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = layers.elementwise_add(gate_input, self._gate_bias)
gate_input = self._gate_activation(gate_input)
r, u = layers.split(gate_input, num_or_sections=2, dim=1)
......@@ -825,8 +430,8 @@ class BasicGRUCell(RNNCell):
r_hidden = r * pre_hidden
candidate = layers.matmul(
layers.concat([input, r_hidden], 1), _candidate_weight)
candidate = layers.elementwise_add(candidate, _candidate_bias)
layers.concat([inputs, r_hidden], 1), self._candidate_weight)
candidate = layers.elementwise_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
......@@ -2650,6 +2255,7 @@ class TransformerCell(Layer):
class Embedder(fluid.dygraph.Layer):
def __init__(self):
super(Embedder, self).__init__()
self.word_embedder = Embedding(size=[1000, 128])
self.pos_embedder = Embedding(size=[500, 128])
......@@ -2999,11 +2605,7 @@ class PrePostProcessLayer(Layer):
out = process(x) # [2, 4, 32]
"""
def __init__(self,
process_cmd,
d_model,
dropout_rate=0.1,
reused_layer_norm=None):
def __init__(self, process_cmd, d_model, dropout_rate=0.1):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
self.functors = []
......@@ -3012,15 +2614,12 @@ class PrePostProcessLayer(Layer):
self.functors.append(
lambda x, y: x + y if y is not None else x)
elif cmd == "n": # add layer normalization
if reused_layer_norm is not None:
layer_norm = reused_layer_norm
else:
layer_norm = LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
layer_norm = LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
self.functors.append(
self.add_sublayer(
......@@ -3091,16 +2690,7 @@ class MultiHeadAttention(Layer):
output = multi_head_attn(query, attn_bias=attn_bias) # [2, 4, 128]
"""
def __init__(self,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.0,
reused_query_fc=None,
reused_key_fc=None,
reused_value_fc=None,
reused_proj_fc=None):
def __init__(self, d_key, d_value, d_model, n_head, dropout_rate=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
......@@ -3109,30 +2699,14 @@ class MultiHeadAttention(Layer):
self.d_model = d_model
self.dropout_rate = dropout_rate
if reused_query_fc is not None:
self.q_fc = reused_query_fc
else:
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
if reused_key_fc is not None:
self.k_fc = reused_key_fc
else:
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
if reused_value_fc is not None:
self.v_fc = reused_value_fc
else:
self.v_fc = Linear(
input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
if reused_proj_fc is not None:
self.proj_fc = reused_proj_fc
else:
self.proj_fc = Linear(
input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.v_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None):
"""
......@@ -3322,24 +2896,12 @@ class FFN(Layer):
out = ffn(x) # [2, 4, 32]
"""
def __init__(self,
d_inner_hid,
d_model,
dropout_rate=0.1,
fc1_act="relu",
reused_fc1=None,
reused_fc2=None):
def __init__(self, d_inner_hid, d_model, dropout_rate=0.1, fc1_act="relu"):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
if reused_fc1 is not None:
self.fc1 = reused_fc1
else:
self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act=fc1_act)
if reused_fc2 is not None:
self.fc2 = reused_fc2
else:
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act=fc1_act)
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x):
"""
......@@ -3422,51 +2984,22 @@ class TransformerEncoderLayer(Layer):
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
ffn_fc1_act="relu",
reused_pre_selatt_layernorm=None,
reused_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
ffn_fc1_act="relu"):
super(TransformerEncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_selatt_layernorm)
self.self_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_multihead_att_weights["reused_query_fc"],
reused_key_fc=reused_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_multihead_att_weights["reused_value_fc"],
reused_proj_fc=reused_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
fc1_act=ffn_fc1_act,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout, fc1_act=ffn_fc1_act)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout,
reused_post_ffn_layernorm)
prepostprocess_dropout)
def forward(self, enc_input, attn_bias=None):
"""
......@@ -3667,83 +3200,33 @@ class TransformerDecoderLayer(Layer):
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
reused_pre_selfatt_layernorm=None,
reused_self_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_selfatt_layernorm=None,
reused_pre_crossatt_layernorm=None,
reused_cross_multihead_att_weights={
"reused_query_fc": None,
"reused_key_fc": None,
"reused_value_fc": None,
"reused_proj_fc": None
},
reused_post_crossatt_layernorm=None,
reused_pre_ffn_layernorm=None,
reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None},
reused_post_ffn_layernorm=None):
ffn_fc1_act="relu"):
super(TransformerDecoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_selfatt_layernorm)
self.self_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_self_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_self_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_self_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_self_multihead_att_weights["reused_proj_fc"])
self.postprocesser1 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_selfatt_layernorm)
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_crossatt_layernorm)
self.cross_attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
attention_dropout,
reused_query_fc=reused_cross_multihead_att_weights[
"reused_query_fc"],
reused_key_fc=reused_cross_multihead_att_weights["reused_key_fc"],
reused_value_fc=reused_cross_multihead_att_weights[
"reused_value_fc"],
reused_proj_fc=reused_cross_multihead_att_weights[
"reused_proj_fc"])
self.postprocesser2 = PrePostProcessLayer(
postprocess_cmd, d_model, prepostprocess_dropout,
reused_post_crossatt_layernorm)
prepostprocess_dropout)
self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout,
reused_pre_ffn_layernorm)
self.ffn = FFN(d_inner_hid,
d_model,
relu_dropout,
reused_fc1=reused_ffn_weights["reused_fc1"],
reused_fc2=reused_ffn_weights["reused_fc2"])
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout, fc1_act=ffn_fc1_act)
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout,
reused_post_ffn_layernorm)
prepostprocess_dropout)
def forward(self,
dec_input,
......@@ -3991,101 +3474,6 @@ class TransformerDecoder(Layer):
} for i in range(self.n_layer)]
#TODO: we should merge GRUCell with BasicGRUCell
class GRUCell(RNNCell):
def __init__(self,
input_size,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.fc_layer = Linear(
input_size, hidden_size * 3, param_attr=param_attr)
self.gru_unit = GRUUnit(
hidden_size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
def forward(self, inputs, states):
# for GRUCell, `step_outputs` and `new_states` both are hidden
x = self.fc_layer(inputs)
hidden, _, _ = self.gru_unit(x, states)
return hidden, hidden
@property
def state_shape(self):
return [self.hidden_size]
#TODO: we should merge GRUCell with BasicGRUCell
class GRUEncoderCell(RNNCell):
def __init__(self,
num_layers,
input_size,
hidden_size,
dropout_prob=0.,
init_scale=0.1):
super(GRUEncoderCell, self).__init__()
self.dropout_prob = dropout_prob
# use add_sublayer to add multi-layers
self.gru_cells = []
for i in range(num_layers):
self.gru_cells.append(
self.add_sublayer(
"gru_%d" % i,
#BasicGRUCell(
GRUCell(
input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))))
def forward(self, step_input, states):
new_states = []
for i, gru_cell in enumerate(self.gru_cells):
out, state = gru_cell(step_input, states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation='upscale_in_train'
) if self.dropout_prob > 0 else out
new_states.append(step_input)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.gru_cells]
class BiGRU(fluid.dygraph.Layer):
def __init__(self, input_dim, grnn_hidden_dim, init_bound, h_0=None):
super(BiGRU, self).__init__()
self.gru = RNN(GRUEncoderCell(1, input_dim, grnn_hidden_dim, 0.0,
init_bound),
is_reverse=False,
time_major=False)
self.gru_r = RNN(GRUEncoderCell(1, input_dim, grnn_hidden_dim, 0.0,
init_bound),
is_reverse=True,
time_major=False)
def forward(self, input_feature):
pre_gru, pre_state = self.gru(input_feature)
gru_r, r_state = self.gru_r(input_feature)
bi_merge = fluid.layers.concat(input=[pre_gru, gru_r], axis=-1)
return bi_merge
class LinearChainCRF(Layer):
"""
Computes the negtive log-likelihood of tag sequences in a linear chain CRF.
......@@ -4349,7 +3737,7 @@ class CRFDecoding(Layer):
return viterbi_path
class GRUEncoder(Layer):
class _GRUEncoder(Layer):
"""
A multi-layer bidirectional GRU encoder used by SequenceTagging.
"""
......@@ -4360,7 +3748,7 @@ class GRUEncoder(Layer):
init_bound,
num_layers=1,
is_bidirection=False):
super(GRUEncoder, self).__init__()
super(_GRUEncoder, self).__init__()
self.num_layers = num_layers
self.is_bidirection = is_bidirection
self.gru_list = []
......@@ -4475,7 +3863,7 @@ class SequenceTagging(Layer):
initializer=fluid.initializer.Uniform(
low=-self.init_bound, high=self.init_bound)))
self.gru_encoder = GRUEncoder(
self.gru_encoder = _GRUEncoder(
input_dim=self.grnn_hidden_dim,
grnn_hidden_dim=self.grnn_hidden_dim,
init_bound=self.init_bound,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册