提交 75dcc161 编写于 作者: J jinyuKing

update text.py

上级 a4cb497e
......@@ -37,7 +37,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.utils as utils
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.dygraph import to_variable, Embedding, Linear, LayerNorm, GRUUnit, Conv2D
from paddle.fluid.dygraph import to_variable, Embedding, Linear, LayerNorm, GRUUnit, Conv2D, Pool2D
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import layers
......@@ -49,8 +49,8 @@ __all__ = [
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'Linear_chain_crf',
'Crf_decoding', 'SequenceTagging', 'GRUEncoderLayer', 'SimCNNEncoder',
'SimBOWEncoder', 'SimpleConvPoolLayer', 'SimGRUEncoder', 'DynamicGRU', 'SimLSTMEncoder'
'Crf_decoding', 'SequenceTagging', 'GRUEncoderLayer', 'Conv1dPoolLayer',
'CNNEncoder'
]
......@@ -1898,226 +1898,202 @@ class SequenceTagging(fluid.dygraph.Layer):
crf_decode = self.crf_decoding(input=emission, length=lengths)
return crf_decode, lengths
class SimpleConvPoolLayer(Layer):
class Conv1dPoolLayer(Layer):
"""
This interface is used to construct a callable object of the ``Conv1DPoolLayer`` class.The ``Conv1DPoolLayer`` is composed of a ``Conv2D`` and a ``Pool2D`` .
For more details, refer to code examples.The ``Conv1DPoolLayer`` layer calculates the output based on the input, filter and strides, paddings, dilations,
groups,global_pooling, pool_type,ceil_mode,exclusive parameters.Input and Output are in NCH format, where N is batch size, C is the number of the feature map,
H is the height of the feature map.The data type of Input data and Output data is 'float32' or 'float64'.
Args:
input(Variable):3-D Tensor, shape is [N, C, H], data type can be float32 or float64
num_channels(int): The number of channels in the input data.
num_filters(int): The number of filters. It is the same as the output channels.
filter_size (int): The filter size of Conv1DPoolLayer.
pool_size (int): The pooling size of Conv1DPoolLayer.
conv_stride (int): The stride size of the conv Layer in Conv1DPoolLayer. Default: 1
pool_stride (int): The stride size of the pool layer in Conv1DPoolLayer. Default: 1
conv_padding (int): The padding size of the conv Layer in Conv1DPoolLayer. Default: 0
pool_padding (int): The padding of pool layer in Conv1DPoolLayer. Default: 0
pool_type (str): Pooling type can be `max` for max-pooling or `avg` for average-pooling. Default: math:`max`
global_pooling (bool): Whether to use the global pooling. If global_pooling = true, pool_size and pool_padding while be ignored. Default: False
dilation (int): The dilation size of the conv Layer. Default: 1.
groups (int): The groups number of the conv Layer. According to grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of conv layer. If it is set to None or one attribute of
ParamAttr, conv2d will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized
with :`Normal(0.0, std)`,and the :`std` is :`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`.Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv.If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not
set, the bias is initialized zero. Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. Default: None
act (str): Activation type for conv layer, if it is set to None, activation is not appended. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: False
ceil_mode (bool, optional): Whether to use the ceil function to calculate output height and width.
False is the default. If it is set to False, the floor function will be used. Default: False.
exclusive (bool, optional): Whether to exclude padding points in average pooling mode. Default: True.
Return:
3-D Tensor, the result of input after conv and pool, with the same data type as :`input`
Return Type:
Variable
Example:
```python
import paddle.fluid as fluid
from hapi.text import Conv1dPoolLayer
test=np.random.uniform(-1,1,[2,3,4]).astype('float32')
with fluid.dygraph.guard():
paddle_input=to_variable(test)
print(paddle_input.shape)
cov2d=Conv1dPoolLayer(3,4,2,2)
paddle_out=cov2d(paddle_input)
print(paddle_out.shape)#[2,4,2]
```
"""
def __init__(self,
num_channels,
num_filters,
filter_size,
pool_size,
conv_stride=1,
pool_stride=1,
conv_padding=0,
pool_padding=0,
pool_type='max',
global_pooling=False,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
act=None,
use_cudnn=False,
act=None
ceil_mode=False,
exclusive=True,
):
super(SimpleConvPoolLayer, self).__init__()
super(Conv1dPoolLayer, self).__init__()
self._conv2d = Conv2D(num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
padding=[1, 1],
filter_size=[filter_size,1],
stride=[conv_stride,1],
padding=[conv_padding,0],
dilation=[dilation,1],
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act)
def forward(self, input):
x = self._conv2d(input)
x = fluid.layers.reduce_max(x, dim=-1)
x = fluid.layers.reshape(x, shape=[x.shape[0], -1])
self._pool2d = Pool2D(pool_size=[pool_size,1],
pool_type=pool_type,
pool_stride=[pool_stride,1],
pool_padding=[pool_padding,0],
global_pooling=global_pooling,
use_cudnn=use_cudnn,
ceil_mode=ceil_mode,
exclusive=exclusive
)
def forward(self, inputs):
x = fluid.layers.unsqueeze(inputs,axes=[-1])
x = self._conv2d(x)
x = self._pool2d(x)
x = fluid.layers.squeeze(x, axes=[-1])
return x
class SimCNNEncoder(Layer):
class CNNEncoder(Layer):
"""
simple CNNEncoder for simnet
This interface is used to construct a callable object of the ``CNNEncoder`` class.The ``CNNEncoder`` is composed of a ``Embedding`` and a ``Conv1dPoolLayer`` .
For more details, refer to code examples. The ``CNNEncoder`` layer calculates the output based on the input, dict_size and emb_dim, filter_size, num_filters,
use_cuda, is_sparse, param_attr parameters. The type of Input data is a Tensor or a lod-tensor .The data type of Input data is 'int64'. Output data are in NCH
format, where N is batch size, C is the number of the feature map, H is the height of the feature map. The data type of Output data is 'float32' or 'float64'.
Args:
dict_size(int): the size of the dictionary of embeddings
emb_szie(int): the size of each embedding vector respectively.
num_channels(int): The number of channels in the input data.Default:1
num_filters(int): The number of filters. It is the same as the output channels.
filter_size(int): The filter size of Conv1DPoolLayer in CNNEncoder.
pool_size(int): The pooling size of Conv1DPoolLayer in CNNEncoder.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: False
is_sparse(bool): The flag indicating whether to use sparse update. This parameter only affects the performance of the backwards gradient update. It is recommended
to set True because sparse update is faster. But some optimizer does not support sparse update,such as :ref:`api_fluid_optimizer_AdadeltaOptimizer` ,
:ref:`api_fluid_optimizer_AdamaxOptimizer` , :ref:`api_fluid_optimizer_DecayedAdagradOptimizer` , :ref:`api_fluid_optimizer_FtrlOptimizer` ,
:ref:`api_fluid_optimizer_LambOptimizer` and :ref:`api_fluid_optimizer_LarsMomentumOptimizer` .
In these case, is_sparse must be False. Default: True.
param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in
:ref:`api_fluid_ParamAttr` . In addition,user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. The local word vector
needs to be transformed into numpy format, and the shape of local word vector should be consistent with :attr:`size` .
Then :ref:`api_fluid_initializer_NumpyArrayInitializer` is used to load custom or pre-trained word vectors. Default: None.
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted to :math:`vocab\_size + padding\_idx` . It will output all-zero padding
data whenever lookup encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. If set None, it makes no effect to
output. Default: None.
act (str): Activation type for `Conv1dPoollayer` layer, if it is set to None, activation is not appended. Default: None.
Return:
3-D Tensor, the result of input after embedding and conv1dPoollayer
Return Type:
Variable
Example:
```python
import paddle.fluid as fluid
from hapi.text import CNNEncoder
test=np.random.uniform(1,5,[2,3,4]).astype('int64')
with fluid.dygraph.guard():
paddle_input=to_variable(test)
print(paddle_input.shape)
cov2d=CNNEncoder(128,4,3,4,2,2)
paddle_out=cov2d(paddle_input)
print(paddle_out.shape)#[8,4,2]
```
"""
def __init__(self,
dict_size,
emb_dim,
filter_size,
emb_size,
num_channels,
num_filters,
hidden_dim,
seq_len,
padding_idx,
act
filter_size,
pool_size,
use_cuda=False,
is_sparse=True,
param_attr=None,
padding_idx=None,
act=None
):
super(SimCNNEncoder, self).__init__()
super(CNNEncoder, self).__init__()
self.dict_size = dict_size
self.emb_dim = emb_dim
self.emb_size = emb_size
self.filter_size = filter_size
self.num_filters = num_filters
self.hidden_dim = hidden_dim
self.seq_len = seq_len
self.padding_idx = padding_idx
self.act = act
self.channels = 1
self.emb_layer = Embedding(size=[self.dict_size, self.emb_dim],
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=fluid.ParamAttr(name='emb', initializer=fluid.initializer.Xavier()))
self.cnn_layer = SimpleConvPoolLayer(
self.pool_size = pool_size
self.channels = num_channels
self._emb_layer = Embedding(size=[self.dict_size, self.emb_size],
is_sparse=is_sparse,
padding_idx=padding_idx,
param_attr=param_attr)
self._cnn_layer = Conv1dPoolLayer(
self.channels,
self.num_filters,
self.filter_size,
use_cudnn=False,
act=self.act
)
def forward(self, input):
emb = self.emb_layer(input)
emb_reshape = fluid.layers.reshape(
emb, shape=[-1, self.channels, self.seq_len, self.hidden_dim])
emb_out=self.cnn_layer(emb_reshape)
return emb_out
class SimBOWEncoder(Layer):
"""
simple BOWEncoder for simnet
"""
def __init__(self,
dict_size,
emb_dim,
bow_dim,
seq_len,
padding_idx
):
super(SimBOWEncoder, self).__init__()
self.dict_size = dict_size
self.bow_dim = bow_dim
self.seq_len = seq_len
self.emb_dim = emb_dim
self.padding_idx=padding_idx
self.emb_layer = Embedding(size=[self.dict_size, self.emb_dim],
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=fluid.ParamAttr(name='emb', initializer=fluid.initializer.Xavier()))
def forward(self, input):
emb = self.emb_layer(input)
emb_reshape = fluid.layers.reshape(
emb, shape=[-1, self.seq_len, self.bow_dim])
bow_emb = fluid.layers.reduce_sum(emb_reshape, dim=1)
return bow_emb
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
h_0=None,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
init_size=None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[:, i:i + 1, :]
input_ = fluid.layers.reshape(
input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(
hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class SimGRUEncoder(Layer):
"""
simple GRUEncoder for simnet
"""
def __init__(self,
dict_size,
emb_dim,
gru_dim,
hidden_dim,
padding_idx,
seq_len
):
super(SimGRUEncoder, self).__init__()
self.dict_size = dict_size
self.emb_dim = emb_dim
self.gru_dim = gru_dim
self.seq_len=seq_len
self.hidden_dim = hidden_dim
self.padding_idx=self.padding_idx
self.emb_layer = Embedding(size=[self.dict_size, self.emb_dim],
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=fluid.ParamAttr(name='emb',
initializer=fluid.initializer.Xavier()))
self.gru_layer = DynamicGRU(self.gru_dim)
self.proj_layer = Linear(input_dim=self.hidden_dim, output_dim=self.gru_dim * 3)
def forward(self, input):
emb = self.emb_layer(input)
emb_proj = self.proj_layer(emb)
h_0 = np.zeros((emb_proj.shape[0], self.hidden_dim), dtype="float32")
h_0 = to_variable(h_0)
gru = self.gru_layer(emb_proj, h_0=h_0)
gru = fluid.layers.reduce_max(gru, dim=1)
gru = fluid.layers.tanh(gru)
return gru
class SimLSTMEncoder(Layer):
"""
simple LSTMEncoder for simnet
"""
def __init__(self,
dict_size,
emb_dim,
lstm_dim,
hidden_dim,
seq_len,
padding_idx,
is_reverse
):
"""
initialize
"""
super(SimLSTMEncoder, self).__init__()
self.dict_size = dict_size
self.emb_dim = emb_dim
self.lstm_dim = lstm_dim
self.hidden_dim = hidden_dim
self.seq_len = seq_len
self.is_reverse = False
self.padding_idx=padding_idx
self.emb_layer = Embedding(size=[self.dict_size, self.emb_dim],
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=fluid.ParamAttr(name='emb', initializer=fluid.initializer.Xavier()))
self.lstm_cell = BasicLSTMCell(
hidden_size=self.lstm_dim, input_size=self.lstm_dim * 4
)
self.lstm_layer = RNN(
cell=self.lstm_cell, time_major=True, is_reverse=self.is_reverse
self.pool_size,
use_cudnn=use_cuda,
act=act
)
self.proj_layer = Linear(input_dim=self.hidden_dim, output_dim=self.lstm_dim * 4)
def forward(self, input):
emb = self.emb_layer(input)
emb_proj = self.proj_layer(emb)
emb_lstm, _ = self.lstm_layer(emb_proj)
emb_reduce = fluid.layers.reduce_max(emb_lstm, dim=1)
emb = self._emb_layer(input)
emb_reshape = fluid.layers.reshape(
emb_reduce, shape=[-1, self.seq_len, self.hidden_dim])
emb_lstm = fluid.layers.reduce_sum(emb_reshape, dim=1)
emb_last = fluid.layers.tanh(emb_lstm)
return emb_last
emb, shape=[-1, self.channels, self.emb_size])
emb_out=self._cnn_layer(emb_reshape)
return emb_out
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册