未验证 提交 df2eee71 编写于 作者: H Hongyu Liu 提交者: GitHub

Sequence mask support tensor (#18249)

* sequnce mask support max length tensor input; test=develop

* add rnn_impl.py; test=develop

* add basic gru lstm unittest; test=develop

* fix api spec; test=develop

* fix sequence_mask op bug;
test=develop
test=document_preview

* change +-*x to elmentwise_op; test=develop

* add mkl flag; test=develop

* fix rnn impl bug; test=develop

* update api spec; test=develop

* fix doc bug; test=develop

* fix lstm bugs; test=develop
上级 9cb799be
......@@ -430,6 +430,38 @@ paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_p
paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4'))
paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'incr_every_n_steps', 'decr_every_n_nan_or_inf', 'incr_ratio', 'decr_ratio', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, 1000, 2, 2.0, 0.8, False)), ('document', 'bdb8f9dbb0d94b3957272c53eeee9818'))
paddle.fluid.contrib.fused_elemwise_activation (ArgSpec(args=['x', 'y', 'functor_list', 'axis', 'scale', 'save_intermediate_out'], varargs=None, keywords=None, defaults=(-1, 0.0, True)), ('document', '1c4b247a2858cea8d9d8750693688270'))
paddle.fluid.contrib.BasicGRUUnit.__init__ (ArgSpec(args=['self', 'name_scope', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'dtype'], varargs=None, keywords=None, defaults=(None, None, None, None, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
paddle.fluid.contrib.BasicGRUUnit.add_sublayer (ArgSpec(args=['self', 'name', 'sublayer'], varargs=None, keywords=None, defaults=None), ('document', '839ff3c0534677ba6ad8735c3fd4e995'))
paddle.fluid.contrib.BasicGRUUnit.backward (ArgSpec(args=['self'], varargs='inputs', keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.clear_gradients (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.create_parameter (ArgSpec(args=['self', 'attr', 'shape', 'dtype', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'a6420ca1455366eaaf972191612de0b6'))
paddle.fluid.contrib.BasicGRUUnit.create_variable (ArgSpec(args=['self', 'name', 'persistable', 'dtype', 'type'], varargs=None, keywords=None, defaults=(None, None, None, VarType.LOD_TENSOR)), ('document', '171cccfceba636d5bbf7bbae672945d8'))
paddle.fluid.contrib.BasicGRUUnit.eval (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.forward (ArgSpec(args=['self', 'input', 'pre_hidden'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.full_name (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '23ce4f961f48ed0f79cadf93a3938ed2'))
paddle.fluid.contrib.BasicGRUUnit.load_dict (ArgSpec(args=['self', 'stat_dict', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.parameters (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '5aec25a854eb57abc798dccccbb507d5'))
paddle.fluid.contrib.BasicGRUUnit.state_dict (ArgSpec(args=['self', 'destination', 'include_sublayers'], varargs=None, keywords=None, defaults=(None, True)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.sublayers (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '00a881005ecbc96578faf94513bf0d62'))
paddle.fluid.contrib.BasicGRUUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.basic_gru (ArgSpec(args=['input', 'init_hidden', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 'float32', 'basic_gru')), ('document', '0afcbe4fbe1b8c35eda58b4efe48f9fd'))
paddle.fluid.contrib.BasicLSTMUnit.__init__ (ArgSpec(args=['self', 'name_scope', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype'], varargs=None, keywords=None, defaults=(None, None, None, None, 1.0, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
paddle.fluid.contrib.BasicLSTMUnit.add_sublayer (ArgSpec(args=['self', 'name', 'sublayer'], varargs=None, keywords=None, defaults=None), ('document', '839ff3c0534677ba6ad8735c3fd4e995'))
paddle.fluid.contrib.BasicLSTMUnit.backward (ArgSpec(args=['self'], varargs='inputs', keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.clear_gradients (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.create_parameter (ArgSpec(args=['self', 'attr', 'shape', 'dtype', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'a6420ca1455366eaaf972191612de0b6'))
paddle.fluid.contrib.BasicLSTMUnit.create_variable (ArgSpec(args=['self', 'name', 'persistable', 'dtype', 'type'], varargs=None, keywords=None, defaults=(None, None, None, VarType.LOD_TENSOR)), ('document', '171cccfceba636d5bbf7bbae672945d8'))
paddle.fluid.contrib.BasicLSTMUnit.eval (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.forward (ArgSpec(args=['self', 'input', 'pre_hidden', 'pre_cell'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.full_name (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '23ce4f961f48ed0f79cadf93a3938ed2'))
paddle.fluid.contrib.BasicLSTMUnit.load_dict (ArgSpec(args=['self', 'stat_dict', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.parameters (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '5aec25a854eb57abc798dccccbb507d5'))
paddle.fluid.contrib.BasicLSTMUnit.state_dict (ArgSpec(args=['self', 'destination', 'include_sublayers'], varargs=None, keywords=None, defaults=(None, True)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicLSTMUnit.sublayers (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '00a881005ecbc96578faf94513bf0d62'))
paddle.fluid.contrib.BasicLSTMUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.basic_lstm (ArgSpec(args=['input', 'init_hidden', 'init_cell', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 1.0, 'float32', 'basic_lstm')), ('document', 'fe4d0c3c55a162b8cfe10b05fabb7ce4'))
paddle.fluid.dygraph.Layer.__init__ (ArgSpec(args=['self', 'name_scope', 'dtype'], varargs=None, keywords=None, defaults=(VarType.FP32,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.dygraph.Layer.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
paddle.fluid.dygraph.Layer.add_sublayer (ArgSpec(args=['self', 'name', 'sublayer'], varargs=None, keywords=None, defaults=None), ('document', '839ff3c0534677ba6ad8735c3fd4e995'))
......
......@@ -35,7 +35,9 @@ template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>;
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
......
......@@ -13,6 +13,80 @@
// limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
#include <string>
namespace paddle {
namespace operators {
class SequenceMaskOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
int maxlen = ctx->Attrs().Get<int>("maxlen");
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
if (ctx->HasInputs("MaxLenTensor")) {
dim.push_back(-1);
} else {
dim.push_back(maxlen > 0 ? maxlen : -1);
}
ctx->SetOutputDim("Y", framework::make_ddim(dim));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "depth_tensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of sequence_mask op.");
AddOutput("Y", "The output mask of sequence_mask op.");
AddInput("MaxLenTensor",
"Max length tensor"
"have higher priority than maxlen attribute")
.AsDispensable();
AddAttr<int>("maxlen",
"The maximum length of the sequence. If maxlen < 0, maxlen "
"= max(Input(X)).")
.SetDefault(-1)
.AddCustomChecker([](const int& v) {
PADDLE_ENFORCE(v < 0 || v >= 1,
"Attr(maxlen) must be less than 0 or larger than 1");
});
AddAttr<int>("out_dtype", "Output data type");
AddComment(R"DOC(
SequenceMask Operator
This operator outputs a Mask according to Input(X) and Attr(maxlen).
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
If maxlen < 0, maxlen = max(X)
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp,
paddle::operators::SequenceMaskOpMaker,
......
......@@ -28,48 +28,8 @@
namespace paddle {
namespace operators {
class SequenceMaskOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
int maxlen = ctx->Attrs().Get<int>("maxlen");
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
dim.push_back(maxlen > 0 ? maxlen : -1);
ctx->SetOutputDim("Y", framework::make_ddim(dim));
}
};
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of sequence_mask op.");
AddOutput("Y", "The output mask of sequence_mask op.");
AddAttr<int>("maxlen",
"The maximum length of the sequence. If maxlen < 0, maxlen "
"= max(Input(X)).")
.SetDefault(-1)
.AddCustomChecker([](const int &v) {
PADDLE_ENFORCE(v < 0 || v >= 1,
"Attr(maxlen) must be less than 0 or larger than 1");
});
AddAttr<int>("out_dtype", "Output data type");
AddComment(R"DOC(
SequenceMask Operator
This operator outputs a Mask according to Input(X) and Attr(maxlen).
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
If maxlen < 0, maxlen = max(X)
)DOC");
}
};
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename Tx, typename Ty>
struct SequenceMaskForRangeFunctor {
......@@ -90,8 +50,6 @@ struct SequenceMaskForRangeFunctor {
template <typename DeviceContext, typename Tx>
struct SequenceMaskFunctor {
using Tensor = framework::LoDTensor;
SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y,
int limits, int maxlen)
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
......@@ -119,7 +77,25 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto maxlen = ctx.Attr<int>("maxlen");
int maxlen = ctx.Attr<int>("maxlen");
if (ctx.HasInput("MaxLenTensor")) {
auto max_len_tensor = ctx.Input<Tensor>("MaxLenTensor");
PADDLE_ENFORCE(max_len_tensor != NULL, "MaxLenTensor is NULL");
if (platform::is_gpu_place(max_len_tensor->place())) {
framework::Tensor temp;
TensorCopySync(*max_len_tensor, platform::CPUPlace(), &temp);
maxlen = *temp.data<int32_t>();
} else {
maxlen = *max_len_tensor->data<int32_t>();
}
auto y_dim = framework::vectorize2int(x->dims());
y_dim.push_back(maxlen);
y->Resize(framework::make_ddim(y_dim));
PADDLE_ENFORCE_GT(maxlen, 0,
"MaxLenTensor value should be greater than 0");
}
auto *x_data = x->data<Tx>();
auto x_numel = x->numel();
......
......@@ -55,4 +55,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int32_t>,
ops::ShapeKernel<float>, ops::ShapeKernel<double>);
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>);
......@@ -16,5 +16,6 @@ limitations under the License. */
REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int32_t>,
paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>);
......@@ -92,7 +92,7 @@ class SliceOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace());
ctx.Input<Tensor>("Input")->place());
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include <memory>
#include <string>
#include <vector>
......@@ -289,8 +290,12 @@ REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -30,11 +30,15 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext,
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
from . import nn
from .nn import *
from .rnn_impl import *
__all__ = []
__all__ += nn.__all__
__all__ += rnn_impl.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
from paddle.fluid.layers.control_flow import StaticRNN
__all__ = ['BasicGRUUnit', 'basic_gru', 'BasicLSTMUnit', 'basic_lstm']
class BasicGRUUnit(Layer):
"""
****
BasicGRUUnit class, using basic operators to build GRU
The algorithm can be described as the equations below.
.. math::
u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
Args:
name_scope(string) : The name scope used to identify parameters and biases
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of GRU unit.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
dtype(string): data type used in this unit
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import BasicGRUUnit
input_size = 128
hidden_size = 256
input = layers.data( name = "input", shape = [-1, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
gru_unit = BasicGRUUnit( "gru_unit", hidden_size )
new_hidden = gru_unit( input, pre_hidden )
"""
def __init__(self,
name_scope,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32'):
super(BasicGRUUnit, self).__init__(name_scope, dtype)
self._name = name_scope
self._hiden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._dtype = dtype
def _build_once(self, input, pre_hidden):
self._input_size = input.shape[-1]
assert (self._input_size > 0)
self._gate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 2 * self._hiden_size],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, self._hiden_size],
dtype=self._dtype)
self._gate_bias = self.create_parameter(
self._bias_attr,
shape=[2 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
self._bias_attr,
shape=[self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, pre_hidden):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = layers.elementwise_add(gate_input, self._gate_bias)
gate_input = self._gate_activation(gate_input)
r, u = layers.split(gate_input, num_or_sections=2, dim=1)
r_hidden = r * pre_hidden
candidate = layers.matmul(
layers.concat([input, pre_hidden], 1), self._candidate_weight)
candidate = layers.elementwise_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
def basic_gru(input,
init_hidden,
hidden_size,
num_layers=1,
sequence_length=None,
dropout_prob=0.0,
bidirectional=False,
batch_first=True,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32',
name='basic_gru'):
"""
GRU implementation using basic operator, supports multiple layers and bidirection gru.
.. math::
u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
Args:
input (Variable): GRU input tensor,
if batch_first = False, shape should be ( seq_len x batch_size x input_size )
if batch_first = True, shape should be ( batch_size x seq_len x hidden_size )
init_hidden(Variable|None): The initial hidden state of the GRU
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to tensor with ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
hidden_size (int): Hidden size of the GRU
num_layers (int): The total number of layers of the GRU
sequence_length (Variabe|None): A Tensor (shape [batch_size]) stores each real length of each instance,
This tensor will be convert to a mask to mask the padding ids
If it's None means NO padding ids
dropout_prob(float|0.0): Dropout prob, dropout ONLY works after rnn output of earch layers,
NOT between time steps
bidirectional (bool|False): If it is bidirectional
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of GRU unit.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
dtype(string): data type used in this unit
name(string): name used to identify parameters and biases
Returns:
rnn_out(Tensor),last_hidden(Tensor)
- rnn_out is result of GRU hidden, with shape (seq_len x batch_size x hidden_size) \
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
- last_hidden is the hidden state of the last step of GRU \
shape is ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size),
can be reshaped to a tensor with shape( num_layers x 2 x batch_size x hidden_size)
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import basic_gru
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 2
dropout = 0.5
bidirectional = True
batch_first = False
input = layers.data( name = "input", shape = [-1, batch_size, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data( name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden = basic_gru( input, pre_hidden, hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
batch_first = batch_first)
"""
fw_unit_list = []
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
fw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
if bidirectional:
bw_unit_list = []
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
bw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
if batch_first:
input = layers.transpose(input, [1, 0, 2])
mask = None
if sequence_length:
max_seq_len = layers.shape(input)[0]
mask = layers.sequence_mask(
sequence_length, maxlen=max_seq_len, dtype='float32')
mask = layers.transpose(mask, [1, 0])
direc_num = 1
if bidirectional:
direc_num = 2
if init_hidden:
init_hidden = layers.reshape(
init_hidden, shape=[num_layers, direc_num, -1, hidden_size])
def get_single_direction_output(rnn_input,
unit_list,
mask=None,
direc_index=0):
rnn = StaticRNN()
with rnn.step():
step_input = rnn.step_input(rnn_input)
if mask:
step_mask = rnn.step_input(mask)
for i in range(num_layers):
if init_hidden:
pre_hidden = rnn.memory(init=init_hidden[i, direc_index])
else:
pre_hidden = rnn.memory(
batch_ref=rnn_input,
shape=[-1, hidden_size],
ref_batch_dim_idx=1)
new_hidden = unit_list[i](step_input, pre_hidden)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden, step_mask, axis=0) - layers.elementwise_mul(
pre_hidden, (step_mask - 1), axis=0)
rnn.update_memory(pre_hidden, new_hidden)
rnn.step_output(new_hidden)
step_input = new_hidden
if dropout_prob != None and dropout_prob > 0.0:
step_input = layers.dropout(
step_input,
dropout_prob=dropout_prob, )
rnn.step_output(step_input)
rnn_out = rnn()
last_hidden_array = []
rnn_output = rnn_out[-1]
for i in range(num_layers):
last_hidden = rnn_out[i]
last_hidden = last_hidden[-1]
last_hidden_array.append(last_hidden)
last_hidden_output = layers.concat(last_hidden_array, axis=0)
last_hidden_output = layers.reshape(
last_hidden_output, shape=[num_layers, -1, hidden_size])
return rnn_output, last_hidden_output
# seq_len, batch_size, hidden_size
fw_rnn_out, fw_last_hidden = get_single_direction_output(
input, fw_unit_list, mask, direc_index=0)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
last_hidden = layers.reshape(
last_hidden, shape=[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
if batch_first:
rnn_out = fluid.layser.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
def basic_lstm(input,
init_hidden,
init_cell,
hidden_size,
num_layers=1,
sequence_length=None,
dropout_prob=0.0,
bidirectional=False,
batch_first=True,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32',
name='basic_lstm'):
"""
LSTM implementation using basic operators, supports multiple layers and bidirection LSTM.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
Args:
input (Variable): lstm input tensor,
if batch_first = False, shape should be ( seq_len x batch_size x input_size )
if batch_first = True, shape should be ( batch_size x seq_len x hidden_size )
init_hidden(Variable|None): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to a tensor with shape ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
init_cell(Variable|None): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to a tensor with shape ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
hidden_size (int): Hidden size of the LSTM
num_layers (int): The total number of layers of the LSTM
sequence_length (Variabe|None): A tensor (shape [batch_size]) stores each real length of each instance,
This tensor will be convert to a mask to mask the padding ids
If it's None means NO padding ids
dropout_prob(float|0.0): Dropout prob, dropout ONLY work after rnn output of earch layers,
NOT between time steps
bidirectional (bool|False): If it is bidirectional
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
forget_bias (float|1.0) : Forget bias used to compute the forget gate
dtype(string): Data type used in this unit
name(string): Name used to identify parameters and biases
Returns:
rnn_out(Tensor), last_hidden(Tensor), last_cell(Tensor)
- rnn_out is the result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) \
if is_bidirec set to True, it's shape will be ( seq_len x batch_sze x hidden_size*2)
- last_hidden is the hidden state of the last step of LSTM \
with shape ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, it's shape will be ( num_layers*2 x batch_size x hidden_size),
and can be reshaped to a tensor ( num_layers x 2 x batch_size x hidden_size) to use.
- last_cell is the hidden state of the last step of LSTM \
with shape ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, it's shape will be ( num_layers*2 x batch_size x hidden_size),
and can be reshaped to a tensor ( num_layers x 2 x batch_size x hidden_size) to use.
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import basic_lstm
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 2
dropout = 0.5
bidirectional = True
batch_first = False
input = layers.data( name = "input", shape = [-1, batch_size, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
pre_cell = layers.data( name = "pre_cell", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data( name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden, last_cell = basic_lstm( input, pre_hidden, pre_cell, \
hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
batch_first = batch_first)
"""
fw_unit_list = []
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
fw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
dtype=dtype))
if bidirectional:
bw_unit_list = []
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
bw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
dtype=dtype))
if batch_first:
input = layers.transpose(input, [1, 0, 2])
mask = None
if sequence_length:
max_seq_len = layers.shape(input)[0]
mask = layers.sequence_mask(
sequence_length, maxlen=max_seq_len, dtype='float32')
mask = layers.transpose(mask, [1, 0])
direc_num = 1
if bidirectional:
direc_num = 2
# convert to [num_layers, 2, batch_size, hidden_size]
if init_hidden:
init_hidden = layers.reshape(
init_hidden, shape=[num_layers, direc_num, -1, hidden_size])
init_cell = layers.reshape(
init_cell, shape=[num_layers, direc_num, -1, hidden_size])
# forward direction
def get_single_direction_output(rnn_input,
unit_list,
mask=None,
direc_index=0):
rnn = StaticRNN()
with rnn.step():
step_input = rnn.step_input(rnn_input)
if mask:
step_mask = rnn.step_input(mask)
for i in range(num_layers):
if init_hidden:
pre_hidden = rnn.memory(init=init_hidden[i, direc_index])
pre_cell = rnn.memory(init=init_cell[i, direc_index])
else:
pre_hidden = rnn.memory(
batch_ref=rnn_input, shape=[-1, hidden_size])
pre_cell = rnn.memory(
batch_ref=rnn_input, shape=[-1, hidden_size])
new_hidden, new_cell = unit_list[i](step_input, pre_hidden,
pre_cell)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden, step_mask, axis=0) - layers.elementwise_mul(
pre_hidden, (step_mask - 1), axis=0)
new_cell = layers.elementwise_mul(
new_cell, step_mask, axis=0) - layers.elementwise_mul(
pre_cell, (step_mask - 1), axis=0)
rnn.update_memory(pre_hidden, new_hidden)
rnn.update_memory(pre_cell, new_cell)
rnn.step_output(new_hidden)
rnn.step_output(new_cell)
step_input = new_hidden
if dropout_prob != None and dropout_prob > 0.0:
step_input = layers.dropout(
step_input,
dropout_prob=dropout_prob,
dropout_implementation='upscale_in_train')
rnn.step_output(step_input)
rnn_out = rnn()
last_hidden_array = []
last_cell_array = []
rnn_output = rnn_out[-1]
for i in range(num_layers):
last_hidden = rnn_out[i * 2]
last_hidden = last_hidden[-1]
last_hidden_array.append(last_hidden)
last_cell = rnn_out[i * 2 + 1]
last_cell = last_cell[-1]
last_cell_array.append(last_cell)
last_hidden_output = layers.concat(last_hidden_array, axis=0)
last_hidden_output = layers.reshape(
last_hidden_output, shape=[num_layers, -1, hidden_size])
last_cell_output = layers.concat(last_cell_array, axis=0)
last_cell_output = layers.reshape(
last_cell_output, shape=[num_layers, -1, hidden_size])
return rnn_output, last_hidden_output, last_cell_output
# seq_len, batch_size, hidden_size
fw_rnn_out, fw_last_hidden, fw_last_cell = get_single_direction_output(
input, fw_unit_list, mask, direc_index=0)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden, bw_last_cell = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
last_hidden = layers.reshape(
last_hidden, shape=[num_layers * direc_num, -1, hidden_size])
last_cell = layers.concat([fw_last_cell, bw_last_cell], axis=1)
last_cell = layers.reshape(
last_cell, shape=[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
last_cell = fw_last_cell
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
class BasicLSTMUnit(Layer):
"""
****
BasicLSTMUnit class, Using basic operator to build LSTM
The algorithm can be described as the code below.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The :math:`\odot` is the element-wise product of the vectors.
- :math:`tanh` is the activation functions.
- :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Args:
name_scope(string) : The name scope used to identify parameter and bias name
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized as zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cells (actNode).
Default: 'fluid.layers.tanh'
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import BasicLSTMUnit
input_size = 128
hidden_size = 256
input = layers.data( name = "input", shape = [-1, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
pre_cell = layers.data( name = "pre_cell", shape=[-1, hidden_size], dtype='float32')
lstm_unit = BasicLSTMUnit( "gru_unit", hidden_size)
new_hidden, new_cell = lstm_unit( input, pre_hidden, pre_cell )
"""
def __init__(self,
name_scope,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(name_scope, dtype)
self._name = name_scope
self._hiden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant(
[1], dtype=dtype, value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
def _build_once(self, input, pre_hidden, pre_cell):
self._input_size = input.shape[-1]
assert (self._input_size > 0)
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 4 * self._hiden_size],
dtype=self._dtype)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, pre_hidden, pre_cell):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
return new_hidden, new_cell
......@@ -449,7 +449,7 @@ class StaticRNN(object):
raise TypeError("step input takes a Variable")
if self.seq_len is None:
self.seq_len = x.shape[0]
elif self.seq_len != x.shape[0]:
elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(
......
......@@ -9244,14 +9244,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
else:
out = helper.create_variable_for_type_inference(dtype=dtype, name=name)
inputs = {'X': [x]}
attrs = {'out_dtype': out.dtype}
if maxlen is not None:
if isinstance(maxlen, Variable):
inputs['MaxLenTensor'] = maxlen
else:
attrs['maxlen'] = maxlen
helper.append_op(
type='sequence_mask',
inputs={'X': [x]},
outputs={'Y': out},
attrs={
'maxlen': maxlen if maxlen is not None else -1,
'out_dtype': out.dtype
})
type='sequence_mask', inputs=inputs, outputs={'Y': out}, attrs=attrs)
out.stop_gradient = True
return out
......
......@@ -118,6 +118,10 @@ list(REMOVE_ITEM TEST_OPS test_layers)
list(REMOVE_ITEM TEST_OPS test_imperative_ocr_attention_model)
list(REMOVE_ITEM TEST_OPS test_async_ssa_graph_executor_mnist)
list(REMOVE_ITEM TEST_OPS test_install_check)
list(REMOVE_ITEM TEST_OPS test_basic_gru_api)
list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
......@@ -161,6 +165,10 @@ py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_
py_test_modules(test_warpctc_op MODULES test_warpctc_op)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS})
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS})
py_test_modules(test_basic_gru_api MODULES test_basic_gru_api ENVS MKL_CBWR=COMPATIBLE)
py_test_modules(test_basic_gru_unit_op MODULES test_basic_gru_unit_op ENVS MKL_CBWR=COMPATIBLE)
py_test_modules(test_basic_lstm_api MODULES test_basic_lstm_api ENVS MKL_CBWR=COMPATIBLE)
py_test_modules(test_basic_lstm_unit_op MODULES test_basic_lstm_unit_op ENVS MKL_CBWR=COMPATIBLE)
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
FLAGS_cudnn_deterministic=1 SERIAL)
set_tests_properties(test_imperative_resnet PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid.contrib.layers import basic_gru
from paddle.fluid.executor import Executor
from paddle.fluid import framework
import numpy as np
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def gru_np(input,
init_h,
hidden_size,
gate_weight,
gate_bias,
candidate_weight,
candidate_bias,
num_layers=1,
batch_first=False,
is_bidirect=False,
sequence_length=None):
def step(step_in, pre_hidden, gate_w, gate_b, candidate_w, candidate_b):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
gate_input = sigmoid(gate_input)
r, u = np.split(gate_input, indices_or_sections=2, axis=1)
r_hidden = r * pre_hidden
candidate = np.matmul(
np.concatenate([step_in, pre_hidden], 1), candidate_w)
candidate += candidate_b
c = tanh(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
if batch_first:
input = np.tranpose(input, [1, 0, 2])
batch_size = input.shape[1]
mask = None
if sequence_length is not None:
max_seq_len = input.shape[0]
mask = np.zeros([batch_size, max_seq_len])
for i, len in enumerate(sequence_length):
mask[i, :len] = 1.0
mask = np.transpose(mask, [1, 0])
direc_num = 1
if is_bidirect:
direc_num = 2
if init_h:
init_h = np.reshape(
init_h, shape=[num_layers, direc_num, -1, hidden_size])
else:
init_h = np.zeros([num_layers, direc_num, batch_size, hidden_size])
def get_single_direction_output(rnn_input, mask=None, direc_index=0):
seq_len = rnn_input.shape[0]
output = []
# init pre hidden
pre_hidden_array = []
for i in range(num_layers):
pre_hidden_array.append(init_h[i, direc_index])
for i in range(seq_len):
step_input = rnn_input[i]
if mask is not None:
step_mask = mask[i]
step_mask = np.reshape(step_mask, [-1, 1])
for i in range(num_layers):
new_hidden = step(
step_input, pre_hidden_array[i],
gate_weight[direc_index * num_layers + i],
gate_bias[direc_index * num_layers + i],
candidate_weight[direc_index * num_layers + i],
candidate_bias[direc_index * num_layers + i])
if mask is not None:
new_hidden = new_hidden * step_mask + (
1 - step_mask) * pre_hidden_array[i]
pre_hidden_array[i] = new_hidden
step_input = new_hidden
output.append(step_input)
rnn_out = np.concatenate(output, 0)
rnn_out = np.reshape(rnn_out, [seq_len, -1, hidden_size])
last_hidden_out = np.concatenate(pre_hidden_array, 0)
last_hidden_out = np.reshape(last_hidden_out,
[num_layers, -1, hidden_size])
return rnn_out, last_hidden_out
fw_rnn_out, fw_last_hidden = get_single_direction_output(
input, mask, direc_index=0)
if is_bidirect:
bw_input = input[::-1]
bw_mask = None
if mask is not None:
bw_mask = mask[::-1]
bw_rnn_out, bw_last_hidden = get_single_direction_output(
bw_input, bw_mask, direc_index=1)
bw_rnn_out = bw_rnn_out[::-1]
rnn_out = np.concatenate([fw_rnn_out, bw_rnn_out], 2)
last_hidden = np.concatenate([fw_last_hidden, bw_last_hidden], 1)
last_hidden = np.reshape(last_hidden,
[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = np.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
if batch_first:
rnn_out = np.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
class TestBasicGRUApi(unittest.TestCase):
def setUp(self):
self.hidden_size = 10
self.batch_size = 5
self.seq_len = 6
self.num_layers = 2
self.is_bidirect = True
self.batch_first = False
def test_run(self):
x = layers.data(
name='x',
shape=[-1, self.batch_size, self.hidden_size],
dtype='float32')
sequence_length = layers.data(
name="sequence_length", shape=[-1], dtype='float32')
rnn_out, last_hidden = basic_gru( x, None, self.hidden_size, num_layers=self.num_layers, \
batch_first = self.batch_first, bidirectional=self.is_bidirect, sequence_length=sequence_length )
last_hidden.persisbale = True
rnn_out.persisbale = True
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
param_list = fluid.default_main_program().block(0).all_parameters()
# process weight and bias
gate_weight = []
gate_bias = []
candidate_weight = []
candidate_bias = []
for i in range(self.num_layers):
gate_w_name = "basic_gru_layers_" + str(i) + "/BasicGRUUnit_0.w_0"
gate_b_name = "basic_gru_layers_" + str(i) + "/BasicGRUUnit_0.b_0"
candidate_w_name = "basic_gru_layers_" + str(
i) + "/BasicGRUUnit_0.w_1"
candidate_b_name = "basic_gru_layers_" + str(
i) + "/BasicGRUUnit_0.b_1"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name)
.get_tensor())
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(gate_w,
place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name)
.get_tensor())
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(gate_b,
place)
candidate_w = np.array(fluid.global_scope().find_var(
candidate_w_name).get_tensor())
candidate_w = np.random.uniform(
-0.1, 0.1, size=candidate_w.shape).astype('float32')
fluid.global_scope().find_var(candidate_w_name).get_tensor().set(
candidate_w, place)
candidate_b = np.array(fluid.global_scope().find_var(
candidate_b_name).get_tensor())
candidate_b = np.random.uniform(
-0.1, 0.1, size=candidate_b.shape).astype('float32')
fluid.global_scope().find_var(candidate_b_name).get_tensor().set(
candidate_b, place)
gate_weight.append(gate_w)
gate_bias.append(gate_b)
candidate_weight.append(candidate_w)
candidate_bias.append(candidate_b)
if self.is_bidirect:
for i in range(self.num_layers):
gate_w_name = "basic_gru_reverse_layers_" + str(
i) + "/BasicGRUUnit_0.w_0"
gate_b_name = "basic_gru_reverse_layers_" + str(
i) + "/BasicGRUUnit_0.b_0"
candidate_w_name = "basic_gru_reverse_layers_" + str(
i) + "/BasicGRUUnit_0.w_1"
candidate_b_name = "basic_gru_reverse_layers_" + str(
i) + "/BasicGRUUnit_0.b_1"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name)
.get_tensor())
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(
gate_w, place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name)
.get_tensor())
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(
gate_b, place)
candidate_w = np.array(fluid.global_scope().find_var(
candidate_w_name).get_tensor())
candidate_w = np.random.uniform(
-0.1, 0.1, size=candidate_w.shape).astype('float32')
fluid.global_scope().find_var(candidate_w_name).get_tensor(
).set(candidate_w, place)
candidate_b = np.array(fluid.global_scope().find_var(
candidate_b_name).get_tensor())
candidate_b = np.random.uniform(
-0.1, 0.1, size=candidate_b.shape).astype('float32')
fluid.global_scope().find_var(candidate_b_name).get_tensor(
).set(candidate_b, place)
gate_weight.append(gate_w)
gate_bias.append(gate_b)
candidate_weight.append(candidate_w)
candidate_bias.append(candidate_b)
step_input_np = np.random.uniform(-0.1, 0.1, (
self.seq_len, self.batch_size, self.hidden_size)).astype('float32')
sequence_length_np = np.random.randint(
self.seq_len // 2, self.seq_len,
size=(self.batch_size)).astype('int64')
out = exe.run(
feed={'x': step_input_np,
'sequence_length': sequence_length_np},
fetch_list=[rnn_out, last_hidden])
api_rnn_out = out[0]
api_last_hidden = out[1]
np_out = gru_np(
step_input_np,
None,
self.hidden_size,
gate_weight,
gate_bias,
candidate_weight,
candidate_bias,
num_layers=self.num_layers,
batch_first=self.batch_first,
is_bidirect=self.is_bidirect,
sequence_length=sequence_length_np)
self.assertTrue(np.allclose(api_rnn_out, np_out[0], rtol=1e-4, atol=0))
self.assertTrue(
np.allclose(
api_last_hidden, np_out[1], rtol=1e-4, atol=0))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid.contrib.layers import BasicGRUUnit
from paddle.fluid.executor import Executor
from paddle.fluid import framework
import numpy as np
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def step(step_in, pre_hidden, gate_w, gate_b, candidate_w, candidate_b):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
gate_input = sigmoid(gate_input)
r, u = np.split(gate_input, indices_or_sections=2, axis=1)
r_hidden = r * pre_hidden
candidate = np.matmul(np.concatenate([step_in, pre_hidden], 1), candidate_w)
candidate += candidate_b
c = tanh(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
class TestBasicGRUUnit(unittest.TestCase):
def setUp(self):
self.hidden_size = 5
self.batch_size = 5
def test_run(self):
x = layers.data(name='x', shape=[-1, self.hidden_size], dtype='float32')
pre_hidden = layers.data(
name="pre_hidden", shape=[-1, self.hidden_size], dtype='float32')
gru_unit = BasicGRUUnit("gru_unit", self.hidden_size)
new_hidden = gru_unit(x, pre_hidden)
new_hidden.persisbale = True
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
param_list = fluid.default_main_program().block(0).all_parameters()
# process weight and bias
gate_w_name = "gru_unit/BasicGRUUnit_0.w_0"
gate_b_name = "gru_unit/BasicGRUUnit_0.b_0"
candidate_w_name = "gru_unit/BasicGRUUnit_0.w_1"
candidate_b_name = "gru_unit/BasicGRUUnit_0.b_1"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name).get_tensor(
))
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(gate_w,
place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name).get_tensor(
))
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(gate_b,
place)
candidate_w = np.array(fluid.global_scope().find_var(candidate_w_name)
.get_tensor())
candidate_w = np.random.uniform(
-0.1, 0.1, size=candidate_w.shape).astype('float32')
fluid.global_scope().find_var(candidate_w_name).get_tensor().set(
candidate_w, place)
candidate_b = np.array(fluid.global_scope().find_var(candidate_b_name)
.get_tensor())
candidate_b = np.random.uniform(
-0.1, 0.1, size=candidate_b.shape).astype('float32')
fluid.global_scope().find_var(candidate_b_name).get_tensor().set(
candidate_b, place)
step_input_np = np.random.uniform(-0.1, 0.1, (
self.batch_size, self.hidden_size)).astype('float32')
pre_hidden_np = np.random.uniform(-0.1, 0.1, (
self.batch_size, self.hidden_size)).astype('float32')
out = exe.run(feed={'x': step_input_np,
'pre_hidden': pre_hidden_np},
fetch_list=[new_hidden])
api_out = out[0]
np_out = step(step_input_np, pre_hidden_np, gate_w, gate_b, candidate_w,
candidate_b)
self.assertTrue(np.allclose(api_out, np_out, rtol=1e-4, atol=0))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid.contrib.layers import basic_lstm
from paddle.fluid.executor import Executor
from paddle.fluid import framework
import numpy as np
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def lstm_np(input,
init_h,
init_c,
hidden_size,
gate_weight,
gate_bias,
num_layers=1,
batch_first=False,
is_bidirect=False,
sequence_length=None,
forget_bias=1.0):
def step(step_in, pre_hidden, pre_cell, gate_w, gate_b):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
i, j, f, o = np.split(gate_input, indices_or_sections=4, axis=1)
new_cell = pre_cell * sigmoid(f + forget_bias) + sigmoid(i) * tanh(j)
new_hidden = tanh(new_cell) * sigmoid(o)
return new_hidden, new_cell
if batch_first:
input = np.tranpose(input, [1, 0, 2])
if mask is not None:
mask = np.transpose(mask, [1, 0])
batch_size = input.shape[1]
mask = None
if sequence_length is not None:
max_seq_len = input.shape[0]
mask = np.zeros([batch_size, max_seq_len])
for i, len in enumerate(sequence_length):
mask[i, :len] = 1.0
mask = np.transpose(mask, [1, 0])
direc_num = 1
if is_bidirect:
direc_num = 2
if init_h:
init_h = np.reshape(init_h, [num_layers, direc_num, -1, hidden_size])
init_c = np.reshape(init_c, [num_layers, direc_num, -1, hidden_size])
else:
init_h = np.zeros([num_layers, direc_num, batch_size, hidden_size])
init_c = np.zeros([num_layers, direc_num, batch_size, hidden_size])
def get_single_direction_output(rnn_input, mask=None, direc_index=0):
seq_len = rnn_input.shape[0]
output = []
# init pre hidden
pre_hidden_array = []
pre_cell_array = []
for i in range(num_layers):
pre_hidden_array.append(init_h[i, direc_index])
pre_cell_array.append(init_c[i, direc_index])
for i in range(seq_len):
step_input = rnn_input[i]
if mask is not None:
step_mask = mask[i]
step_mask = np.reshape(step_mask, [-1, 1])
#print("np mask", step_mask.shape )
for i in range(num_layers):
new_hidden, new_cell = step(
step_input, pre_hidden_array[i], pre_cell_array[i],
gate_weight[direc_index * num_layers + i],
gate_bias[direc_index * num_layers + i])
if mask is not None:
new_hidden = np.multiply(
new_hidden, step_mask) - np.multiply(
pre_hidden_array[i], (step_mask - 1.0))
#new_hidden = new_hidden * step_mask - pre_hidden_array[i] * ( step_mask -1 )
#new_cell = new_cell * step_mask - pre_cell_array[i] * (step_mask -1)
new_cell = np.multiply(new_cell, step_mask) - np.multiply(
pre_cell_array[i], (step_mask - 1.0))
pre_hidden_array[i] = new_hidden
pre_cell_array[i] = new_cell
step_input = new_hidden
output.append(step_input)
rnn_out = np.concatenate(output, 0)
rnn_out = np.reshape(rnn_out, [seq_len, -1, hidden_size])
last_hidden_out = np.concatenate(pre_hidden_array, 0)
last_hidden_out = np.reshape(last_hidden_out,
[num_layers, -1, hidden_size])
last_cell_out = np.concatenate(pre_cell_array, 0)
last_cell_out = np.reshape(last_cell_out, [num_layers, -1, hidden_size])
return rnn_out, last_hidden_out, last_cell_out
fw_rnn_out, fw_last_hidden, fw_last_cell = get_single_direction_output(
input, mask, direc_index=0)
if is_bidirect:
bw_input = input[::-1]
bw_mask = None
if mask is not None:
bw_mask = mask[::-1]
bw_rnn_out, bw_last_hidden, bw_last_cell = get_single_direction_output(
bw_input, bw_mask, direc_index=1)
bw_rnn_out = bw_rnn_out[::-1]
rnn_out = np.concatenate([fw_rnn_out, bw_rnn_out], 2)
last_hidden = np.concatenate([fw_last_hidden, bw_last_hidden], 1)
last_hidden = np.reshape(last_hidden,
[num_layers * direc_num, -1, hidden_size])
last_cell = np.concatenate([fw_last_cell, bw_last_cell], 1)
last_cell = np.reshape(last_cell,
[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = np.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
last_cell = fw_last_cell
if batch_first:
rnn_out = np.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
class TestBasicLSTMApi(unittest.TestCase):
def setUp(self):
self.hidden_size = 10
self.batch_size = 5
self.seq_len = 6
self.num_layers = 2
self.is_bidirect = True
self.batch_first = False
self.forget_bias = 1.0
def test_run(self):
x = layers.data(
name='x',
shape=[-1, self.batch_size, self.hidden_size],
dtype='float32')
sequence_length = layers.data(
name="sequence_length", shape=[-1], dtype='float32')
rnn_out, last_hidden, last_cell = basic_lstm( x, None, None, self.hidden_size, num_layers=self.num_layers, \
batch_first = self.batch_first, bidirectional=self.is_bidirect, sequence_length=sequence_length, forget_bias = self.forget_bias )
last_hidden.persisbale = True
rnn_out.persisbale = True
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
param_list = fluid.default_main_program().block(0).all_parameters()
# process weight and bias
gate_weight = []
gate_bias = []
for i in range(self.num_layers):
gate_w_name = "basic_lstm_layers_" + str(i) + "/BasicLSTMUnit_0.w_0"
gate_b_name = "basic_lstm_layers_" + str(i) + "/BasicLSTMUnit_0.b_0"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name)
.get_tensor())
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(gate_w,
place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name)
.get_tensor())
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(gate_b,
place)
gate_weight.append(gate_w)
gate_bias.append(gate_b)
if self.is_bidirect:
for i in range(self.num_layers):
gate_w_name = "basic_lstm_reverse_layers_" + str(
i) + "/BasicLSTMUnit_0.w_0"
gate_b_name = "basic_lstm_reverse_layers_" + str(
i) + "/BasicLSTMUnit_0.b_0"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name)
.get_tensor())
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(
gate_w, place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name)
.get_tensor())
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(
gate_b, place)
gate_weight.append(gate_w)
gate_bias.append(gate_b)
step_input_np = np.random.uniform(-0.1, 0.1, (
self.seq_len, self.batch_size, self.hidden_size)).astype('float32')
sequence_length_np = np.random.randint(
self.seq_len // 2, self.seq_len,
size=(self.batch_size)).astype('int64')
out = exe.run(
feed={'x': step_input_np,
'sequence_length': sequence_length_np},
fetch_list=[rnn_out, last_hidden, last_cell])
api_rnn_out = out[0]
api_last_hidden = out[1]
api_last_cell = out[2]
np_out = lstm_np(
step_input_np,
None,
None,
self.hidden_size,
gate_weight,
gate_bias,
num_layers=self.num_layers,
batch_first=self.batch_first,
is_bidirect=self.is_bidirect,
sequence_length=sequence_length_np)
self.assertTrue(np.allclose(api_rnn_out, np_out[0], rtol=1e-4, atol=0))
self.assertTrue(
np.allclose(
api_last_hidden, np_out[1], rtol=1e-4, atol=0))
self.assertTrue(
np.allclose(
api_last_cell, np_out[2], rtol=1e-4, atol=0))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid.contrib.layers import BasicLSTMUnit
from paddle.fluid.executor import Executor
from paddle.fluid import framework
import numpy as np
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0):
concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w)
gate_input += gate_b
i, j, f, o = np.split(gate_input, indices_or_sections=4, axis=1)
new_cell = pre_cell * sigmoid(f + forget_bias) + sigmoid(i) * tanh(j)
new_hidden = tanh(new_cell) * sigmoid(o)
return new_hidden, new_cell
class TestBasicGRUUnit(unittest.TestCase):
def setUp(self):
self.hidden_size = 5
self.batch_size = 5
def test_run(self):
x = layers.data(name='x', shape=[-1, self.hidden_size], dtype='float32')
pre_hidden = layers.data(
name="pre_hidden", shape=[-1, self.hidden_size], dtype='float32')
pre_cell = layers.data(
name="pre_cell", shape=[-1, self.hidden_size], dtype='float32')
lstm_unit = BasicLSTMUnit("lstm_unit", self.hidden_size)
new_hidden, new_cell = lstm_unit(x, pre_hidden, pre_cell)
new_hidden.persisbale = True
new_cell.persisbale = True
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
param_list = fluid.default_main_program().block(0).all_parameters()
# process weight and bias
gate_w_name = "lstm_unit/BasicLSTMUnit_0.w_0"
gate_b_name = "lstm_unit/BasicLSTMUnit_0.b_0"
gate_w = np.array(fluid.global_scope().find_var(gate_w_name).get_tensor(
))
gate_w = np.random.uniform(
-0.1, 0.1, size=gate_w.shape).astype('float32')
fluid.global_scope().find_var(gate_w_name).get_tensor().set(gate_w,
place)
gate_b = np.array(fluid.global_scope().find_var(gate_b_name).get_tensor(
))
gate_b = np.random.uniform(
-0.1, 0.1, size=gate_b.shape).astype('float32')
fluid.global_scope().find_var(gate_b_name).get_tensor().set(gate_b,
place)
step_input_np = np.random.uniform(-0.1, 0.1, (
self.batch_size, self.hidden_size)).astype('float32')
pre_hidden_np = np.random.uniform(-0.1, 0.1, (
self.batch_size, self.hidden_size)).astype('float32')
pre_cell_np = np.random.uniform(-0.1, 0.1, (
self.batch_size, self.hidden_size)).astype('float32')
out = exe.run( feed={ 'x' : step_input_np, 'pre_hidden' : pre_hidden_np, \
'pre_cell' : pre_cell_np },
fetch_list=[ new_hidden, new_cell])
api_hidden_out = out[0]
api_cell_out = out[1]
np_hidden_out, np_cell_out = step(step_input_np, pre_hidden_np,
pre_cell_np, gate_w, gate_b)
self.assertTrue(
np.allclose(
api_hidden_out, np_hidden_out, rtol=1e-4, atol=0))
self.assertTrue(
np.allclose(
api_cell_out, np_cell_out, rtol=1e-4, atol=0))
if __name__ == '__main__':
unittest.main()
......@@ -90,5 +90,67 @@ class SequenceMaskTest6(SequenceMaskTestBase):
self.maxlen = -1
class SequenceMaskTestBase_tensor_attr(OpTest):
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.maxlen = 10
self.maxlen_tensor = np.ones((1), 'int32') * 10
self.mask_dtype = 'int64'
self.x = [[0, 3, 4], [5, 7, 9]]
def initParameters(self):
pass
def setUp(self):
self.initDefaultParameters()
self.initParameters()
if not isinstance(self.x, np.ndarray):
self.x = np.array(self.x)
self.inputs = {'X': self.x, 'MaxLenTensor': self.maxlen_tensor}
self.outputs = {'Y': self.calc_ground_truth_mask()}
self.attrs = {'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)}
def calc_ground_truth_mask(self):
maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen
shape = self.x.shape + (maxlen, )
index_broadcast = np.broadcast_to(
np.reshape(
range(maxlen), newshape=[1] * self.x.ndim + [-1]),
shape=shape)
x_broadcast = np.broadcast_to(
np.reshape(
self.x, newshape=self.x.shape + (-1, )), shape=shape)
return (index_broadcast < x_broadcast).astype(self.mask_dtype)
def test_check_output(self):
self.check_output()
class SequenceMaskTest1_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'bool'
class SequenceMaskTest2_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'uint8'
class SequenceMaskTest3_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'int32'
class SequenceMaskTest4_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'float32'
class SequenceMaskTest5_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'float64'
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册