diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py index 732f0681c4e65006628d51e083a400c0b5bd3d92..89ae3c6a39d6277f590c8f2e02f7b0ae62a1cd4a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py @@ -302,8 +302,11 @@ use_py_reader = False # if we run sync mode sync = False -# how many batches we use -batch_num = 50 +if not core.is_compiled_with_cuda(): + # how many batches we use + batch_num = 50 +else: + batch_num = 5 np.random.seed = 1 src_word_np = np.random.randint(