未验证 提交 c1f2c52c 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

fix bugs about ParallelEnv (#50405)

上级 291c55a2
...@@ -21,7 +21,6 @@ from get_gpt_model import FakeDataset, generate_model ...@@ -21,7 +21,6 @@ from get_gpt_model import FakeDataset, generate_model
import paddle import paddle
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
sys.path.append("..") sys.path.append("..")
from test_sparse_addmm_op import get_cuda_version from test_sparse_addmm_op import get_cuda_version
...@@ -55,7 +54,7 @@ class TestFusedPassBaseList(unittest.TestCase): ...@@ -55,7 +54,7 @@ class TestFusedPassBaseList(unittest.TestCase):
paddle.seed(2021) paddle.seed(2021)
np.random.seed(2021) np.random.seed(2021)
random.seed(2021) random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place) engine._executor = paddle.static.Executor(place)
def get_engine(self, use_fused_passes=False, fused_passes_list=[]): def get_engine(self, use_fused_passes=False, fused_passes_list=[]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册