未验证 提交 cac6f408 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】update dist param grad for pass (#38941)

* update dist param grad for pass

* update unitest

* update unitests

* fix conflict
上级 f080e8d5
......@@ -209,7 +209,8 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context,
dist_params_grads)
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
......
......@@ -20,6 +20,7 @@ import paddle.fluid.core as core
from paddle.utils import unique_name
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Program, OpProtoHolder
from paddle.distributed.fleet.meta_optimizers.common import OpRole
import paddle.fluid.layers.utils as utils
from ..collective import _get_global_env
from .dist_context import DistributedContext
......@@ -862,7 +863,7 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog):
def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
"""Remove no need vars in the main program"""
remove_vars = set()
block = auto_parallel_main_prog.global_block()
......@@ -879,14 +880,42 @@ def _remove_no_need_vars(auto_parallel_main_prog):
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads):
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars:
block._remove_var(var)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id):
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads):
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_remove_no_need_vars(auto_parallel_main_prog)
_remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
def remove_no_need_in_startup(auto_parallel_main_prog,
......@@ -964,7 +993,7 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
dist_context):
dist_context, dist_params_grads):
"""
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
......@@ -973,6 +1002,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
auto_parallel_startup_prog (Program): An auto parallel startup program.
rank_id (int): The process id.
dist_context (DistributedContext): The distributed context of this rank.
dist_params_grads (list): The list contains the tuple of param and grad.
"""
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
......@@ -1049,7 +1079,8 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
idx += 1
# remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id)
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads)
# remove no need vars and ops in the startip program
remove_no_need_in_startup(auto_parallel_main_prog,
......
......@@ -175,7 +175,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_runtime_estimation(cost):
......@@ -229,10 +229,10 @@ class TestCostModel(unittest.TestCase):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
distributed_program, dist_startup_prog = get_dist_prog(
distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id,
dist_context)
dist_context, dist_params_grads)
dist_program.append(distributed_program)
cluster = None
cost = estimate_cost(
......
......@@ -502,7 +502,8 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
return dist_train_program, dist_startup_prog
......
......@@ -183,7 +183,7 @@ def get_dist_prog(train_program,
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check):
......@@ -277,7 +277,7 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, 0)
op_need_check = None
......@@ -306,11 +306,12 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -326,11 +327,12 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
......@@ -347,9 +349,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -137,7 +137,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_send_recv_result(dist_main_prog, rank_id):
......@@ -177,9 +177,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -152,7 +152,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def check_send_recv_result(dist_main_prog, rank_id):
......@@ -211,9 +211,10 @@ class TestMLPReshard(unittest.TestCase):
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog(
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -271,7 +272,7 @@ class TestMLPReshard(unittest.TestCase):
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context)
dist_context, partitioned_params_grads)
# the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册