diff --git a/python/paddle/fluid/tests/unittests/test_basic_lstm_api.py b/python/paddle/fluid/tests/unittests/test_basic_lstm_api.py index 5383632838d318e935bb079380e4ddf99d5db1c9..bedba672edf95895a0571c4316a026a09c89b418 100644 --- a/python/paddle/fluid/tests/unittests/test_basic_lstm_api.py +++ b/python/paddle/fluid/tests/unittests/test_basic_lstm_api.py @@ -68,13 +68,14 @@ def lstm_np(input, return new_hidden, new_cell + mask = None + if batch_first: input = np.tranpose(input, [1, 0, 2]) if mask is not None: mask = np.transpose(mask, [1, 0]) batch_size = input.shape[1] - mask = None if sequence_length is not None: max_seq_len = input.shape[0]