# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle import paddle.distributed as dist import paddle.distributed.communication.stream as stream import paddle.framework as framework from .serialization_utils import ( convert_object_to_tensor, convert_tensor_to_object, ) def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): """ Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png :width: 800 :alt: scatter :align: center Args: tensor (Tensor): The output Tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None. src (int): The source rank id. Default value is 0. group (Group, optional): The group instance return by new_group or None for global default group. sync_op (bool, optional): Whether this op is a sync op. The default value is True. Returns: None. Examples: .. code-block:: python # required: distributed import paddle import paddle.distributed as dist dist.init_parallel_env() if dist.get_rank() == 0: data1 = paddle.to_tensor([7, 8, 9]) data2 = paddle.to_tensor([10, 11, 12]) dist.scatter(data1, src=1) else: data1 = paddle.to_tensor([1, 2, 3]) data2 = paddle.to_tensor([4, 5, 6]) dist.scatter(data1, tensor_list=[data1, data2], src=1) print(data1, data2) # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0) # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1) """ return stream.scatter(tensor, tensor_list, src, group, sync_op) def scatter_object_list( out_object_list, in_object_list=None, src=0, group=None ): """ Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in. Args: out_object_list (list): The list of objects to store the scattered objects. in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered. src (int): The source rank in global view. group (Group): The group instance return by new_group or None for global default group. Returns: None. Warning: This API only supports the dygraph mode. Examples: .. code-block:: python # required: distributed import paddle.distributed as dist dist.init_parallel_env() out_object_list = [] if dist.get_rank() == 0: in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}] else: in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}] dist.scatter_object_list(out_object_list, in_object_list, src=1) print(out_object_list) # [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0) # [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1) """ assert ( framework.in_dygraph_mode() ), "scatter_object_list doesn't support static graph mode." rank = dist.get_rank() in_obj_tensors = [] in_obj_sizes = [] if rank == src: for obj in in_object_list: obj_tensor, obj_size = convert_object_to_tensor(obj) in_obj_tensors.append(obj_tensor) in_obj_sizes.append(obj_size) max_obj_size_tensor = max(in_obj_sizes) else: # NOTE: shape can be [] after 0D tensor support max_obj_size_tensor = paddle.empty([1], dtype="int64") stream.broadcast(max_obj_size_tensor, src) max_obj_size = int(max_obj_size_tensor.item()) # resize to the same size in_tensor_list = [] for tensor in in_obj_tensors: numpy_data = tensor.numpy() numpy_data = np.resize(numpy_data, [max_obj_size]) in_tensor = paddle.to_tensor(numpy_data) in_tensor_list.append(in_tensor) out_tensor = paddle.empty([max_obj_size], dtype="uint8") scatter(out_tensor, in_tensor_list if rank == src else None, src) # NOTE: shape can be [] after 0D tensor support out_tensor_size = paddle.empty([1], dtype="int64") scatter(out_tensor_size, in_obj_sizes if rank == src else None, src) out_object_list.clear() out_object_list.append( convert_tensor_to_object(out_tensor, out_tensor_size.item()) )