未验证 提交 d3ba1895 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】New local tensor (#38747)

* update dist tensor

* add unitest

* update unitest

* refactor dist tensor

* update dist tensor and unitest
上级 fbb40281
...@@ -62,6 +62,10 @@ class DistributedContext: ...@@ -62,6 +62,10 @@ class DistributedContext:
self._dist_op_context = DistributedOperatorContext() self._dist_op_context = DistributedOperatorContext()
self._process_meshes = [] self._process_meshes = []
# Distributed programs
self._dist_main_programs = {}
self._dist_startup_programs = {}
@property @property
def serial_program(self): def serial_program(self):
return self._serial_program return self._serial_program
...@@ -84,6 +88,14 @@ class DistributedContext: ...@@ -84,6 +88,14 @@ class DistributedContext:
def dist_op_context(self): def dist_op_context(self):
return self._dist_op_context return self._dist_op_context
@property
def dist_main_programs(self):
return self._dist_main_programs
@property
def dist_startup_programs(self):
return self._dist_startup_programs
def add_process_mesh(self, process_mesh): def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.' 'The type of dim_mapping must be ProcessMesh.'
...@@ -371,10 +383,14 @@ class DistributedContext: ...@@ -371,10 +383,14 @@ class DistributedContext:
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph": if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
# update dist tensor's dist_context
for key in result._dist_tensors_for_program.keys():
result._dist_tensors_for_program[key]._dist_context = result
return result return result
......
...@@ -13,18 +13,155 @@ ...@@ -13,18 +13,155 @@
# limitations under the License # limitations under the License
import copy import copy
import inspect
import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Parameter, Block, Variable
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys from .dist_attribute import get_tensor_dist_attr_field_keys
from .utils import _linear_idx2coordinate
class DistributedTensor: class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None): """
DistributedTensor represents the distribution of tensor on the process group and
local tensors can be created by DistributedTensor.
Only support even sharding now and uneven sharding will be supported in the future.
Local tensor information can be obtained from the DistributedTensor instance object,
or obtained by the static methods provided by DistributedTensor,
including shard (i.e. the index in the serial tensor), offsets, and sizes.
"""
@staticmethod
def _validate_sizes_and_dist_attr(sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, sizes))):
raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".
format(sizes))
if not (isinstance(dims_mapping, (list, tuple)) and all(
map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))):
raise ValueError(
"The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}".
format(dims_mapping))
if not (isinstance(processes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x >= 0, processes))):
raise ValueError(
"The processes must be list or tuple and item in processes must be integer, but got {}".
format(processes))
if not (isinstance(topology, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, topology))):
raise ValueError(
"The topology must be list or tuple and item in topology must be non-negative integer, but got {}".
format(topology))
if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank))
# NOTE: Only support even sharding now
if shard_sizes is not None:
raise ValueError("Only support even sharding now.")
@staticmethod
def get_local_sizes(global_sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_sizes = []
# for even sharding, the local sizes of every rank are equal
for idx, item in enumerate(global_sizes):
if dims_mapping[idx] == -1:
local_sizes.append(item)
else:
local_sizes.append(item // topology[dims_mapping[idx]])
return local_sizes
@staticmethod
def get_local_offsets(global_sizes,
dims_mapping,
topology,
processes,
rank,
shard_sizes=None):
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_offsets = []
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(topology, rank_relatvie)
for i in range(len(global_sizes)):
if dims_mapping[i] == -1:
local_offsets.append(0)
else:
local_offsets.append(coordinate[dims_mapping[i]] *
local_sizes[i])
return local_offsets
@staticmethod
def get_global_sizes(local_sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr(
local_sizes, dims_mapping, topology, processes, rank, shard_sizes)
global_sizes = []
for idx, item in enumerate(local_sizes):
if dims_mapping[idx] == -1:
global_sizes.append(item)
else:
global_sizes.append(item * topology[dims_mapping[idx]])
return global_sizes
@staticmethod
def get_local_shard(global_sizes,
dims_mapping,
topology,
processes,
rank,
shard_sizes=None):
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
assert len(local_sizes) == len(
local_offsets
), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format(
len(local_sizes), len(local_offsets))
local_end_offsets = list(
map(lambda x: x[0] + x[1], zip(local_offsets, local_sizes)))
local_shard = list(zip(local_offsets, local_end_offsets))
return local_shard
def __init__(self, serial_tensor, dist_attr=None, dist_context=None):
self._serial_tensor = serial_tensor self._serial_tensor = serial_tensor
self._dist_attr = None self._dist_attr = None
self._batch_dim = 0 self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr # Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {}
self._local_shard_map = {}
self._local_tensor_map = {}
from .dist_context import get_default_distributed_context
self._dist_context = dist_context if dist_context is not None else get_default_distributed_context(
)
# TODO: Add Automatically to dist_context after initialized and it will be adapted in the future.
# self._dist_context.add_dist_tensor_for_program(self)
@property @property
def serial_tensor(self): def serial_tensor(self):
...@@ -34,6 +171,10 @@ class DistributedTensor: ...@@ -34,6 +171,10 @@ class DistributedTensor:
def dist_attr(self): def dist_attr(self):
return self._dist_attr return self._dist_attr
@property
def dist_context(self):
return self._dist_context
@dist_attr.setter @dist_attr.setter
def dist_attr(self, dist_attr): def dist_attr(self, dist_attr):
if self._dist_attr is None: if self._dist_attr is None:
...@@ -66,12 +207,150 @@ class DistributedTensor: ...@@ -66,12 +207,150 @@ class DistributedTensor:
return False return False
return True return True
def local_sizes(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_sizes_map[rank] = local_sizes
return local_sizes
def local_offsets(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_offsets = None
if rank in self._local_offsets_map.keys():
local_offsets = self._local_offsets_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_offsets_map[rank] = local_offsets
return local_offsets
def global_sizes(self):
return self.serial_tensor.shape
def local_shard(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_shard = None
if rank in self._local_shard_map.keys():
local_shard = self._local_shard_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_shard = DistributedTensor.get_local_shard(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_shard_map[rank] = local_shard
return local_shard
def new_local_tensor(self, block=None, rank=None, name=None):
"""
Create a new local tensor of serial tensor corresponding to rank.
Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
"""
def _copy_kwargs(serial_tensor):
kwargs = {}
no_need_copy_args = ["self", "block", "shape", "name"]
arg_spec = inspect.getargspec(Variable.__init__)
for key in arg_spec.args:
# TODO: Check the copied attribute from serial tensor whether valid
if key in no_need_copy_args:
continue
elif key not in kwargs:
if key == "type":
kwargs[key] = serial_tensor.desc.type()
elif key == "dtype":
kwargs[key] = serial_tensor.desc.dtype()
elif key == "lod_level":
kwargs[key] = serial_tensor.desc.lod_level()
elif key == "persistable":
kwargs[key] = serial_tensor.desc.persistable()
elif key == "stop_gradient":
kwargs[key] = serial_tensor.desc.stop_gradient()
elif key == "need_check_feed":
kwargs[key] = serial_tensor.desc.need_check_feed()
# TODO: Get capacity by framework
elif key == "capacity":
continue
else:
kwargs[key] = self.serial_tensor.__dict__[key]
if isinstance(serial_tensor, Parameter):
kwargs["trainable"] = serial_tensor.trainable
kwargs["optimize_attr"] = serial_tensor.trainable
kwargs["regularizer"] = serial_tensor.regularizer
kwargs["do_model_average"] = serial_tensor.do_model_average
kwargs["need_clip"] = serial_tensor.need_clip
kwargs["is_distributed"] = serial_tensor.is_distributed
kwargs["is_parameter"] = serial_tensor.is_parameter
return kwargs
if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank))
if block is not None and not isinstance(block, Block):
raise TypeError("The block must be Block, but got {}.".format(
type(block)))
rank = paddle.distributed.get_rank() if rank is None else rank
if block is None:
block_id = self.serial_tensor.block.idx
block = self.dist_context.dist_main_programs[rank].block(block_id)
# copy serial tensor attribute
kwargs = _copy_kwargs(self.serial_tensor)
kwargs["name"] = name
kwargs["shape"] = self.local_sizes(rank)
if isinstance(self.serial_tensor, Parameter):
kwargs.pop("persistable")
local_tensor = Parameter(block=block, **kwargs)
else:
local_tensor = block.create_var(**kwargs)
# TODO: Set original id when set original_id is approved
local_tensor.desc.set_original_id(self.serial_tensor.desc.id())
self._local_tensor_map[rank] = local_tensor
return local_tensor
def local_tensor(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
assert rank in self._local_tensor_map, "The rank {} local tensor has not been created.".format(
rank)
return self._local_tensor_map[rank]
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__ cls = self.__class__
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_tensor": if k == "_serial_tensor" or k == "_local_tensor_map":
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
......
...@@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) ...@@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_dist_tensor)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp)
...@@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_dist_tensor)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
...@@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE) ...@@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_dist_tensor MODULES test_auto_parallel_dist_tensor ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import paddle
from paddle.fluid import core
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute
import test_auto_parallel_reshard
from test_auto_parallel_reshard import mlp_forward
def get_dist_prog(train_program,
startup_program,
dist_context,
rank_id,
complete_train_program=None):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(
train_program, dist_context
) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program,
# startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program
class TestDistributedTensor(unittest.TestCase):
def test_new_local_tensor(self):
test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh(
mesh=[0, 1])
test_auto_parallel_reshard._global_parallel_strategy = "dp"
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 0
dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
dist_context.dist_main_programs[rank_id] = dist_main_prog
dist_context.dist_startup_programs[rank_id] = dist_startup_prog
name = "layer_norm_1.tmp_2"
dist_tensor = dist_context.get_dist_tensor_for_program(
complete_train_program.global_block().vars[name])
dist_tensor._dist_context = dist_context
intermediate_var_0 = dist_tensor.new_local_tensor(
name="intermediate_var_0")
self.assertEqual(intermediate_var_0.shape, (2, 1024))
self.assertEqual(intermediate_var_0.name, "intermediate_var_0")
rank_id = 1
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_main_prog, dist_startup_prog, _ = get_dist_prog(
train_program, startup_program, dist_context, rank_id,
complete_train_program)
dist_context.dist_main_programs[rank_id] = dist_main_prog
dist_context.dist_startup_programs[rank_id] = dist_startup_prog
name = "layer_norm_1.tmp_2"
dist_tensor = dist_context.get_dist_tensor_for_program(
complete_train_program.global_block().vars[name])
dist_tensor._dist_context = dist_context
intermediate_var_1 = dist_tensor.new_local_tensor(
rank=rank_id, name="intermediate_var_1")
self.assertEqual(intermediate_var_0.shape, (2, 1024))
self.assertEqual(intermediate_var_1.name, "intermediate_var_1")
name = "linear_0.w_0"
dist_tensor = dist_context.get_dist_tensor_for_program(
complete_train_program.global_block().vars[name])
dist_tensor._dist_context = dist_context
intermediate_var_1 = dist_tensor.new_local_tensor(
rank=rank_id, name="linear_0.w_0_intermediate")
self.assertEqual(intermediate_var_1.shape, (1024, 4096))
self.assertEqual(intermediate_var_1.name, "linear_0.w_0_intermediate")
copied_dist_context = copy.deepcopy(dist_context)
self.assertIsNotNone(copied_dist_context)
self.assertEqual(
id(copied_dist_context),
id(
copied_dist_context.get_dist_tensor_for_program(
dist_tensor.serial_tensor).dist_context))
def test_static_method(self):
dims_mapping = [1, 0]
processes = [0, 1, 2, 3, 4, 5, 6]
topology = [2, 3]
global_sizes = [6, 6]
# rank 0 [(0, 2), (0, 3)]
# rank 1 [(2, 4), (0, 3)]
# rank 4 [(2, 4), (3, 6)]
rank = 0
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes)
self.assertEqual(local_sizes, [2, 3])
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_offsets, [0, 0])
local_shard = DistributedTensor.get_local_shard(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_shard, [(0, 2), (0, 3)])
rank = 1
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes)
self.assertEqual(local_sizes, [2, 3])
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_offsets, [2, 0])
local_shard = DistributedTensor.get_local_shard(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_shard, [(2, 4), (0, 3)])
rank = 4
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes)
self.assertEqual(local_sizes, [2, 3])
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_offsets, [2, 3])
local_shard = DistributedTensor.get_local_shard(
global_sizes, dims_mapping, topology, processes, rank)
self.assertEqual(local_shard, [(2, 4), (3, 6)])
# global sizes
local_sizes = [2, 3]
global_sizes = DistributedTensor.get_global_sizes(
local_sizes, dims_mapping, topology, processes)
self.assertEqual(global_sizes, [6, 6])
def test_instance_method(self):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = [1, 0]
tensor_dist_attr.process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2], [3, 4, 5]])
serial_tensor = paddle.static.data(
name="data", shape=[6, 6], dtype='float32')
dist_tensor = DistributedTensor(serial_tensor, tensor_dist_attr)
# rank 0 [(0, 2), (0, 3)]
# rank 1 [(2, 4), (0, 3)]
# rank 4 [(2, 4), (3, 6)]
rank = 0
local_sizes = dist_tensor.local_sizes(rank)
self.assertEqual(local_sizes, [2, 3])
local_offsets = dist_tensor.local_offsets(rank)
self.assertEqual(local_offsets, [0, 0])
local_shard = dist_tensor.local_shard(rank)
self.assertEqual(local_shard, [(0, 2), (0, 3)])
self.assertEqual(local_sizes, dist_tensor.local_sizes(rank))
self.assertEqual(local_offsets, dist_tensor.local_offsets(rank))
self.assertEqual(local_shard, dist_tensor.local_shard(rank))
self.assertEqual(local_sizes, dist_tensor.local_sizes())
self.assertEqual(local_offsets, dist_tensor.local_offsets())
self.assertEqual(local_shard, dist_tensor.local_shard())
rank = 1
local_sizes = dist_tensor.local_sizes(rank)
self.assertEqual(local_sizes, [2, 3])
local_offsets = dist_tensor.local_offsets(rank)
self.assertEqual(local_offsets, [2, 0])
local_shard = dist_tensor.local_shard(rank)
self.assertEqual(local_shard, [(2, 4), (0, 3)])
rank = 4
local_sizes = dist_tensor.local_sizes(rank)
self.assertEqual(local_sizes, [2, 3])
local_offsets = dist_tensor.local_offsets(rank)
self.assertEqual(local_offsets, [2, 3])
local_shard = dist_tensor.local_shard(rank)
self.assertEqual(local_shard, [(2, 4), (3, 6)])
global_sizes = dist_tensor.global_sizes()
self.assertEqual(global_sizes, (6, 6))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册