diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index ecc58768522831b55f620cb6dc911630e2c2ad68..39d25a9c7ffd47212153a345a98781ed034c7020 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -127,7 +127,8 @@ class RNNCell(object): else: integer_types = (int, ) check_variable_and_dtype(batch_ref, 'batch_ref', - ['float32', 'float64'], 'RNNCell') + ['float32', 'float64', 'int32', 'int64'], + 'RNNCell') check_type(shape, 'shape', (list, tuple, type(None), integer_types), 'RNNCell') if isinstance(shape, (list, tuple)):