提交 bdd957b4 编写于 作者: J JiabinYang

fix test_parallel_executor_transformer

上级 08cfe27c
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase
import unittest import unittest
import paddle import paddle
import paddle.fluid.core as core
import paddle.dataset.wmt16 as wmt16 import paddle.dataset.wmt16 as wmt16
import os import os
...@@ -170,6 +171,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -170,6 +171,7 @@ class TestTransformer(TestParallelExecutorBase):
writer.complete_append_tensor() writer.complete_append_tensor()
def test_main(self): def test_main(self):
if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5) self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册