diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e27ae8ef6c4a10a5bd5baf8d53c976482442eb90..59cf3f782d00e47cc13a245730598ef27d969044 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1032,12 +1032,12 @@ def _convert_object_to_tensor(obj): _pickler(f).dump(obj) data = np.frombuffer(f.getvalue(), dtype=np.uint8) tensor = paddle.to_tensor(data) - return tensor + return tensor, tensor.numel() -def _convert_tensor_to_object(tensor): +def _convert_tensor_to_object(tensor, len_of_tensor): _unpickler = pickle.Unpickler - return _unpickler(io.BytesIO(tensor.numpy())).load() + return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load() def all_gather_object(object_list, obj, group=None): @@ -1076,12 +1076,25 @@ def all_gather_object(object_list, obj, group=None): assert in_dygraph_mode( ), "all_gather_object doesn't support static graph mode." - tensor = _convert_object_to_tensor(obj) + tensor, len_of_tensor = _convert_object_to_tensor(obj) + + # gather len_of_tensor from all ranks + list_len_of_tensor = [] + all_gather(list_len_of_tensor, len_of_tensor, group) + # get the max length from list + max_len_of_tensor = int(max(list_len_of_tensor).item()) + # resize the input tensor to max length avoid hang in all gather + # Note(liyurui): Maybe we should support various length all_gather? + # Now this operation is efficient for we don't support resize in python. + numpy_data = tensor.numpy() + numpy_data = np.resize(numpy_data, [max_len_of_tensor]) + input_tensor = paddle.to_tensor(numpy_data) tensor_list = [] - all_gather(tensor_list, tensor, group) - for tensor in tensor_list: - object_list.append(_convert_tensor_to_object(tensor)) + all_gather(tensor_list, input_tensor, group) + for i, tensor in enumerate(tensor_list): + object_list.append( + _convert_tensor_to_object(tensor, list_len_of_tensor[i])) def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 79457571aca90f6d0e879c55e9550e639df3e947..e52da771b8957bcbcfc23bad37635dcf7026e419 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -63,6 +63,8 @@ def create_complex_test_data(shape=None, dtype=None, seed=None): def create_pylist_test_data(shape=None, seed=None): if seed: np.random.seed(seed) + # Generate random shape test case for xxx_object api + shape = np.random.randint(0, high=100, size=(2)).tolist() data = np.random.random(shape).tolist() return data