From 516d84b22a1ebffb1fc2b32d5be21053eb58fb78 Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Mon, 28 Sep 2020 11:37:06 +0800 Subject: [PATCH] fix tests warpctc (#27639) --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 3 +-- python/paddle/fluid/tests/unittests/test_warpctc_op.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 0fa79f02ab..23aaa90d68 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -394,8 +394,7 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4) -# disable test_warpctc_op -# py_test_modules(test_warpctc_op MODULES test_warpctc_op) +py_test_modules(test_warpctc_op MODULES test_warpctc_op) py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS}) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS}) py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index b82ab04c98..6310a76d8d 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -24,6 +24,8 @@ from paddle.fluid import Program, program_guard import paddle import paddle.nn.functional as F +paddle.enable_static() + CUDA_BLOCK_SIZE = 32 @@ -490,8 +492,8 @@ class TestWarpCTCOpError(unittest.TestCase): logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32") # labels should not be blank labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32") - softmax = paddle.to_variable(logits) - labels = paddle.to_variable(labels) + softmax = paddle.to_tensor(logits) + labels = paddle.to_tensor(labels) fluid.layers.warpctc(input=softmax, label=labels) -- GitLab