未验证 提交 d4e8c99f 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #16095 from velconia/transfer_gru_unit

Imperative transfer gru unit
......@@ -22,7 +22,8 @@ from . import layers
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding']
__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit']
class Conv2D(layers.Layer):
......@@ -468,3 +469,137 @@ class Embedding(layers.Layer):
})
return out
class GRUUnit(layers.Layer):
"""
**GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t)
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
:math:`xu_t`, :math:`xr_t` and :math:`xm_t`. This means that in order to
implement a full GRU unit operator for an input, a fully
connected layer has to be applied, such that :math:`z_t = W_{fc}x_t`.
The terms :math:`u_t` and :math:`r_t` represent the update and reset gates
of the GRU cell. Unlike LSTM, GRU has one lesser gate. However, there is
an intermediate candidate hidden output, which is denoted by :math:`m_t`.
This layer has three outputs :math:`h_t`, :math:`dot(r_t, h_{t-1})`
and concatenation of :math:`u_t`, :math:`r_t` and :math:`m_t`.
Args:
input (Variable): The fc transformed input value of current step.
name_scope (str): See base class.
hidden (Variable): The hidden value of gru unit from previous step.
size (integer): The input dimension value.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:
- The shape of the weight matrix is :math:`(T \\times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU.Note that the bias with :math:`(1 \\times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, gru_unit will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
activation (string): The activation type for cell (actNode).
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
Default: 'sigmoid'
Returns:
tuple: The hidden value, reset-hidden value and gate values.
"""
def __init__(self,
name_scope,
size,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid',
origin_mode=False,
dtype='float32'):
super(GRUUnit, self).__init__(name_scope)
activation_dict = dict(
identity=0,
sigmoid=1,
tanh=2,
relu=3, )
activation = activation_dict[activation]
gate_activation = activation_dict[gate_activation]
self._dtype = dtype
size = size // 3
# create weight
self._weight = self.create_parameter(
attr=param_attr, shape=[size, 3 * size], dtype=dtype)
# create bias
bias_size = [1, 3 * size]
self._bias = self.create_parameter(
attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
def forward(self, input, hidden):
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self._weight}
if self._bias:
inputs['Bias'] = self._bias
gate = self._helper.create_variable_for_type_inference(self._dtype)
reset_hidden_pre = self._helper.create_variable_for_type_inference(
self._dtype)
updated_hidden = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type='gru_unit',
inputs=inputs,
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 2, # tanh
'gate_activation': 1, # sigmoid
})
return updated_hidden, reset_hidden_pre, gate
......@@ -22,6 +22,7 @@ import six
import time
import itertools
import collections
from collections import defaultdict
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -257,8 +258,65 @@ class OpTest(unittest.TestCase):
outs, _ = self._calc_output(place)
return outs
def _calc_output(self, place, parallel=False, no_check_set=None):
def _create_var_from_numpy(self, value):
if isinstance(value, tuple):
data = value[0]
lod = value[1]
v = fluid.imperative.base.to_variable(value=data)
v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod)
return v
else:
return fluid.imperative.base.to_variable(value)
def _calc_imperative_output(self, place, parallel=False, no_check_set=None):
with fluid.imperative.base.guard(place=place):
block = fluid.default_main_program().global_block()
# prepare input variable
inputs = defaultdict(list)
for name, np_value in six.iteritems(self.inputs):
if not isinstance(np_value, list):
np_value = [np_value]
for i in range(len(np_value)):
inputs[name].append(
self._create_var_from_numpy(np_value[i]))
# prepare output variable
outputs = defaultdict(list)
for name, np_value in six.iteritems(self.outputs):
if not isinstance(np_value, list):
np_value = [np_value]
for i in range(len(np_value)):
value = np_value[i]
if isinstance(value, tuple):
v = block.create_var(
name="%s_out%d" % (name, i),
dtype=value[0].dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
v._ivar.value().get_tensor(
).set_recursive_sequence_lengths(value[1])
else:
v = block.create_var(
name="%s_out%d" % (name, i),
dtype=value.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs[name].append(v)
block.append_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=self.attrs)
return outputs
def _calc_output(self, place, parallel=False, no_check_set=None):
program = Program()
block = program.global_block()
self._append_ops(block)
......@@ -305,8 +363,13 @@ class OpTest(unittest.TestCase):
place,
atol,
no_check_set=None,
equal_nan=False):
equal_nan=False,
check_imperative=False):
if check_imperative:
imperative_outs = self._calc_imperative_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
......@@ -330,6 +393,10 @@ class OpTest(unittest.TestCase):
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
if check_imperative:
imperative_actual = imperative_outs[sub_out_name][0]
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
......@@ -340,12 +407,31 @@ class OpTest(unittest.TestCase):
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in imperative mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")
else:
if check_imperative:
imperative_actual = imperative_outs[out_name][0]
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
......@@ -357,10 +443,27 @@ class OpTest(unittest.TestCase):
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(imperative_actual_t) + " in class " +
self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")
def _get_places(self):
if self.dtype == np.float16:
......@@ -383,10 +486,15 @@ class OpTest(unittest.TestCase):
places.append(core.CUDAPlace(0))
return places
def check_output(self, atol=1e-5, no_check_set=None, equal_nan=False):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_imperative=False):
places = self._get_places()
for place in places:
self.check_output_with_place(place, atol, no_check_set, equal_nan)
self.check_output_with_place(place, atol, no_check_set, equal_nan,
check_imperative)
def check_output_customized(self, checker):
places = self._get_places()
......
......@@ -156,7 +156,7 @@ class TestGRUOp(OpTest):
}
def test_check_output(self):
self.check_output(atol=1e-8)
self.check_output(atol=1e-8, check_imperative=True)
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
......
......@@ -112,6 +112,47 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
self.assertTrue(np.allclose(static_ret, static_ret2))
def test_gru_unit(self):
lod = [[2, 4, 3]]
D = 5
T = sum(lod[0])
N = len(lod[0])
input = np.random.rand(T, 3 * D).astype('float32')
hidden_input = np.random.rand(T, D).astype('float32')
with self.static_graph():
x = layers.data(name='x', shape=[-1, D * 3], dtype='float32')
hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32')
updated_hidden, reset_hidden_pre, gate = layers.gru_unit(
input=x, hidden=hidden, size=D * 3)
static_ret = self.get_static_graph_result(
feed={'x': input,
'hidden': hidden_input},
fetch_list=[updated_hidden, reset_hidden_pre, gate])
with self.static_graph():
x = layers.data(name='x', shape=[-1, D * 3], dtype='float32')
hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32')
updated_hidden, reset_hidden_pre, gate = layers.gru_unit(
input=x, hidden=hidden, size=D * 3)
gru = nn.GRUUnit('gru', size=D * 3)
updated_hidden, reset_hidden_pre, gate = gru(x, hidden)
static_ret2 = self.get_static_graph_result(
feed={'x': input,
'hidden': hidden_input},
fetch_list=[updated_hidden, reset_hidden_pre, gate])
with self.dynamic_graph():
gru = nn.GRUUnit('gru', size=D * 3)
dy_ret = gru(
base.to_variable(input), base.to_variable(hidden_input))
for i in range(len(static_ret)):
self.assertTrue(np.allclose(static_ret[i], static_ret2[i]))
self.assertTrue(np.allclose(static_ret[i], dy_ret[i]._numpy()))
class TestBook(unittest.TestCase):
def test_fit_a_line(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册