未验证 提交 2bd63c0c 编写于 作者: K kangguangli 提交者: GitHub

remove with_data_parallel (#51240)

上级 151ec311
...@@ -82,9 +82,7 @@ class AutoCheckpointBase(unittest.TestCase): ...@@ -82,9 +82,7 @@ class AutoCheckpointBase(unittest.TestCase):
sgd, loss, image, label = simple_net() sgd, loss, image, label = simple_net()
if minimize: if minimize:
compiled = fluid.CompiledProgram(main_prog).with_data_parallel( compiled = fluid.CompiledProgram(main_prog)
loss_name=loss.name
)
else: else:
compiled = None compiled = None
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
......
...@@ -80,9 +80,8 @@ class BuildIrMemOptBase(unittest.TestCase): ...@@ -80,9 +80,8 @@ class BuildIrMemOptBase(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
train_cp = compiler.CompiledProgram(fluid.default_main_program()) train_cp = compiler.CompiledProgram(
train_cp = train_cp.with_data_parallel( fluid.default_main_program(), build_strategy=build_strategy
loss_name=cost.name, build_strategy=build_strategy
) )
fetch_list = [cost.name] fetch_list = [cost.name]
......
...@@ -52,9 +52,7 @@ class TestCUDAGraphInFirstBatch(unittest.TestCase): ...@@ -52,9 +52,7 @@ class TestCUDAGraphInFirstBatch(unittest.TestCase):
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
build_strategy.allow_cuda_graph_capture = True build_strategy.allow_cuda_graph_capture = True
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
main main, build_strategy=build_strategy
).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy, places=place
) )
cuda_graph = None cuda_graph = None
......
...@@ -1678,11 +1678,8 @@ class TracedLayer: ...@@ -1678,11 +1678,8 @@ class TracedLayer:
@switch_to_static_graph @switch_to_static_graph
def _compile(self): def _compile(self):
self._compiled_program = CompiledProgram( self._compiled_program = CompiledProgram(
self._program self._program,
).with_data_parallel(
build_strategy=self._build_strategy, build_strategy=self._build_strategy,
exec_strategy=self._exec_strategy,
places=self._place,
) )
def _build_feed(self, inputs): def _build_feed(self, inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册