未验证 提交 9107dc67 编写于 作者: 0 0x45f 提交者: GitHub

Switch test_transformer to eager mode and fix roll error (#41548)

上级 795d7121
...@@ -6,7 +6,7 @@ set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS} FLAGS_enable_eager_mode=1) ...@@ -6,7 +6,7 @@ set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS} FLAGS_enable_eager_mode=1)
set(TEST_EAGER_OPS test_bmn test_break_continue test_ifelse test_loop test_mnist_amp set(TEST_EAGER_OPS test_bmn test_break_continue test_ifelse test_loop test_mnist_amp
test_mnist_pure_fp16 test_mobile_net test_program_translator test_ptb_lm test_reinforcement_learning test_mnist_pure_fp16 test_mobile_net test_program_translator test_ptb_lm test_reinforcement_learning
test_resnet test_resnet_amp test_resnet_pure_fp16 test_se_resnet test_sentiment test_seq2seq test_resnet test_resnet_amp test_resnet_pure_fp16 test_se_resnet test_sentiment test_seq2seq
test_tsm test_word2vec test_yolov3 test_bert test_cycle_gan test_lstm test_simnet) test_tsm test_word2vec test_yolov3 test_bert test_cycle_gan test_lstm test_simnet test_transformer)
list(REMOVE_ITEM TEST_OPS test_lac) list(REMOVE_ITEM TEST_OPS test_lac)
# NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will # NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will
# be removed and will cause some random failed in multi-thread. # be removed and will cause some random failed in multi-thread.
......
...@@ -784,6 +784,8 @@ def roll(x, shifts, axis=None, name=None): ...@@ -784,6 +784,8 @@ def roll(x, shifts, axis=None, name=None):
axis = [] axis = []
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(shifts, paddle.Tensor):
shifts = shifts.cpu()
return _C_ops.final_state_roll(x, shifts, axis) return _C_ops.final_state_roll(x, shifts, axis)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册