未验证 提交 8624f3b1 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard for auto search (#45002)

* update reshard for auto search

* fix unittest bug

* update dist tensor

* update reshard output

* fix unittests bug

* merge develop
上级 236ad4fc
......@@ -276,8 +276,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_input_dist_attr(self, name):
del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
......@@ -287,8 +287,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_output_dist_attr(self, name):
del self._outputs_dist_attrs[name]
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
......
......@@ -163,7 +163,6 @@ class DistributedTensor:
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {}
self._local_shard_map = {}
self._local_tensor_map = {}
......@@ -223,20 +222,17 @@ class DistributedTensor:
return True
def local_sizes(self, rank=None):
"""Get local sizes of the given rank."""
rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_sizes_map[rank] = local_sizes
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(global_sizes,
dims_mapping, topology,
processes, rank,
shard_sizes)
return local_sizes
......@@ -282,7 +278,6 @@ class DistributedTensor:
def new_local_tensor(self, block=None, rank=None, name=None):
"""
Create a new local tensor of serial tensor corresponding to rank.
Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
......
......@@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs):
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
# skip ops
not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign"
]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
......
......@@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice']
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign'
]
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......
......@@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase):
input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
label_data = np.random.randint(0, 10, [2, 8]).astype("float32")
fetchs = [loss.name, 'input@RESHARD_0']
fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
loss.name, 'split@RESHARD.tmp_1'
]
loss_np, shard_data_np = exe.run(distributed_main_program,
feed={
"input": input_data,
......
......@@ -28,7 +28,7 @@ from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
......@@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册