未验证 提交 e48cb42b 编写于 作者: L LiYuRio 提交者: GitHub

fix all_gather_object with various length, test=allcases (#44718)

上级 3e8708bc
......@@ -1032,12 +1032,12 @@ def _convert_object_to_tensor(obj):
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor
return tensor, tensor.numel()
def _convert_tensor_to_object(tensor):
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy())).load()
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
def all_gather_object(object_list, obj, group=None):
......@@ -1076,12 +1076,25 @@ def all_gather_object(object_list, obj, group=None):
assert in_dygraph_mode(
), "all_gather_object doesn't support static graph mode."
tensor = _convert_object_to_tensor(obj)
tensor, len_of_tensor = _convert_object_to_tensor(obj)
# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)
tensor_list = []
all_gather(tensor_list, tensor, group)
for tensor in tensor_list:
object_list.append(_convert_tensor_to_object(tensor))
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i]))
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
......
......@@ -63,6 +63,8 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
def create_pylist_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
# Generate random shape test case for xxx_object api
shape = np.random.randint(0, high=100, size=(2)).tolist()
data = np.random.random(shape).tolist()
return data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册