未验证 提交 cac6f408 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】update dist param grad for pass (#38941)

* update dist param grad for pass

* update unitest

* update unitests

* fix conflict
上级 f080e8d5
...@@ -209,7 +209,8 @@ class AutoParallelizer: ...@@ -209,7 +209,8 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context,
dist_params_grads)
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
......
...@@ -20,6 +20,7 @@ import paddle.fluid.core as core ...@@ -20,6 +20,7 @@ import paddle.fluid.core as core
from paddle.utils import unique_name from paddle.utils import unique_name
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Program, OpProtoHolder from paddle.fluid.framework import Program, OpProtoHolder
from paddle.distributed.fleet.meta_optimizers.common import OpRole
import paddle.fluid.layers.utils as utils import paddle.fluid.layers.utils as utils
from ..collective import _get_global_env from ..collective import _get_global_env
from .dist_context import DistributedContext from .dist_context import DistributedContext
...@@ -862,7 +863,7 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): ...@@ -862,7 +863,7 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
block._remove_op(idx) block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog): def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
"""Remove no need vars in the main program""" """Remove no need vars in the main program"""
remove_vars = set() remove_vars = set()
block = auto_parallel_main_prog.global_block() block = auto_parallel_main_prog.global_block()
...@@ -879,14 +880,42 @@ def _remove_no_need_vars(auto_parallel_main_prog): ...@@ -879,14 +880,42 @@ def _remove_no_need_vars(auto_parallel_main_prog):
for var in vars: for var in vars:
if var not in need_vars: if var not in need_vars:
remove_vars.add(var) remove_vars.add(var)
# change dist_params_grads
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads):
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars: for var in remove_vars:
block._remove_var(var) block._remove_var(var)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id): def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads):
"""Remove no need vars and ops in the main program.""" """Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_remove_no_need_vars(auto_parallel_main_prog) _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
def remove_no_need_in_startup(auto_parallel_main_prog, def remove_no_need_in_startup(auto_parallel_main_prog,
...@@ -964,7 +993,7 @@ def remove_no_need_in_startup(auto_parallel_main_prog, ...@@ -964,7 +993,7 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
dist_context): dist_context, dist_params_grads):
""" """
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
...@@ -973,6 +1002,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -973,6 +1002,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
auto_parallel_startup_prog (Program): An auto parallel startup program. auto_parallel_startup_prog (Program): An auto parallel startup program.
rank_id (int): The process id. rank_id (int): The process id.
dist_context (DistributedContext): The distributed context of this rank. dist_context (DistributedContext): The distributed context of this rank.
dist_params_grads (list): The list contains the tuple of param and grad.
""" """
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog)) "but got {}.".format(type(auto_parallel_main_prog))
...@@ -1049,7 +1079,8 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -1049,7 +1079,8 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
idx += 1 idx += 1
# remove no need vars and ops in the main program # remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads)
# remove no need vars and ops in the startip program # remove no need vars and ops in the startip program
remove_no_need_in_startup(auto_parallel_main_prog, remove_no_need_in_startup(auto_parallel_main_prog,
......
...@@ -175,7 +175,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -175,7 +175,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_runtime_estimation(cost): def check_runtime_estimation(cost):
...@@ -229,10 +229,10 @@ class TestCostModel(unittest.TestCase): ...@@ -229,10 +229,10 @@ class TestCostModel(unittest.TestCase):
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
distributed_program, dist_startup_prog = get_dist_prog( distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id, reshard(distributed_program, dist_startup_prog, rank_id,
dist_context) dist_context, dist_params_grads)
dist_program.append(distributed_program) dist_program.append(distributed_program)
cluster = None cluster = None
cost = estimate_cost( cost = estimate_cost(
......
...@@ -502,7 +502,8 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -502,7 +502,8 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads) dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) reshard(dist_train_program, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
return dist_train_program, dist_startup_prog return dist_train_program, dist_startup_prog
......
...@@ -183,7 +183,7 @@ def get_dist_prog(train_program, ...@@ -183,7 +183,7 @@ def get_dist_prog(train_program,
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check):
...@@ -277,7 +277,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -277,7 +277,7 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, 0) train_program, startup_program, dist_context, 0)
op_need_check = None op_need_check = None
...@@ -306,11 +306,12 @@ class TestMLPReshard(unittest.TestCase): ...@@ -306,11 +306,12 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -326,11 +327,12 @@ class TestMLPReshard(unittest.TestCase): ...@@ -326,11 +327,12 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
print_program_with_dist_attr(dist_main_prog, dist_context) print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
...@@ -347,9 +349,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -347,9 +349,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -137,7 +137,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -137,7 +137,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_send_recv_result(dist_main_prog, rank_id): def check_send_recv_result(dist_main_prog, rank_id):
...@@ -177,9 +177,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -177,9 +177,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# print_program_with_dist_attr(dist_main_prog, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -152,7 +152,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -152,7 +152,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_send_recv_result(dist_main_prog, rank_id): def check_send_recv_result(dist_main_prog, rank_id):
...@@ -211,9 +211,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -211,9 +211,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -271,7 +272,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -271,7 +272,7 @@ class TestMLPReshard(unittest.TestCase):
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context) dist_context, partitioned_params_grads)
# the x should not be slice # the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog)) self.assertTrue(check_allgather(partitioned_main_prog))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册