diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index d4277daebdc9e9c5dbf2f1c8a3fdbe26653a3659..bbe9d9d31974c5cf9f92eca5b21f4750c84773c7 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1911,6 +1911,10 @@ class GRUUnit(layers.Layer): self.activation, 'gate_activation', self.gate_activation) return updated_hidden, reset_hidden_pre, gate + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'GRUUnit') + check_variable_and_dtype(hidden, 'hidden', ['float32', 'float64'], + 'GRUUnit') inputs = { 'Input': [input], 'HiddenPrev': [hidden], @@ -1918,10 +1922,6 @@ class GRUUnit(layers.Layer): } if self.bias is not None: inputs['Bias'] = [self.bias] - attrs = { - 'activation': self.activation, - 'gate_activation': self.gate_activation, - } gate = self._helper.create_variable_for_type_inference(self._dtype) reset_hidden_pre = self._helper.create_variable_for_type_inference( self._dtype) diff --git a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py index 6164591992b7acf73faee2132a5e10e37514c9d3..4143619d981f1ff167c4a3b2825ab497ce6dec86 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py @@ -17,9 +17,25 @@ from __future__ import print_function import math import unittest import numpy as np +import paddle.fluid as fluid from op_test import OpTest +class TestGRUUnitAPIError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + D = 5 + layer = fluid.dygraph.nn.GRUUnit(size=D * 3) + # the input must be Variable. + x0 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, layer, x0) + # the input dtype must be float32 or float64 + x = fluid.data(name='x', shape=[-1, D * 3], dtype='float16') + hidden = fluid.data(name='hidden', shape=[-1, D], dtype='float32') + self.assertRaises(TypeError, layer, x, hidden) + + class GRUActivationType(OpTest): identity = 0 sigmoid = 1