提交 bdd957b4 编写于 作者: J JiabinYang

fix test_parallel_executor_transformer

上级 08cfe27c
......@@ -20,6 +20,7 @@ import numpy as np
from parallel_executor_test_base import TestParallelExecutorBase
import unittest
import paddle
import paddle.fluid.core as core
import paddle.dataset.wmt16 as wmt16
import os
......@@ -170,6 +171,7 @@ class TestTransformer(TestParallelExecutorBase):
writer.complete_append_tensor()
def test_main(self):
if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册