未验证 提交 516d84b2 编写于 作者: L Li Fuchen 提交者: GitHub

fix tests warpctc (#27639)

上级 c9a88013
...@@ -394,8 +394,7 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -394,8 +394,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4) py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4)
# disable test_warpctc_op py_test_modules(test_warpctc_op MODULES test_warpctc_op)
# py_test_modules(test_warpctc_op MODULES test_warpctc_op)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS}) py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS})
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS}) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS})
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
......
...@@ -24,6 +24,8 @@ from paddle.fluid import Program, program_guard ...@@ -24,6 +24,8 @@ from paddle.fluid import Program, program_guard
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
paddle.enable_static()
CUDA_BLOCK_SIZE = 32 CUDA_BLOCK_SIZE = 32
...@@ -490,8 +492,8 @@ class TestWarpCTCOpError(unittest.TestCase): ...@@ -490,8 +492,8 @@ class TestWarpCTCOpError(unittest.TestCase):
logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32") logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32")
# labels should not be blank # labels should not be blank
labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32") labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
softmax = paddle.to_variable(logits) softmax = paddle.to_tensor(logits)
labels = paddle.to_variable(labels) labels = paddle.to_tensor(labels)
fluid.layers.warpctc(input=softmax, label=labels) fluid.layers.warpctc(input=softmax, label=labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册