“6a16182835f879c28cd2424590288c3bd303f6ac”上不存在“demo/auto_compression/tensorflow_mobilenet/run.py”
未验证 提交 4672ea8e 编写于 作者: 骑马小猫 提交者: GitHub

[FluidAPI] remove fluid rnn apis (#49050)

* remove lstm api

* remove gru_unit api

* remove lstm in all

* remove beam-search

* remove beam_search slot

* remove lstm test code

* remove fluid.layers.nn api

* update gru-unit

* revert gru_unit white list
上级 0c1cb5e3
...@@ -37,11 +37,6 @@ from collections.abc import Sequence ...@@ -37,11 +37,6 @@ from collections.abc import Sequence
__all__ = [ __all__ = [
'dynamic_decode', 'dynamic_decode',
'dynamic_lstm',
'dynamic_lstmp',
'dynamic_gru',
'gru_unit',
'lstm',
] ]
...@@ -476,940 +471,3 @@ def dynamic_decode( ...@@ -476,940 +471,3 @@ def dynamic_decode(
return_length, return_length,
**kwargs **kwargs
) )
def dynamic_lstm(
input,
size,
h_0=None,
c_0=None,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
dtype='float32',
name=None,
):
r"""
**Note**:
1. This OP only supports LoDTensor as inputs. If you need to deal with Tensor, please use :ref:`api_fluid_layers_lstm` .
2. In order to improve efficiency, users must first map the input of dimension [T, hidden_size] to input of [T, 4 * hidden_size], and then pass it to this OP.
The implementation of this OP include diagonal/peephole connections.
Please refer to `Gers, F. A., & Schmidhuber, J. (2000) <ftp://ftp.idsia.ch/pub/juergen/TimeCount-IJCNN2000.pdf>`_ .
If you do not need peephole connections, please set use_peepholes to False .
This OP computes each timestep as follows:
.. math::
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_{x_i} + b_{h_i})
.. math::
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_{x_f} + b_{h_f})
.. math::
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_{x_o} + b_{h_o})
.. math::
\widetilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + b{x_c} + b_{h_c})
.. math::
c_t = f_t \odot c_{t-1} + i_t \odot \widetilde{c_t}
.. math::
h_t = o_t \odot tanh(c_t)
The symbolic meanings in the formula are as follows:
- :math:`x_{t}` represents the input at timestep :math:`t`
- :math:`h_{t}` represents the hidden state at timestep :math:`t`
- :math:`h_{t-1}, c_{t-1}` represent the hidden state and cell state at timestep :math:`t-1` , respectively
- :math:`\widetilde{c_t}` represents the candidate cell state
- :math:`i_t` , :math:`f_t` and :math:`o_t` represent input gate, forget gate, output gate, respectively
- :math:`W` represents weight (e.g., :math:`W_{ix}` is the weight of a linear transformation of input :math:`x_{t}` when calculating input gate :math:`i_t` )
- :math:`b` represents bias (e.g., :math:`b_{i}` is the bias of input gate)
- :math:`\sigma` represents nonlinear activation function for gate, default sigmoid
- :math:`\odot` represents the Hadamard product of a matrix, i.e. multiplying the elements of the same position for two matrices with the same dimension to get another matrix with the same dimension
Parameters:
input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, multi-dimensional LODTensor of shape :math:`[T, 4*hidden\_size]` . Data type is float32 or float64.
size (int): must be 4 * hidden_size.
h_0( :ref:`api_guide_Variable_en` , optional): The initial hidden state of the LSTM, multi-dimensional Tensor of shape :math:`[batch\_size, hidden\_size]` .
Data type is float32 or float64. If set to None, it will be a vector of all 0. Default: None.
c_0( :ref:`api_guide_Variable_en` , optional): The initial hidden state of the LSTM, multi-dimensional Tensor of shape :math:`[batch\_size, hidden\_size]` .
Data type is float32 or float64. If set to None, it will be a vector of all 0. `h_0` and `c_0` can be None but only at the same time. Default: None.
param_attr(ParamAttr, optional): Parameter attribute of weight. If it is None, the default weight parameter attribute is used. Please refer to ref:`api_fluid_ParamAttr' .
If the user needs to set this parameter, the dimension must be :math:`[hidden\_size, 4*hidden\_size]` . Default: None.
- Weights = :math:`\{ W_{cr},W_{ir},W_{fr},W_{or} \}` , the shape is [hidden_size, 4*hidden_size].
bias_attr (ParamAttr, optional): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
setting `use_peepholes` to `True`.
Please refer to ref:`api_fluid_ParamAttr' . Default: None.
1. `use_peepholes = False`
- Biases = {:math:`b_c, b_i, b_f, b_o`}.
- The shape is [1, 4*hidden_size].
2. `use_peepholes = True`
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is [1, 7*hidden_size].
use_peepholes (bool, optional): Whether to use peephole connection or not. Default: True.
is_reverse (bool, optional): Whether to calculate reverse LSTM. Default: False.
gate_activation (str, optional): The activation for input gate, forget gate and output gate. Default: "sigmoid".
cell_activation (str, optional): The activation for cell output. Default: "tanh".
candidate_activation (str, optional): The activation for candidate hidden state. Default: "tanh".
dtype (str, optional): Data type, can be "float32" or "float64". Default: "float32".
name (str, optional): A name for this layer. Please refer to :ref:`api_guide_Name` . Default: None.
Returns:
tuple ( :ref:`api_guide_Variable` , :ref:`api_guide_Variable` ) :
The hidden state and cell state of LSTM
- hidden: LoDTensor with shape of :math:`[T, hidden\_size]` , and its lod and dtype is the same as the input.
- cell: LoDTensor with shape of :math:`[T, hidden\_size]` , and its lod and dtype is the same as the input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
emb_dim = 256
vocab_size = 10000
hidden_dim = 512
data = fluid.data(name='x', shape=[None], dtype='int64', lod_level=1)
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
forward_proj = fluid.layers.fc(input=emb, size=hidden_dim * 4,
bias_attr=False)
forward, cell = fluid.layers.dynamic_lstm(
input=forward_proj, size=hidden_dim * 4, use_peepholes=False)
forward.shape # (-1, 512)
cell.shape # (-1, 512)
"""
assert (
_non_static_mode() is not True
), "please use lstm instead of dynamic_lstm in dygraph mode!"
assert (
bias_attr is not False
), "bias_attr should not be False in dynamic_lstm."
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'dynamic_lstm'
)
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(h_0, Variable):
check_variable_and_dtype(
h_0, 'h_0', ['float32', 'float64'], 'dynamic_lstm'
)
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(c_0, Variable):
check_variable_and_dtype(
c_0, 'c_0', ['float32', 'float64'], 'dynamic_lstm'
)
helper = LayerHelper('lstm', **locals())
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype
)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True
)
hidden = helper.create_variable_for_type_inference(dtype)
cell = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, size), (
'The shape of h0 should be (batch_size, %d)' % size
)
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), (
'The shape of c0 should be (batch_size, %d)' % size
)
inputs['C0'] = c_0
helper.append_op(
type='lstm',
inputs=inputs,
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act,
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation,
},
)
return hidden, cell
@deprecated(
since='2.0.0',
update_to='paddle.nn.LSTM',
reason="This API may occur CUDNN errors.",
)
def lstm(
input,
init_h,
init_c,
max_len,
hidden_size,
num_layers,
dropout_prob=0.0,
is_bidirec=False,
is_test=False,
name=None,
default_initializer=None,
seed=-1,
):
r"""
**Note**:
This OP only supports running on GPU devices.
This OP implements LSTM operation - `Hochreiter, S., & Schmidhuber, J. (1997) <https://blog.xpgreat.com/file/lstm.pdf>`_ .
The implementation of this OP does not include diagonal/peephole connections.
Please refer to `Gers, F. A., & Schmidhuber, J. (2000) <ftp://ftp.idsia.ch/pub/juergen/TimeCount-IJCNN2000.pdf>`_ .
If you need peephole connections, please use :ref:`api_fluid_layers_dynamic_lstm` .
This OP computes each timestep as follows:
.. math::
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_{x_i} + b_{h_i})
.. math::
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_{x_f} + b_{h_f})
.. math::
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_{x_o} + b_{h_o})
.. math::
\widetilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + b{x_c} + b_{h_c})
.. math::
c_t = f_t \odot c_{t-1} + i_t \odot \widetilde{c_t}
.. math::
h_t = o_t \odot tanh(c_t)
The symbolic meanings in the formula are as follows:
- :math:`x_{t}` represents the input at timestep :math:`t`
- :math:`h_{t}` represents the hidden state at timestep :math:`t`
- :math:`h_{t-1}, c_{t-1}` represent the hidden state and cell state at timestep :math:`t-1` , respectively
- :math:`\widetilde{c_t}` represents the candidate cell state
- :math:`i_t` , :math:`f_t` and :math:`o_t` represent input gate, forget gate, output gate, respectively
- :math:`W` represents weight (e.g., :math:`W_{ix}` is the weight of a linear transformation of input :math:`x_{t}` when calculating input gate :math:`i_t` )
- :math:`b` represents bias (e.g., :math:`b_{i}` is the bias of input gate)
- :math:`\sigma` represents nonlinear activation function for gate, default sigmoid
- :math:`\odot` represents the Hadamard product of a matrix, i.e. multiplying the elements of the same position for two matrices with the same dimension to get another matrix with the same dimension
Parameters:
input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64
init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): This parameter has no effect and will be discarded.
init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
hidden_size (int): hidden size of the LSTM.
num_layers (int): total layers number of the LSTM.
dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
There is NO dropout work on rnn output of the last RNN layers.
Default: 0.0.
is_bidirec (bool, optional): If it is bidirectional. Default: False.
is_test (bool, optional): If it is in test phrase. Default: False.
name (str, optional): A name for this layer. If set None, the layer
will be named automatically. Default: None.
default_initializer(Initializer, optional): Where use initializer to initialize the Weight
If set None, default initializer will be used. Default: None.
seed(int, optional): Seed for dropout in LSTM, If it's -1, dropout will use random seed. Default: 1.
Returns:
tuple ( :ref:`api_guide_Variable_en` , :ref:`api_guide_Variable_en` , :ref:`api_guide_Variable_en` ) :
Three tensors, rnn_out, last_h, last_c:
- rnn_out is result of LSTM hidden, shape is :math:`[seq\_len, batch\_size, hidden\_size]` \
if is_bidirec set to True, shape will be :math:`[seq\_len, batch\_size, hidden\_size*2]`
- last_h is the hidden state of the last step of LSTM \
shape is :math:`[num\_layers, batch\_size, hidden\_size]` \
if is_bidirec set to True, shape will be :math:`[num\_layers*2, batch\_size, hidden\_size]`
- last_c(Tensor): the cell state of the last step of LSTM \
shape is :math:`[num\_layers, batch\_size, hidden\_size]` \
if is_bidirec set to True, shape will be :math:`[num\_layers*2, batch\_size, hidden\_size]`
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
paddle.enable_static()
emb_dim = 256
vocab_size = 10000
data = fluid.data(name='x', shape=[None, 100], dtype='int64')
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
batch_size = 100
dropout_prob = 0.2
input_size = 100
hidden_size = 150
num_layers = 1
max_len = 12
init_h = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0 )
init_c = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0 )
rnn_out, last_h, last_c = layers.lstm( emb, init_h, init_c, \
max_len, hidden_size, num_layers, \
dropout_prob=dropout_prob)
rnn_out.shape # (-1, 100, 150)
last_h.shape # (1, 20, 150)
last_c.shape # (1, 20, 150)
"""
helper = LayerHelper('cudnn_lstm', **locals())
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'lstm')
check_variable_and_dtype(init_h, 'init_h', ['float32', 'float64'], 'lstm')
check_variable_and_dtype(init_c, 'init_c', ['float32', 'float64'], 'lstm')
check_type(max_len, 'max_len', (int), 'lstm')
check_type(hidden_size, 'hidden_size', (int), 'lstm')
check_type(num_layers, 'num_layers', (int), 'lstm')
dtype = input.dtype
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1
for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4 * num_dirrection
else:
input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8 * num_dirrection
weight = helper.create_parameter(
attr=helper.param_attr,
shape=[weight_size],
dtype=dtype,
default_initializer=default_initializer,
)
out = helper.create_variable_for_type_inference(dtype)
last_h = helper.create_variable_for_type_inference(dtype)
last_c = helper.create_variable_for_type_inference(dtype)
reserve = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
state_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
state_out.persistable = True
helper.append_op(
type='cudnn_lstm',
inputs={
'Input': input,
'InitH': init_h,
'InitC': init_c,
'W': weight,
},
outputs={
'Out': out,
'LastH': last_h,
'LastC': last_c,
'Reserve': reserve,
'StateOut': state_out,
},
attrs={
'is_bidirec': is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'is_test': is_test,
'dropout_prob': dropout_prob,
'seed': seed,
},
)
return out, last_h, last_c
def dynamic_lstmp(
input,
size,
proj_size,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
proj_activation='tanh',
dtype='float32',
name=None,
h_0=None,
c_0=None,
cell_clip=None,
proj_clip=None,
):
r"""
**Note**:
1. In order to improve efficiency, users must first map the input of dimension [T, hidden_size] to input of [T, 4 * hidden_size], and then pass it to this OP.
This OP implements the LSTMP (LSTM Projected) layer.
The LSTMP layer has a separate linear mapping layer behind the LSTM layer. -- `Sak, H., Senior, A., & Beaufays, F. (2014) <https://ai.google/research/pubs/pub43905.pdf>`_ .
Compared with the standard LSTM layer, LSTMP has an additional linear mapping layer,
which is used to map from the original hidden state :math:`h_t` to the lower dimensional state :math:`r_t` .
This reduces the total number of parameters and computational complexity, especially when the output unit is relatively large.
The default implementation of the OP contains diagonal/peephole connections,
please refer to `Gers, F. A., & Schmidhuber, J. (2000) <ftp://ftp.idsia.ch/pub/juergen/TimeCount-IJCNN2000.pdf>`_ .
If you need to disable the peephole connections, set use_peepholes to False.
This OP computes each timestep as follows:
.. math::
i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i)
.. math::
f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f)
.. math::
o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_{t-1} + b_o)
.. math::
\widetilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c)
.. math::
c_t = f_t \odot c_{t-1} + i_t \odot \widetilde{c_t}
.. math::
h_t = o_t \odot act_h(c_t)
.. math::
r_t = \overline{act_h}(W_{rh}h_t)
The symbolic meanings in the formula are as follows:
- :math:`x_{t}` represents the input at timestep :math:`t`
- :math:`h_{t}` represents the hidden state at timestep :math:`t`
- :math:`r_{t}` : represents the state of the projected output of the hidden state :math:`h_{t}`
- :math:`h_{t-1}, c_{t-1}, r_{t-1}` represent the hidden state, cell state and projected output at timestep :math:`t-1` , respectively
- :math:`\widetilde{c_t}` represents the candidate cell state
- :math:`i_t` , :math:`f_t` and :math:`o_t` represent input gate, forget gate, output gate, respectively
- :math:`W` represents weight (e.g., :math:`W_{ix}` is the weight of a linear transformation of input :math:`x_{t}` when calculating input gate :math:`i_t` )
- :math:`b` represents bias (e.g., :math:`b_{i}` is the bias of input gate)
- :math:`\sigma` represents nonlinear activation function for gate, default sigmoid
- :math:`\odot` represents the Hadamard product of a matrix, i.e. multiplying the elements of the same position for two matrices with the same dimension to get another matrix with the same dimension
Parameters:
input( :ref:`api_guide_Variable_en` ): The input of dynamic_lstmp layer, which supports
variable-time length input sequence.
It is a multi-dimensional LODTensor of shape :math:`[T, 4*hidden\_size]` . Data type is float32 or float64.
size(int): must be 4 * hidden_size.
proj_size(int): The size of projection output.
param_attr(ParamAttr, optional): Parameter attribute of weight. If it is None, the default weight parameter attribute is used. Please refer to ref:`api_fluid_ParamAttr' .
If the user needs to set this parameter, the dimension must be :math:`[hidden\_size, 4*hidden\_size]` . Default: None.
- Weights = :math:`\{ W_{cr},W_{ir},W_{fr},W_{or} \}` , the shape is [P, 4*hidden_size] , where P is the projection size.
- Projection weight = :math:`\{ W_{rh} \}` , the shape is [hidden_size, P].
bias_attr (ParamAttr, optional): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
setting `use_peepholes` to `True`.
Please refer to ref:`api_fluid_ParamAttr' . Default: None.
1. `use_peepholes = False`
- Biases = {:math:`b_c, b_i, b_f, b_o`}.
- The shape is [1, 4*hidden_size].
2. `use_peepholes = True`
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is [1, 7*hidden_size].
use_peepholes (bool, optional): Whether to use peephole connection or not. Default True.
is_reverse (bool, optional): Whether to calculate reverse LSTM. Default False.
gate_activation (str, optional): The activation for input gate, forget gate and output gate. Default "sigmoid".
cell_activation (str, optional): The activation for cell output. Default "tanh".
candidate_activation (str, optional): The activation for candidate hidden state. Default "tanh".
proj_activation(str, optional): The activation for projection output. Default "tanh".
dtype (str, optional): Data type, can be "float32" or "float64". Default "float32".
name (str, optional): A name for this layer. Please refer to :ref:`api_guide_Name` . Default: None.
h_0( :ref:`api_guide_Variable` , optional): The initial hidden state is an optional input, default is zero.
This is a tensor with shape :math:`[batch\_size, P]` , where P is the projection size. Default: None.
c_0( :ref:`api_guide_Variable` , optional): The initial cell state is an optional input, default is zero.
This is a tensor with shape :math:`[batch\_size, P]` , where P is the projection size.
`h_0` and `c_0` can be None but only at the same time. Default: None.
cell_clip(float, optional): If not None, the cell state is clipped
by this value prior to the cell output activation. Default: None.
proj_clip(float, optional): If `num_proj > 0` and `proj_clip` is
provided, then the projected values are clipped elementwise to within
`[-proj_clip, proj_clip]`. Default: None.
Returns:
tuple ( :ref:`api_guide_Variable` , :ref:`api_guide_Variable` ) :
The hidden state and cell state of LSTMP
- hidden: LoDTensor with shape of :math:`[T, P]` , and its lod and dtype is the same as the input.
- cell: LoDTensor with shape of :math:`[T, hidden\_size]` , and its lod and dtype is the same as the input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dict_dim, emb_dim = 128, 64
data = fluid.data(name='sequence', shape=[None], dtype='int64', lod_level=1)
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim, proj_dim = 512, 256
fc_out = fluid.layers.fc(input=emb, size=hidden_dim * 4,
act=None, bias_attr=None)
proj_out, last_c = fluid.layers.dynamic_lstmp(input=fc_out,
size=hidden_dim * 4,
proj_size=proj_dim,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh")
proj_out.shape # (-1, 256)
last_c.shape # (-1, 512)
"""
assert (
_non_static_mode() is not True
), "please use lstm instead of dynamic_lstmp in dygraph mode!"
assert (
bias_attr is not False
), "bias_attr should not be False in dynamic_lstmp."
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'dynamic_lstmp'
)
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(h_0, Variable):
check_variable_and_dtype(
h_0, 'h_0', ['float32', 'float64'], 'dynamic_lstmp'
)
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(c_0, Variable):
check_variable_and_dtype(
c_0, 'c_0', ['float32', 'float64'], 'dynamic_lstmp'
)
helper = LayerHelper('lstmp', **locals())
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[proj_size, 4 * size], dtype=dtype
)
proj_weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, proj_size], dtype=dtype
)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True
)
projection = helper.create_variable_for_type_inference(dtype)
cell = helper.create_variable_for_type_inference(dtype)
ordered_proj0 = helper.create_variable_for_type_inference(dtype)
batch_hidden = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
inputs = {
'Input': input,
'Weight': weight,
'ProjWeight': proj_weight,
'Bias': bias,
}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, proj_size), (
'The shape of h0 should be (batch_size, %d)' % proj_size
)
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), (
'The shape of c0 should be (batch_size, %d)' % size
)
inputs['C0'] = c_0
if cell_clip:
assert cell_clip >= 0, "cell_clip should not be negative."
if proj_clip:
assert proj_clip >= 0, "proj_clip should not be negative."
helper.append_op(
type='lstmp',
inputs=inputs,
outputs={
'Projection': projection,
'Cell': cell,
'BatchHidden': batch_hidden,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act,
},
attrs={
'use_peepholes': use_peepholes,
'cell_clip': cell_clip,
'proj_clip': proj_clip,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation,
'proj_activation': proj_activation,
},
)
return projection, cell
def dynamic_gru(
input,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
):
r"""
**Note: The input type of this must be LoDTensor. If the input type to be
processed is Tensor, use** :ref:`api_fluid_layers_StaticRNN` .
This operator is used to perform the calculations for a single layer of
Gated Recurrent Unit (GRU) on full sequences step by step. The calculations
in one time step support these two modes:
If ``origin_mode`` is True, then the formula used is from paper
`Learning Phrase Representations using RNN Encoder Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_ .
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t}
if ``origin_mode`` is False, then the formula used is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
:math:`x_t` is the input of current time step, but it is not from ``input`` .
This operator does not include the calculations :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` ,
**Note** thus a fully-connect layer whose size is 3 times of ``size`` should
be used before this operator, and the output should be used as ``input`` here.
:math:`h_{t-1}` is the hidden state from previous time step.
:math:`u_t` , :math:`r_t` , :math:`\\tilde{h_t}` and :math:`h_t` stand for
update gate, reset gate, candidate hidden and hidden output separately.
:math:`W_{uh}, b_u` , :math:`W_{rh}, b_r` and :math:`W_{ch}, b_c` stand for
the weight matrix and bias used in update gate, reset gate, candidate hidden
calculations. For implementation, the three weight matrix are merged into a
tensor shaped :math:`[D, D \\times 3]` , the three bias are concatenated as
a tensor shaped :math:`[1, D \\times 3]` , where :math:`D` stands for the
hidden size; The data layout of weight tensor is: :math:`W_{uh}` and :math:`W_{rh}`
are concatenated with shape :math:`[D, D \\times 2]` lying on the first part,
and :math:`W_{ch}` lying on the latter part with shape :math:`[D, D]` .
Args:
input(Variable): A LoDTensor whose lod level is 1, representing the input
after linear projection. Its shape should be :math:`[T, D \\times 3]` ,
where :math:`T` stands for the total sequence lengths in this mini-batch,
:math:`D` for the hidden size. The data type should be float32 or float64.
size(int): Indicate the hidden size.
param_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used.
See usage for details in :ref:`api_fluid_ParamAttr` .
bias_attr (ParamAttr, optional): To specify the bias parameter property.
Default: None, which means the default bias parameter property is used.
See usage for details in :ref:`api_fluid_ParamAttr` .
is_reverse(bool, optional): Whether to compute in the reversed order of
input sequences. Default False.
gate_activation(str, optional): The activation function corresponding to
:math:`act_g` in the formula. "sigmoid", "tanh", "relu" and "identity"
are supported. Default "sigmoid".
candidate_activation(str, optional): The activation function corresponding to
:math:`act_c` in the formula. "sigmoid", "tanh", "relu" and "identity"
are supported. Default "tanh".
h_0 (Variable, optional): A Tensor representing the initial hidden state.
It not provided, the default initial hidden state is 0. The shape is
:math:`[N, D]` , where :math:`N` is the number of sequences in the
mini-batch, :math:`D` for the hidden size. The data type should be
same as ``input`` . Default None.
Returns:
Variable: A LoDTensor whose lod level is 1 and shape is :math:`[T, D]` , \
where :math:`T` stands for the total sequence lengths in this mini-batch \
:math:`D` for the hidden size. It represents GRU transformed sequence output, \
and has the same lod and data type with ``input`` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
dict_dim, emb_dim = 128, 64
data = fluid.data(name='sequence',
shape=[None],
dtype='int64',
lod_level=1)
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512
x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, size=hidden_dim)
"""
assert (
_non_static_mode() is not True
), "please use gru instead of dynamic_gru in dygraph mode!"
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'dynamic_gru'
)
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_gru')
if isinstance(h_0, Variable):
check_variable_and_dtype(
h_0, 'h_0', ['float32', 'float64'], 'dynamic_gru'
)
helper = LayerHelper('gru', **locals())
dtype = helper.input_dtype()
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype
)
bias = helper.create_parameter(
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True
)
batch_size = input.shape[0]
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0:
assert h_0.shape == (batch_size, size), (
'The shape of h0 should be(batch_size, %d)' % size
)
inputs['H0'] = h_0
hidden = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_variable_for_type_inference(dtype)
batch_reset_hidden_prev = helper.create_variable_for_type_inference(dtype)
batch_hidden = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='gru',
inputs=inputs,
outputs={
'Hidden': hidden,
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden,
},
attrs={
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'activation': candidate_activation,
'origin_mode': origin_mode,
},
)
return hidden
def gru_unit(
input,
hidden,
size,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid',
origin_mode=False,
):
r"""
Gated Recurrent Unit (GRU) RNN cell. This operator performs GRU calculations for
one time step and it supports these two modes:
If ``origin_mode`` is True, then the formula used is from paper
`Learning Phrase Representations using RNN Encoder Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_ .
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t}
if ``origin_mode`` is False, then the formula used is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
:math:`x_t` is the input of current time step, but it is not ``input`` .
This operator does not include the calculations :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` ,
**Note** thus a fully-connect layer whose size is 3 times of GRU hidden size should
be used before this operator, and the output should be used as ``input`` here.
:math:`h_{t-1}` is the hidden state from previous time step.
:math:`u_t` , :math:`r_t` , :math:`\\tilde{h_t}` and :math:`h_t` stand for
update gate, reset gate, candidate hidden and hidden output separately.
:math:`W_{uh}, b_u` , :math:`W_{rh}, b_r` and :math:`W_{ch}, b_c` stand for
the weight matrix and bias used in update gate, reset gate, candidate hidden
calculations. For implementation, the three weight matrix are merged into a
tensor shaped :math:`[D, D \\times 3]` , the three bias are concatenated as
a tensor shaped :math:`[1, D \\times 3]` , where :math:`D` stands for the
hidden size; The data layout of weight tensor is: :math:`W_{uh}` and :math:`W_{rh}`
are concatenated with shape :math:`[D, D \\times 2]` lying on the first part,
and :math:`W_{ch}` lying on the latter part with shape :math:`[D, D]` .
Args:
input(Variable): A 2D Tensor representing the input after linear projection
after linear projection. Its shape should be :math:`[N, D \\times 3]` ,
where :math:`N` stands for batch size, :math:`D` for the hidden size.
The data type should be float32 or float64.
hidden(Variable): A 2D Tensor representing the hidden state from previous step.
Its shape should be :math:`[N, D]` , where :math:`N` stands for batch size,
:math:`D` for the hidden size. The data type should be same as ``input`` .
size(int): Indicate the hidden size.
param_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used.
See usage for details in :ref:`api_fluid_ParamAttr` .
bias_attr (ParamAttr, optional): To specify the bias parameter property.
Default: None, which means the default bias parameter property is used.
See usage for details in :ref:`api_fluid_ParamAttr` .
activation(str, optional): The activation function corresponding to
:math:`act_c` in the formula. "sigmoid", "tanh", "relu" and "identity"
are supported. Default "tanh".
gate_activation(str, optional): The activation function corresponding to
:math:`act_g` in the formula. "sigmoid", "tanh", "relu" and "identity"
are supported. Default "sigmoid".
Returns:
tuple: The tuple contains three Tensor variables with the same data type \
as ``input`` . They represent the hidden state for next time step ( :math:`h_t` ), \
reset previous hidden state ( :math:`r_t \odot h_{t-1}` ), and the \
concatenation of :math:`h_t, r_t, \\tilde{h_t}` . And they have shape \
:math:`[N, D]` , :math:`[N, D]` , :math:`[N, D \times 3]` separately. \
Usually only the hidden state for next time step ( :math:`h_t` ) is used \
as output and state, the other two are intermediate results of calculations.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dict_dim, emb_dim = 128, 64
data = fluid.data(name='step_data', shape=[None], dtype='int64')
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512
x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
pre_hidden = fluid.data(
name='pre_hidden', shape=[None, hidden_dim], dtype='float32')
hidden = fluid.layers.gru_unit(
input=x, hidden=pre_hidden, size=hidden_dim * 3)
"""
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'gru_unit')
check_variable_and_dtype(
hidden, 'hidden', ['float32', 'float64'], 'gru_unit'
)
check_type(size, 'size', (int), 'gru_unit')
activation_dict = dict(
identity=0,
sigmoid=1,
tanh=2,
relu=3,
)
activation = activation_dict[activation]
gate_activation = activation_dict[gate_activation]
helper = LayerHelper('gru_unit', **locals())
dtype = helper.input_dtype()
size = size // 3
# create weight
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype
)
gate = helper.create_variable_for_type_inference(dtype)
reset_hidden_pre = helper.create_variable_for_type_inference(dtype)
updated_hidden = helper.create_variable_for_type_inference(dtype)
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': weight}
# create bias
if helper.bias_attr:
bias_size = [1, 3 * size]
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True
)
inputs['Bias'] = bias
helper.append_op(
type='gru_unit',
inputs=inputs,
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 2, # tanh
'gate_activation': 1, # sigmoid
'origin_mode': origin_mode,
},
)
return updated_hidden, reset_hidden_pre, gate
...@@ -12,5 +12,4 @@ endforeach() ...@@ -12,5 +12,4 @@ endforeach()
set_tests_properties(test_word2vec_book PROPERTIES TIMEOUT 120) set_tests_properties(test_word2vec_book PROPERTIES TIMEOUT 120)
set_tests_properties(test_recognize_digits PROPERTIES TIMEOUT 120) set_tests_properties(test_recognize_digits PROPERTIES TIMEOUT 120)
set_tests_properties(test_image_classification PROPERTIES TIMEOUT 200) set_tests_properties(test_image_classification PROPERTIES TIMEOUT 200)
set_tests_properties(test_label_semantic_roles PROPERTIES TIMEOUT 240)
set_tests_properties(test_fit_a_line PROPERTIES TIMEOUT 120) set_tests_properties(test_fit_a_line PROPERTIES TIMEOUT 120)
...@@ -55,43 +55,6 @@ def convolution_net( ...@@ -55,43 +55,6 @@ def convolution_net(
return avg_cost, accuracy, prediction return avg_cost, accuracy, prediction
def stacked_lstm_net(
data, label, input_dim, class_dim=2, emb_dim=128, hid_dim=512, stacked_num=3
):
assert stacked_num % 2 == 1
emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True
)
# add bias attr
# TODO(qijun) linear act
fc1 = fluid.layers.fc(input=emb, size=hid_dim)
lstm1, cell1 = fluid.layers.dynamic_lstm(input=fc1, size=hid_dim)
inputs = [fc1, lstm1]
for i in range(2, stacked_num + 1):
fc = fluid.layers.fc(input=inputs, size=hid_dim)
lstm, cell = fluid.layers.dynamic_lstm(
input=fc, size=hid_dim, is_reverse=(i % 2) == 0
)
inputs = [fc, lstm]
fc_last = fluid.layers.sequence_pool(input=inputs[0], pool_type='max')
lstm_last = fluid.layers.sequence_pool(input=inputs[1], pool_type='max')
prediction = fluid.layers.fc(
input=[fc_last, lstm_last], size=class_dim, act='softmax'
)
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(cost)
accuracy = paddle.static.accuracy(input=prediction, label=label)
return avg_cost, accuracy, prediction
def train( def train(
word_dict, word_dict,
net_method, net_method,
...@@ -278,25 +241,6 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -278,25 +241,6 @@ class TestUnderstandSentiment(unittest.TestCase):
parallel=True, parallel=True,
) )
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_cpu(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=False,
save_dirname="understand_sentiment_stacked_lstm.inference.model",
)
def test_stacked_lstm_cpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=False,
parallel=True,
)
def test_conv_gpu(self): def test_conv_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main( main(
...@@ -315,25 +259,6 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -315,25 +259,6 @@ class TestUnderstandSentiment(unittest.TestCase):
parallel=True, parallel=True,
) )
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_gpu(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=True,
save_dirname="understand_sentiment_stacked_lstm.inference.model",
)
def test_stacked_lstm_gpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=True,
parallel=True,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import tempfile
import time
import unittest
import numpy as np
import paddle
import paddle.dataset.conll05 as conll05
import paddle.fluid as fluid
paddle.enable_static()
word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict)
label_dict_len = len(label_dict)
pred_dict_len = len(verb_dict)
mark_dict_len = 2
word_dim = 32
mark_dim = 5
hidden_dim = 512
depth = 8
mix_hidden_lr = 1e-3
IS_SPARSE = True
PASS_NUM = 2
BATCH_SIZE = 10
embedding_name = 'emb'
def load_parameter(file_name, h, w):
with open(file_name, 'rb') as f:
f.read(16) # skip header.
return np.fromfile(f, dtype=np.float32).reshape(h, w)
def db_lstm(
word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, **ignored
):
# 8 features
predicate_embedding = fluid.layers.embedding(
input=predicate,
size=[pred_dict_len, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr='vemb',
)
mark_embedding = fluid.layers.embedding(
input=mark,
size=[mark_dict_len, mark_dim],
dtype='float32',
is_sparse=IS_SPARSE,
)
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [
fluid.layers.embedding(
size=[word_dict_len, word_dim],
input=x,
param_attr=fluid.ParamAttr(name=embedding_name, trainable=False),
)
for x in word_input
]
emb_layers.append(predicate_embedding)
emb_layers.append(mark_embedding)
hidden_0_layers = [
fluid.layers.fc(input=emb, size=hidden_dim) for emb in emb_layers
]
hidden_0 = fluid.layers.sums(input=hidden_0_layers)
lstm_0 = fluid.layers.dynamic_lstm(
input=hidden_0,
size=hidden_dim,
candidate_activation='relu',
gate_activation='sigmoid',
cell_activation='sigmoid',
)
# stack L-LSTM and R-LSTM with direct edges
input_tmp = [hidden_0, lstm_0]
for i in range(1, depth):
mix_hidden = fluid.layers.sums(
input=[
fluid.layers.fc(input=input_tmp[0], size=hidden_dim),
fluid.layers.fc(input=input_tmp[1], size=hidden_dim),
]
)
lstm = fluid.layers.dynamic_lstm(
input=mix_hidden,
size=hidden_dim,
candidate_activation='relu',
gate_activation='sigmoid',
cell_activation='sigmoid',
is_reverse=((i % 2) == 1),
)
input_tmp = [mix_hidden, lstm]
feature_out = fluid.layers.sums(
input=[
fluid.layers.fc(
input=input_tmp[0], size=label_dict_len, act='tanh'
),
fluid.layers.fc(
input=input_tmp[1], size=label_dict_len, act='tanh'
),
]
)
return feature_out
def train(use_cuda, save_dirname=None, is_local=True):
# define network topology
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1
)
predicate = fluid.layers.data(
name='verb_data', shape=[1], dtype='int64', lod_level=1
)
ctx_n2 = fluid.layers.data(
name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1
)
ctx_n1 = fluid.layers.data(
name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1
)
ctx_0 = fluid.layers.data(
name='ctx_0_data', shape=[1], dtype='int64', lod_level=1
)
ctx_p1 = fluid.layers.data(
name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1
)
ctx_p2 = fluid.layers.data(
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1
)
mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1
)
feature_out = db_lstm(**locals())
target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1
)
cost = fluid.layers.softmax_with_cross_entropy(feature_out, target)
avg_cost = paddle.mean(cost)
# TODO(qiao)
# check other optimizers and check why out will be NAN
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=0.01,
decay_steps=100000,
decay_rate=0.5,
staircase=True,
)
)
sgd_optimizer.minimize(avg_cost)
# TODO(qiao)
# add dependency track and move this config before optimizer
train_data = paddle.batch(
paddle.reader.shuffle(paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE,
)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(
feed_list=[
word,
ctx_n2,
ctx_n1,
ctx_0,
ctx_p1,
ctx_p2,
predicate,
mark,
target,
],
place=place,
)
exe = fluid.Executor(place)
def train_loop(main_program):
exe.run(fluid.default_startup_program())
embedding_param = (
fluid.global_scope().find_var(embedding_name).get_tensor()
)
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim),
place,
)
start_time = time.time()
batch_id = 0
for pass_id in range(PASS_NUM):
for data in train_data():
cost = exe.run(
main_program, feed=feeder.feed(data), fetch_list=[avg_cost]
)
cost = cost[0]
if batch_id % 10 == 0:
print("avg_cost:" + str(cost))
if batch_id != 0:
print(
"second per batch: "
+ str((time.time() - start_time) / batch_id)
)
# Set the threshold low to speed up the CI test
if float(cost) < 80.0:
if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model(
save_dirname,
[
'word_data',
'verb_data',
'ctx_n2_data',
'ctx_n1_data',
'ctx_0_data',
'ctx_p1_data',
'ctx_p2_data',
'mark_data',
],
[feature_out],
exe,
)
return
batch_id = batch_id + 1
raise RuntimeError(
"This model should save_inference_model and return, but not reach here, please check!"
)
if is_local:
train_loop(fluid.default_main_program())
else:
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVER_IPS") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("PADDLE_TRAINERS"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
t = fluid.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(
current_endpoint, pserver_prog
)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
train_loop(t.get_trainer_program())
def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be fed
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(save_dirname, exe)
# Setup input by creating LoDTensor to represent sequence of words.
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
word = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
pred = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=pred_dict_len - 1
)
ctx_n2 = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
ctx_n1 = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
ctx_0 = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
ctx_p1 = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
ctx_p2 = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=word_dict_len - 1
)
mark = fluid.create_random_int_lodtensor(
recursive_seq_lens, base_shape, place, low=0, high=mark_dict_len - 1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert feed_target_names[0] == 'word_data'
assert feed_target_names[1] == 'verb_data'
assert feed_target_names[2] == 'ctx_n2_data'
assert feed_target_names[3] == 'ctx_n1_data'
assert feed_target_names[4] == 'ctx_0_data'
assert feed_target_names[5] == 'ctx_p1_data'
assert feed_target_names[6] == 'ctx_p2_data'
assert feed_target_names[7] == 'mark_data'
results = exe.run(
inference_program,
feed={
feed_target_names[0]: word,
feed_target_names[1]: pred,
feed_target_names[2]: ctx_n2,
feed_target_names[3]: ctx_n1,
feed_target_names[4]: ctx_0,
feed_target_names[5]: ctx_p1,
feed_target_names[6]: ctx_p2,
feed_target_names[7]: mark,
},
fetch_list=fetch_targets,
return_numpy=False,
)
print(results[0].recursive_sequence_lengths())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
def main(use_cuda, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
temp_dir = tempfile.TemporaryDirectory()
# Directory for saving the trained model
save_dirname = os.path.join(
temp_dir.name, "label_semantic_roles.inference.model"
)
train(use_cuda, save_dirname, is_local)
infer(use_cuda, save_dirname)
temp_dir.cleanup()
class TestLabelSemanticRoles(unittest.TestCase):
def test_cuda(self):
with self.scope_prog_guard():
main(use_cuda=True)
def test_cpu(self):
with self.scope_prog_guard():
main(use_cuda=False)
@contextlib.contextmanager
def scope_prog_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
if __name__ == '__main__':
unittest.main()
...@@ -113,7 +113,6 @@ if(WIN32) ...@@ -113,7 +113,6 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_3) list(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_3)
list(REMOVE_ITEM TEST_OPS test_fleet_unitaccessor) list(REMOVE_ITEM TEST_OPS test_fleet_unitaccessor)
list(REMOVE_ITEM TEST_OPS test_ps_dispatcher) list(REMOVE_ITEM TEST_OPS test_ps_dispatcher)
list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_nlp)
list(REMOVE_ITEM TEST_OPS test_nvprof) list(REMOVE_ITEM TEST_OPS test_nvprof)
# TODO: Fix these unittests failed on Windows # TODO: Fix these unittests failed on Windows
...@@ -997,13 +996,6 @@ set_tests_properties(test_parallel_executor_transformer PROPERTIES TIMEOUT 120) ...@@ -997,13 +996,6 @@ set_tests_properties(test_parallel_executor_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_div_op PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_div_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_regularizer_api PROPERTIES TIMEOUT 150) set_tests_properties(test_regularizer_api PROPERTIES TIMEOUT 150)
set_tests_properties(test_multiclass_nms_op PROPERTIES TIMEOUT 120) set_tests_properties(test_multiclass_nms_op PROPERTIES TIMEOUT 120)
if(NOT WIN32)
if(WITH_NV_JETSON)
set_tests_properties(test_ir_memory_optimize_nlp PROPERTIES TIMEOUT 1200)
else()
set_tests_properties(test_ir_memory_optimize_nlp PROPERTIES TIMEOUT 120)
endif()
endif()
set_tests_properties(test_add_reader_dependency PROPERTIES TIMEOUT 120) set_tests_properties(test_add_reader_dependency PROPERTIES TIMEOUT 120)
set_tests_properties(test_bilateral_slice_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilateral_slice_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_buffer_shared_memory_reuse_pass PROPERTIES TIMEOUT set_tests_properties(test_buffer_shared_memory_reuse_pass PROPERTIES TIMEOUT
...@@ -1080,7 +1072,6 @@ set_tests_properties(test_weight_decay PROPERTIES TIMEOUT 120) ...@@ -1080,7 +1072,6 @@ set_tests_properties(test_weight_decay PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_ptb_rnn_sorted_gradient PROPERTIES TIMEOUT set_tests_properties(test_imperative_ptb_rnn_sorted_gradient PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_crop_tensor_op PROPERTIES TIMEOUT 120) set_tests_properties(test_crop_tensor_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_eager_deletion_lstm_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_executor_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_ptb_rnn PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptb_rnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_save_load_v2 PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_save_load_v2 PROPERTIES TIMEOUT 120)
...@@ -1124,7 +1115,6 @@ set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250) ...@@ -1124,7 +1115,6 @@ set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250) set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_eager_deletion_gru_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_activation_op PROPERTIES TIMEOUT 270) set_tests_properties(test_activation_op PROPERTIES TIMEOUT 270)
set_tests_properties(test_normal PROPERTIES TIMEOUT 120) set_tests_properties(test_normal PROPERTIES TIMEOUT 120)
set_tests_properties(test_lstmp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_lstmp_op PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
from paddle.fluid.core import PassVersionChecker
class FcGruFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
dict_dim, emb_dim = 128, 64
data = fluid.data(
name='step_data', shape=[None], dtype='int64', lod_level=1
)
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512
x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(
input=x,
size=hidden_dim,
bias_attr=True,
origin_mode=False,
is_reverse=True,
)
batch = 16
lod_tensor = fluid.LoDTensor()
lod_tensor.set(
np.random.randint(0, dict_dim, size=[batch]).astype("int64"),
fluid.CPUPlace(),
)
lod_tensor.set_lod([[0, batch]])
self.feeds = {"step_data": lod_tensor}
self.fetch_list = [hidden]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(PassVersionChecker.IsCompatible('fc_gru_fuse_pass'))
class MulGruFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
dict_dim, emb_dim = 128, 64
data = fluid.data(
name='step_data', shape=[None], dtype='int64', lod_level=1
)
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512
x = fluid.layers.fc(input=emb, size=hidden_dim * 3, bias_attr=False)
hidden = fluid.layers.dynamic_gru(
input=x,
size=hidden_dim,
bias_attr=True,
origin_mode=False,
is_reverse=True,
)
batch = 16
lod_tensor = fluid.LoDTensor()
lod_tensor.set(
np.random.randint(0, dict_dim, size=[batch]).astype("int64"),
fluid.CPUPlace(),
)
lod_tensor.set_lod([[0, batch]])
self.feeds = {"step_data": lod_tensor}
self.fetch_list = [hidden]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(PassVersionChecker.IsCompatible('mul_gru_fuse_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
from paddle.fluid.core import PassVersionChecker
class MulLstmFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
dict_dim, emb_dim = 128, 64
hidden_dim = 512
data = fluid.data(
name='data', shape=[1], dtype='int64', lod_level=1
)
emb = fluid.embedding(input=data, size=[dict_dim, emb_dim])
x = fluid.layers.fc(input=emb, size=hidden_dim * 4, bias_attr=False)
forward, cell = fluid.layers.dynamic_lstm(
input=x, size=hidden_dim * 4
)
batch = 16
lod_tensor = fluid.LoDTensor()
lod_tensor.set(
np.random.randint(0, dict_dim, size=[batch]).astype("int64"),
fluid.CPUPlace(),
)
lod_tensor.set_lod([[0, batch]])
self.feeds = {"data": lod_tensor}
self.fetch_list = [forward, cell]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(PassVersionChecker.IsCompatible('mul_lstm_fuse_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_eager_deletion_dynamic_rnn_base import TestBase
import paddle
import paddle.fluid as fluid
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
def gru_net(
data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=400.0,
):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
gru_max_tanh = paddle.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(x=cost)
return avg_cost
class GRUTest(TestBase):
def setUp(self):
self.net = gru_net
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_eager_deletion_dynamic_rnn_base import TestBase
import paddle
import paddle.fluid as fluid
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
def lstm_net(
data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0,
):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False
)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = paddle.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(x=cost)
return avg_cost
class LSTMTest(TestBase):
def setUp(self):
self.net = lstm_net
if __name__ == "__main__":
unittest.main()
...@@ -404,21 +404,6 @@ def lm_model( ...@@ -404,21 +404,6 @@ def lm_model(
init_hidden=init_hidden_reshape, init_hidden=init_hidden_reshape,
init_cell=init_cell_reshape, init_cell=init_cell_reshape,
) )
elif rnn_model == "cudnn":
x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
rnn_out, last_hidden, last_cell = layers.lstm(
x_emb,
init_hidden_reshape,
init_cell_reshape,
num_steps,
hidden_size,
num_layers,
is_bidirec=False,
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale
),
)
rnn_out = paddle.transpose(rnn_out, perm=[1, 0, 2])
elif rnn_model == "basic_lstm": elif rnn_model == "basic_lstm":
rnn_out, last_hidden, last_cell = basic_lstm( rnn_out, last_hidden, last_cell = basic_lstm(
x_emb, x_emb,
......
...@@ -18,8 +18,6 @@ import unittest ...@@ -18,8 +18,6 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle import fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION
...@@ -267,25 +265,5 @@ class TestGRUOpInference(TestGRUOp): ...@@ -267,25 +265,5 @@ class TestGRUOpInference(TestGRUOp):
pass pass
class TestGruOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 1536)).astype("float32")
fluid.layers.dynamic_gru(input=input_data, size=512)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 1536], dtype="float32"
)
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_gru(input=in_data, size=512, h_0=h)
self.assertRaises(TypeError, test_h_0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -19,8 +19,6 @@ import numpy as np ...@@ -19,8 +19,6 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.layers import gru_unit
class GRUActivationType(OpTest): class GRUActivationType(OpTest):
...@@ -46,55 +44,6 @@ def relu(x): ...@@ -46,55 +44,6 @@ def relu(x):
return np.maximum(x, 0) return np.maximum(x, 0)
class TestGRUUnitOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
batch_size = 5
hidden_dim = 40
input = fluid.data(
name='input', shape=[None, hidden_dim * 3], dtype='float32'
)
pre_hidden = fluid.data(
name='pre_hidden', shape=[None, hidden_dim], dtype='float32'
)
np_input = np.random.uniform(
-0.1, 0.1, (batch_size, hidden_dim * 3)
).astype('float64')
np_pre_hidden = np.random.uniform(
-0.1, 0.1, (batch_size, hidden_dim)
).astype('float64')
def test_input_Variable():
gru_unit(np_input, pre_hidden, hidden_dim * 3)
self.assertRaises(TypeError, test_input_Variable)
def test_pre_hidden_Variable():
gru_unit(input, np_pre_hidden, hidden_dim * 3)
self.assertRaises(TypeError, test_pre_hidden_Variable)
def test_input_type():
error_input = fluid.data(
name='error_input',
shape=[None, hidden_dim * 3],
dtype='int32',
)
gru_unit(error_input, pre_hidden, hidden_dim * 3)
self.assertRaises(TypeError, test_input_type)
def test_pre_hidden_type():
error_pre_hidden = fluid.data(
name='error_pre_hidden',
shape=[None, hidden_dim],
dtype='int32',
)
gru_unit(input, error_pre_hidden, hidden_dim * 3)
self.assertRaises(TypeError, test_pre_hidden_type)
class TestGRUUnitOp(OpTest): class TestGRUUnitOp(OpTest):
batch_size = 5 batch_size = 5
frame_size = 40 frame_size = 40
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# nlp model stack of op operate on lod. It's a classical test case in optimize pass.
import unittest
from ir_memory_optimize_net_base import TestIrMemOptBase
import paddle
import paddle.fluid as fluid
def lstm_net(
data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0,
):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False
)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = paddle.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(x=cost)
return avg_cost
class TestIrMemOptRNN(TestIrMemOptBase):
def setUp(self):
self.network = lstm_net
if __name__ == "__main__":
unittest.main()
...@@ -2593,20 +2593,6 @@ class TestBook(LayerTest): ...@@ -2593,20 +2593,6 @@ class TestBook(LayerTest):
out = paddle.nn.functional.square_error_cost(input=x, label=y) out = paddle.nn.functional.square_error_cost(input=x, label=y)
return out return out
def test_dynamic_lstmp(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
hidden_dim, proj_dim = 16, 8
seq_data = layers.data(
name='seq_data', shape=[10, 10], dtype='float32', lod_level=1
)
fc_out = layers.fc(input=seq_data, size=4 * hidden_dim)
self.assertIsNotNone(
layers.dynamic_lstmp(
input=fc_out, size=4 * hidden_dim, proj_size=proj_dim
)
)
def test_lod_reset(self): def test_lod_reset(self):
# TODO(minqiyang): dygraph do not support lod now # TODO(minqiyang): dygraph do not support lod now
with self.static_graph(): with self.static_graph():
......
...@@ -20,9 +20,7 @@ import numpy as np ...@@ -20,9 +20,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers
random.seed(2) random.seed(2)
np.set_printoptions(threshold=np.inf) np.set_printoptions(threshold=np.inf)
...@@ -539,90 +537,5 @@ class TestCUDNNLstmOp(OpTest): ...@@ -539,90 +537,5 @@ class TestCUDNNLstmOp(OpTest):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
dtype = 'float32' if core.is_compiled_with_rocm() else 'float64'
input = fluid.data(
name='input', shape=[seq_len, batch_size, hidden_size], dtype=dtype
)
init_h = layers.fill_constant(
[num_layers, batch_size, hidden_size], dtype, 0.0
)
init_c = layers.fill_constant(
[num_layers, batch_size, hidden_size], dtype, 0.0
)
rnn_out, last_h, last_c = layers.lstm(
input,
init_h,
init_c,
seq_len,
hidden_size,
num_layers,
dropout_prob,
False,
)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size, hidden_size)
).astype("float64")
out = exe.run(
fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'],
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCUDNNlstmAPI(unittest.TestCase): # noqa: F811
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 2
dtype = 'float32' if core.is_compiled_with_rocm() else 'float64'
input = fluid.data(
name='input', shape=[seq_len, batch_size, hidden_size], dtype=dtype
)
init_h = layers.fill_constant(
[num_layers, batch_size, hidden_size], dtype, 0.0
)
init_c = layers.fill_constant(
[num_layers, batch_size, hidden_size], dtype, 0.0
)
rnn_out, last_h, last_c = layers.lstm(
input,
init_h,
init_c,
seq_len,
hidden_size,
num_layers,
dropout_prob,
False,
True,
)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size, hidden_size)
).astype(dtype)
out = exe.run(
fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'],
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,11 +17,6 @@ import unittest ...@@ -17,11 +17,6 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle import fluid
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.layers import fill_constant
from paddle.fluid.layers import lstm as LSTM
SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0 SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0 EXP_MAX_INPUT = 40.0
...@@ -132,130 +127,6 @@ def lstm( ...@@ -132,130 +127,6 @@ def lstm(
return hidden, cell return hidden, cell
class LstmUnitTestError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
batch_size = 20
seq_len = 100
dropout_prob = 0.2
hidden_size = 150
num_layers = 1
input = fluid.data(
name='input',
shape=[batch_size, seq_len, hidden_size],
dtype='float32',
)
pre_hidden = fill_constant(
[num_layers, batch_size, hidden_size], 'float32', 0.0
)
pre_cell = fill_constant(
[num_layers, batch_size, hidden_size], 'float32', 0.0
)
np_input = np.random.uniform(
-0.1, 0.1, (batch_size, seq_len, hidden_size)
).astype('float64')
np_pre_hidden = np.random.uniform(
-0.1, 0.1, (num_layers, batch_size, hidden_size)
).astype('float64')
np_pre_cell = np.random.uniform(
-0.1, 0.1, (num_layers, batch_size, hidden_size)
).astype('float64')
def test_input_Variable():
LSTM(
np_input,
pre_hidden,
pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_input_Variable)
def test_pre_hidden_Variable():
LSTM(
np_input,
np_pre_hidden,
pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_pre_hidden_Variable)
def test_pre_cell_Variable():
LSTM(
np_input,
pre_hidden,
np_pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_pre_cell_Variable)
def test_input_type():
error_input = fluid.data(
name='error_input',
shape=[None, hidden_size * 3],
dtype='int32',
)
LSTM(
error_input,
pre_hidden,
pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_input_type)
def test_pre_hidden_type():
error_pre_hidden = fluid.data(
name='error_pre_hidden',
shape=[None, hidden_size],
dtype='int32',
)
LSTM(
input,
error_pre_hidden,
pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_pre_hidden_type)
def test_pre_cell_type():
error_pre_cell = fluid.data(
name='error_pre_cell',
shape=[None, hidden_size],
dtype='int32',
)
LSTM(
input,
pre_hidden,
error_pre_cell,
seq_len,
hidden_size,
num_layers,
dropout_prob=dropout_prob,
)
self.assertRaises(TypeError, test_pre_cell_type)
class TestLstmOp(OpTest): class TestLstmOp(OpTest):
def set_is_test(self): def set_is_test(self):
self.is_test = False self.is_test = False
...@@ -374,47 +245,6 @@ class TestLstmOpInference(TestLstmOp): ...@@ -374,47 +245,6 @@ class TestLstmOpInference(TestLstmOp):
pass pass
class TestLstmOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstm(
input=input_data, size=2048, use_peepholes=False
)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32"
)
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstm(
input=in_data, size=2048, use_peepholes=False, h_0=h, c_0=c
)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32"
)
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstm(
input=in_data_,
size=2048,
use_peepholes=False,
h_0=h_,
c_0=c_,
)
self.assertRaises(TypeError, test_c_0)
# class TestLstmOpHasInitial(TestLstmOp): # class TestLstmOpHasInitial(TestLstmOp):
# def set_argument(self): # def set_argument(self):
# self.lod = [[2, 3, 2]] # self.lod = [[2, 3, 2]]
......
...@@ -17,9 +17,6 @@ import unittest ...@@ -17,9 +17,6 @@ import unittest
import numpy as np import numpy as np
import test_lstm_op as LstmTest import test_lstm_op as LstmTest
from paddle import fluid
from paddle.fluid import Program, program_guard
ACTIVATION = { ACTIVATION = {
'identity': LstmTest.identity, 'identity': LstmTest.identity,
'sigmoid': LstmTest.sigmoid, 'sigmoid': LstmTest.sigmoid,
...@@ -378,64 +375,5 @@ class TestLstmpOpLen0Case2(TestLstmpOp): ...@@ -378,64 +375,5 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
self.lod = [[2, 0, 3]] self.lod = [[2, 0, 3]]
class TestLstmpOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstmp(
input=input_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32"
)
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstmp(
input=in_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h,
c_0=c,
)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32"
)
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstmp(
input=in_data_,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h_,
c_0=c_,
)
self.assertRaises(TypeError, test_c_0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,7 +17,6 @@ import unittest ...@@ -17,7 +17,6 @@ import unittest
import numpy as np import numpy as np
import seresnext_net import seresnext_net
from fake_reader import fake_imdb_reader
from simple_nets import fc_with_batchnorm, init_data, simple_fc_net from simple_nets import fc_with_batchnorm, init_data, simple_fc_net
from test_parallel_executor_transformer import ( from test_parallel_executor_transformer import (
DeviceType, DeviceType,
...@@ -30,37 +29,6 @@ import paddle.fluid as fluid ...@@ -30,37 +29,6 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
def lstm_net(use_feed):
dict_dim = 5147
emb_dim = 128
hid_dim = 128
hid_dim2 = 96
class_dim = 2
emb_lr = 30.0
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1
)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False
)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = paddle.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(x=cost)
return avg_cost
def simple_fc_net_with_accuracy(use_feed): def simple_fc_net_with_accuracy(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32') img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -268,29 +236,6 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -268,29 +236,6 @@ class TestProgramPruneBackward(unittest.TestCase):
method=transformer, feed_dict=feed_dict, optimizer=optimizer method=transformer, feed_dict=feed_dict, optimizer=optimizer
) )
def test_lstm(self):
def optimizer():
optimizer = fluid.optimizer.Adagrad(
learning_rate=0.001,
regularization=fluid.regularizer.L2Decay(1e-4),
)
return optimizer
with self.program_scope_guard():
word_dict_size = 5147
reader = fake_imdb_reader(word_dict_size, 1)
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1
)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
feeder = fluid.DataFeeder(
feed_list=[data, label], place=core.CPUPlace()
)
feed_data = feeder.feed(reader())
self.check_prune_correctness(
method=lstm_net, feed_dict=feed_data, optimizer=optimizer
)
def test_cond(self): def test_cond(self):
def optimizer(): def optimizer():
optimizer = fluid.optimizer.SGD(learning_rate=0.01) optimizer = fluid.optimizer.SGD(learning_rate=0.01)
......
...@@ -91,7 +91,6 @@ HIGH_PARALLEL_JOB_NEW = [ ...@@ -91,7 +91,6 @@ HIGH_PARALLEL_JOB_NEW = [
'test_seqpool_concat_fuse_pass', 'test_seqpool_concat_fuse_pass',
'test_analyzer_save_model', 'test_analyzer_save_model',
'test_exception', 'test_exception',
'test_fc_lstm_fuse_pass',
'test_similarity_focus_op', 'test_similarity_focus_op',
'test_conv_batch_norm_mkldnn_fuse_pass', 'test_conv_batch_norm_mkldnn_fuse_pass',
'test_sequence_last_step', 'test_sequence_last_step',
...@@ -457,7 +456,6 @@ HIGH_PARALLEL_JOB_NEW = [ ...@@ -457,7 +456,6 @@ HIGH_PARALLEL_JOB_NEW = [
'test_spawn_and_init_parallel_env', 'test_spawn_and_init_parallel_env',
'test_fleet_gradient_scale', 'test_fleet_gradient_scale',
'unroll_array_ops_test', 'unroll_array_ops_test',
'test_fc_gru_fuse_pass',
'op_version_registry_test', 'op_version_registry_test',
'test_cudnn_placement_pass', 'test_cudnn_placement_pass',
'cipher_utils_test', 'cipher_utils_test',
...@@ -1188,7 +1186,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ ...@@ -1188,7 +1186,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_sigmoid_focal_loss', 'test_sigmoid_focal_loss',
'test_manual_seed', 'test_manual_seed',
'test_lrn_op', 'test_lrn_op',
'test_ir_memory_optimize_nlp',
'test_dataset_dataloader', 'test_dataset_dataloader',
'test_complex_variable', 'test_complex_variable',
'test_lite_engine', 'test_lite_engine',
...@@ -1199,7 +1196,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ ...@@ -1199,7 +1196,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_elementwise_sub_op', 'test_elementwise_sub_op',
'test_compare_op', 'test_compare_op',
'test_simnet', 'test_simnet',
'test_label_semantic_roles',
'test_normal', 'test_normal',
'test_tensor_scalar_type_promotion_static', 'test_tensor_scalar_type_promotion_static',
'test_trt_group_norm_op', 'test_trt_group_norm_op',
...@@ -1249,7 +1245,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ ...@@ -1249,7 +1245,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_input_spec', 'test_input_spec',
'test_adam_op', 'test_adam_op',
'test_elementwise_floordiv_op', 'test_elementwise_floordiv_op',
'test_eager_deletion_gru_net',
'test_diagonal_op', 'test_diagonal_op',
'test_imperative_static_runner_mnist', 'test_imperative_static_runner_mnist',
'test_nearest_interp_op', 'test_nearest_interp_op',
...@@ -1468,7 +1463,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ ...@@ -1468,7 +1463,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_nearest_interp_v2_op', 'test_nearest_interp_v2_op',
'test_sequence_slice_op', 'test_sequence_slice_op',
'test_program_translator', 'test_program_translator',
'test_eager_deletion_lstm_net',
'malloc_test', 'malloc_test',
'test_size_op', 'test_size_op',
'test_analysis_predictor', 'test_analysis_predictor',
...@@ -1906,8 +1900,6 @@ CPU_PARALLEL_JOB = [ ...@@ -1906,8 +1900,6 @@ CPU_PARALLEL_JOB = [
'test_fetch_handler', 'test_fetch_handler',
'test_feed_fetch_method', 'test_feed_fetch_method',
'test_fc_mkldnn_op', 'test_fc_mkldnn_op',
'test_fc_lstm_fuse_pass',
'test_fc_gru_fuse_pass',
'test_fc_elementwise_layernorm_fuse_pass_cc', 'test_fc_elementwise_layernorm_fuse_pass_cc',
'test_fc_bf16_mkldnn_op', 'test_fc_bf16_mkldnn_op',
'test_executor_feed_non_tensor', 'test_executor_feed_non_tensor',
......
...@@ -162,8 +162,6 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -162,8 +162,6 @@ STATIC_MODE_TESTING_LIST = [
'test_dynrnn_static_input', 'test_dynrnn_static_input',
'test_eager_deletion_conditional_block', 'test_eager_deletion_conditional_block',
'test_eager_deletion_delete_vars', 'test_eager_deletion_delete_vars',
'test_eager_deletion_gru_net',
'test_eager_deletion_lstm_net',
'test_eager_deletion_padding_rnn', 'test_eager_deletion_padding_rnn',
'test_eager_deletion_recurrent_op', 'test_eager_deletion_recurrent_op',
'test_eager_deletion_while_op', 'test_eager_deletion_while_op',
...@@ -586,8 +584,6 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -586,8 +584,6 @@ STATIC_MODE_TESTING_LIST = [
'test_conv_elementwise_add_act_fuse_pass', 'test_conv_elementwise_add_act_fuse_pass',
'test_conv_elementwise_add_fuse_pass', 'test_conv_elementwise_add_fuse_pass',
'test_fc_fuse_pass', 'test_fc_fuse_pass',
'test_fc_gru_fuse_pass',
'test_fc_lstm_fuse_pass',
'test_repeated_fc_relu_fuse_pass', 'test_repeated_fc_relu_fuse_pass',
'test_seqconv_eltadd_relu_fuse_pass', 'test_seqconv_eltadd_relu_fuse_pass',
'test_squared_mat_sub_fuse_pass', 'test_squared_mat_sub_fuse_pass',
...@@ -683,7 +679,6 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -683,7 +679,6 @@ STATIC_MODE_TESTING_LIST = [
'test_fleet_rolemaker_new', 'test_fleet_rolemaker_new',
'test_fused_fc_elementwise_layernorm_op', 'test_fused_fc_elementwise_layernorm_op',
'test_fusion_transpose_flatten_concat_op', 'test_fusion_transpose_flatten_concat_op',
'test_ir_memory_optimize_nlp',
'test_nvprof', 'test_nvprof',
'test_pipeline', 'test_pipeline',
'test_weight_decay', 'test_weight_decay',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册