diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 2d2b70e3f73f2b8e38eb789a88e65d68e0b3b2b2..6681b423415251f5fb75e5bff61abb80fb63e417 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -22,6 +22,7 @@ from . import layers from ..framework import Variable, OpProtoHolder from ..param_attr import ParamAttr from ..initializer import Normal, Constant + __all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit'] @@ -548,7 +549,7 @@ class GRUUnit(layers.Layer): """ def __init__(self, - hidden, + name_scope, size, param_attr=None, bias_attr=None, @@ -556,8 +557,8 @@ class GRUUnit(layers.Layer): gate_activation='sigmoid', origin_mode=False, dtype='float32'): + super(GRUUnit, self).__init__(name_scope) - super(GRUUnit, self).__init__() activation_dict = dict( identity=0, sigmoid=1, @@ -566,29 +567,27 @@ class GRUUnit(layers.Layer): activation = activation_dict[activation] gate_activation = activation_dict[gate_activation] - helper = LayerHelper('gru_unit', **locals()) - dtype = helper.input_dtype() + self._dtype = dtype size = size // 3 - # create weight - weight = helper.create_parameter( - attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype) + self._weight = self.create_parameter( + attr=param_attr, shape=[size, 3 * size], dtype=dtype) - gate = helper.create_variable_for_type_inference(dtype) - reset_hidden_pre = helper.create_variable_for_type_inference(dtype) - updated_hidden = helper.create_variable_for_type_inference(dtype) - inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': weight} # create bias - if helper.bias_attr: - bias_size = [1, 3 * size] - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=bias_size, - dtype=dtype, - is_bias=True) - inputs['Bias'] = bias + bias_size = [1, 3 * size] + self._bias = self.create_parameter( + attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True) - def forward(self, input): + def forward(self, input, hidden): + inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self._weight} + if self._bias: + inputs['Bias'] = self._bias + + gate = self._helper.create_variable_for_type_inference(self._dtype) + reset_hidden_pre = self._helper.create_variable_for_type_inference( + self._dtype) + updated_hidden = self._helper.create_variable_for_type_inference( + self._dtype) self._helper.append_op( type='gru_unit', inputs=inputs, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 823445724302dbde47bc36122c62ef44a7e2394f..9fa62a692ee9d012df23bcc0d1d21c4c085fd9e0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -22,6 +22,7 @@ import six import time import itertools import collections +from collections import defaultdict import paddle.fluid as fluid import paddle.fluid.core as core @@ -257,8 +258,65 @@ class OpTest(unittest.TestCase): outs, _ = self._calc_output(place) return outs - def _calc_output(self, place, parallel=False, no_check_set=None): + def _create_var_from_numpy(self, value): + if isinstance(value, tuple): + data = value[0] + lod = value[1] + v = fluid.imperative.base.to_variable(value=data) + v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod) + return v + else: + return fluid.imperative.base.to_variable(value) + + def _calc_imperative_output(self, place, parallel=False, no_check_set=None): + with fluid.imperative.base.guard(place=place): + block = fluid.default_main_program().global_block() + + # prepare input variable + inputs = defaultdict(list) + for name, np_value in six.iteritems(self.inputs): + if not isinstance(np_value, list): + np_value = [np_value] + + for i in range(len(np_value)): + inputs[name].append( + self._create_var_from_numpy(np_value[i])) + + # prepare output variable + outputs = defaultdict(list) + for name, np_value in six.iteritems(self.outputs): + if not isinstance(np_value, list): + np_value = [np_value] + + for i in range(len(np_value)): + value = np_value[i] + if isinstance(value, tuple): + v = block.create_var( + name="%s_out%d" % (name, i), + dtype=value[0].dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + v._ivar.value().get_tensor( + ).set_recursive_sequence_lengths(value[1]) + else: + v = block.create_var( + name="%s_out%d" % (name, i), + dtype=value.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + outputs[name].append(v) + + block.append_op( + type=self.op_type, + inputs=inputs, + outputs=outputs, + attrs=self.attrs) + + return outputs + def _calc_output(self, place, parallel=False, no_check_set=None): program = Program() block = program.global_block() self._append_ops(block) @@ -305,8 +363,13 @@ class OpTest(unittest.TestCase): place, atol, no_check_set=None, - equal_nan=False): + equal_nan=False, + check_imperative=False): + if check_imperative: + imperative_outs = self._calc_imperative_output( + place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) + for out_name, out_dup in Operator.get_op_outputs(self.op_type): if out_name not in self.outputs: continue @@ -330,6 +393,10 @@ class OpTest(unittest.TestCase): type(sub_out)) for item in sub_out: sub_out_name, expect = item[0], item[1] + if check_imperative: + imperative_actual = imperative_outs[sub_out_name][0] + imperative_actual_t = np.array( + imperative_actual._ivar.value().get_tensor()) idx = find_actual(sub_out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -340,12 +407,24 @@ class OpTest(unittest.TestCase): actual_t, expect_t, atol=atol, equal_nan=equal_nan), "Output (" + sub_out_name + ") has diff at " + str(place)) + self.assertTrue( + np.allclose( + imperative_actual_t, + expect_t, + atol=atol, + equal_nan=equal_nan), + "Output (" + sub_out_name + ") has diff at " + + str(place) + " in imperative mode") if isinstance(expect, tuple): self.assertListEqual( actual.recursive_sequence_lengths(), expect[1], "Output (" + sub_out_name + ") has different lod at " + str(place)) else: + if check_imperative: + imperative_actual = imperative_outs[out_name][0] + imperative_actual_t = np.array( + imperative_actual._ivar.value().get_tensor()) idx = find_actual(out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -357,10 +436,25 @@ class OpTest(unittest.TestCase): "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + str(expect_t) + "\n" + "But Got" + str(actual_t) + " in class " + self.__class__.__name__) + self.assertTrue( + np.allclose( + imperative_actual_t, + expect_t, + atol=atol, + equal_nan=equal_nan), + "Output (" + out_name + ") has diff at " + str(place) + + "\nExpect " + str(expect_t) + "\n" + "But Got" + + str(imperative_actual_t) + " in class " + + self.__class__.__name__) if isinstance(expect, tuple): self.assertListEqual(actual.recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + str(place)) + if check_imperative: + self.assertListEqual( + imperative_actual._ivar.value().get_tensor() + .recursive_sequence_lengths(), expect[1], "Output (" + + out_name + ") has different lod at " + str(place)) def _get_places(self): if self.dtype == np.float16: @@ -383,10 +477,15 @@ class OpTest(unittest.TestCase): places.append(core.CUDAPlace(0)) return places - def check_output(self, atol=1e-5, no_check_set=None, equal_nan=False): + def check_output(self, + atol=1e-5, + no_check_set=None, + equal_nan=False, + check_imperative=False): places = self._get_places() for place in places: - self.check_output_with_place(place, atol, no_check_set, equal_nan) + self.check_output_with_place(place, atol, no_check_set, equal_nan, + check_imperative) def check_output_customized(self, checker): places = self._get_places() diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 6606162733487b15ef55f1a4677fb382e6e7e0ac..848c9a4952aebcf93fd7bf12f7bc4cd15c7a8b28 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -156,7 +156,7 @@ class TestGRUOp(OpTest): } def test_check_output(self): - self.check_output(atol=1e-8) + self.check_output(atol=1e-8, check_imperative=True) def test_check_grad(self): self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b29ad2587016a04ecd2d875538d9d5c437793dfd..5b186ae0384e3d365303c25861138a3c7e4c189f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -112,6 +112,47 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_ret, dy_ret._numpy())) self.assertTrue(np.allclose(static_ret, static_ret2)) + def test_gru_unit(self): + lod = [[2, 4, 3]] + D = 5 + T = sum(lod[0]) + N = len(lod[0]) + + input = np.random.rand(T, 3 * D).astype('float32') + hidden_input = np.random.rand(T, D).astype('float32') + + with self.static_graph(): + x = layers.data(name='x', shape=[-1, D * 3], dtype='float32') + hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32') + updated_hidden, reset_hidden_pre, gate = layers.gru_unit( + input=x, hidden=hidden, size=D * 3) + static_ret = self.get_static_graph_result( + feed={'x': input, + 'hidden': hidden_input}, + fetch_list=[updated_hidden, reset_hidden_pre, gate]) + + with self.static_graph(): + x = layers.data(name='x', shape=[-1, D * 3], dtype='float32') + hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32') + updated_hidden, reset_hidden_pre, gate = layers.gru_unit( + input=x, hidden=hidden, size=D * 3) + gru = nn.GRUUnit('gru', size=D * 3) + updated_hidden, reset_hidden_pre, gate = gru(x, hidden) + + static_ret2 = self.get_static_graph_result( + feed={'x': input, + 'hidden': hidden_input}, + fetch_list=[updated_hidden, reset_hidden_pre, gate]) + + with self.dynamic_graph(): + gru = nn.GRUUnit('gru', size=D * 3) + dy_ret = gru( + base.to_variable(input), base.to_variable(hidden_input)) + + for i in range(len(static_ret)): + self.assertTrue(np.allclose(static_ret[i], static_ret2[i])) + self.assertTrue(np.allclose(static_ret[i], dy_ret[i]._numpy())) + class TestBook(unittest.TestCase): def test_fit_a_line(self):