未验证 提交 25a77450 编写于 作者: K kangguangli 提交者: GitHub

remove with_data_parallel in OpTest (#51237)

上级 2bd63c0c
...@@ -943,11 +943,7 @@ class OpTest(unittest.TestCase): ...@@ -943,11 +943,7 @@ class OpTest(unittest.TestCase):
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
use_cuda = True use_cuda = True
compiled_prog = fluid.CompiledProgram( compiled_prog = fluid.CompiledProgram(program)
program
).with_data_parallel(
loss_name=loss.name if loss else None, places=place
)
program = compiled_prog program = compiled_prog
fetch_list = getattr(self, "fetch_list", []) fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly. # if the fetch_list is customized by user, we use it directly.
...@@ -971,9 +967,7 @@ class OpTest(unittest.TestCase): ...@@ -971,9 +967,7 @@ class OpTest(unittest.TestCase):
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
compiled_prog = fluid.CompiledProgram( compiled_prog = fluid.CompiledProgram(
program program, build_strategy=build_strategy
).with_data_parallel(
build_strategy=build_strategy, places=place
) )
program = compiled_prog program = compiled_prog
...@@ -1273,9 +1267,7 @@ class OpTest(unittest.TestCase): ...@@ -1273,9 +1267,7 @@ class OpTest(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
compiled_program = fluid.CompiledProgram( compiled_program = fluid.CompiledProgram(
grad_program grad_program, build_strategy=build_strategy
).with_data_parallel(
loss_name="", build_strategy=build_strategy, places=place
) )
program = compiled_program program = compiled_program
...@@ -2426,9 +2418,7 @@ class OpTest(unittest.TestCase): ...@@ -2426,9 +2418,7 @@ class OpTest(unittest.TestCase):
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
use_cuda = True use_cuda = True
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel( compiled_prog = fluid.CompiledProgram(prog)
loss_name=loss.name, places=place
)
prog = compiled_prog prog = compiled_prog
executor = fluid.Executor(place) executor = fluid.Executor(place)
res = list( res = list(
......
...@@ -1041,11 +1041,7 @@ class OpTest(unittest.TestCase): ...@@ -1041,11 +1041,7 @@ class OpTest(unittest.TestCase):
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
use_cuda = True use_cuda = True
compiled_prog = fluid.CompiledProgram( compiled_prog = fluid.CompiledProgram(program)
program
).with_data_parallel(
loss_name=loss.name if loss else None, places=place
)
program = compiled_prog program = compiled_prog
fetch_list = getattr(self, "fetch_list", []) fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly. # if the fetch_list is customized by user, we use it directly.
...@@ -1069,9 +1065,7 @@ class OpTest(unittest.TestCase): ...@@ -1069,9 +1065,7 @@ class OpTest(unittest.TestCase):
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
compiled_prog = fluid.CompiledProgram( compiled_prog = fluid.CompiledProgram(
program program, build_strategy=build_strategy
).with_data_parallel(
build_strategy=build_strategy, places=place
) )
program = compiled_prog program = compiled_prog
...@@ -1371,9 +1365,7 @@ class OpTest(unittest.TestCase): ...@@ -1371,9 +1365,7 @@ class OpTest(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
compiled_program = fluid.CompiledProgram( compiled_program = fluid.CompiledProgram(
grad_program grad_program, build_strategy=build_strategy
).with_data_parallel(
loss_name="", build_strategy=build_strategy, places=place
) )
program = compiled_program program = compiled_program
...@@ -2736,9 +2728,7 @@ class OpTest(unittest.TestCase): ...@@ -2736,9 +2728,7 @@ class OpTest(unittest.TestCase):
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
use_cuda = True use_cuda = True
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel( compiled_prog = fluid.CompiledProgram(prog)
loss_name=loss.name, places=place
)
prog = compiled_prog prog = compiled_prog
executor = fluid.Executor(place) executor = fluid.Executor(place)
res = list( res = list(
......
...@@ -104,10 +104,9 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -104,10 +104,9 @@ class TestParallelExecutorBase(unittest.TestCase):
) )
if use_parallel_executor: if use_parallel_executor:
binary = compiler.CompiledProgram(main).with_data_parallel( binary = compiler.CompiledProgram(
loss_name=loss.name, main,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy,
) )
else: else:
binary = main binary = main
...@@ -204,10 +203,9 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -204,10 +203,9 @@ class TestParallelExecutorBase(unittest.TestCase):
use_device, use_device,
) )
binary = compiler.CompiledProgram(main).with_data_parallel( binary = compiler.CompiledProgram(
loss_name=loss.name, main,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy,
) )
exe.run(binary, feed=feed_dict, fetch_list=[loss.name]) exe.run(binary, feed=feed_dict, fetch_list=[loss.name])
......
...@@ -121,7 +121,6 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -121,7 +121,6 @@ class TestDygraphDataLoader(unittest.TestCase):
return ret return ret
def test_main(self): def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(): for p in prepare_places():
for persistent_workers in [False, True]: for persistent_workers in [False, True]:
results = [] results = []
......
...@@ -121,7 +121,6 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -121,7 +121,6 @@ class TestDygraphDataLoader(unittest.TestCase):
return ret return ret
def test_main(self): def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(): for p in prepare_places():
for persistent_workers in [False, True]: for persistent_workers in [False, True]:
results = [] results = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册