未验证 提交 d4cf02bc 编写于 作者: L LiYuRio 提交者: GitHub

Complete the dtypes for all_gather, add all_gather_object api (#44417)

上级 768e50c9
......@@ -79,6 +79,15 @@ namespace distributed {
case experimental::DataType::INT64: \
func<int64_t>(args); \
break; \
case experimental::DataType::INT8: \
func<int8_t>(args); \
break; \
case experimental::DataType::UINT8: \
func<uint8_t>(args); \
break; \
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
......
......@@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather,
ops::CAllGatherOpCPUKernel<double>,
ops::CAllGatherOpCPUKernel<int>,
ops::CAllGatherOpCPUKernel<int64_t>,
ops::CAllGatherOpCPUKernel<uint8_t>,
ops::CAllGatherOpCPUKernel<int8_t>,
ops::CAllGatherOpCPUKernel<bool>,
ops::CAllGatherOpCPUKernel<plat::float16>);
......@@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpCUDAKernel<uint8_t>,
ops::CAllGatherOpCUDAKernel<int8_t>,
ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpCUDAKernel<bool>,
ops::CAllGatherOpCUDAKernel<plat::float16>);
......@@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclFloat16;
} else if (type == framework::proto::VarType::INT8) {
return ncclInt8;
} else if (type == framework::proto::VarType::UINT8) {
return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
......@@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt64;
} else if (type == experimental::DataType::FLOAT16) {
return ncclFloat16;
} else if (type == experimental::DataType::UINT8) {
return ncclUint8;
} else if (type == experimental::DataType::INT8) {
return ncclInt8;
} else if (type == experimental::DataType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
......
......@@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split,
int64_t,
int,
bool,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split,
int64_t,
int,
bool,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -31,6 +31,7 @@ from .collective import broadcast # noqa: F401
from .collective import all_reduce # noqa: F401
from .collective import reduce # noqa: F401
from .collective import all_gather # noqa: F401
from .collective import all_gather_object # noqa: F401
from .collective import scatter # noqa: F401
from .collective import barrier # noqa: F401
from .collective import ReduceOp # noqa: F401
......@@ -71,7 +72,8 @@ __all__ = [ # noqa
"init_parallel_env", "gloo_init_parallel_env", "gloo_barrier",
"gloo_release", "QueueDataset", "split", "CountFilterEntry",
"ShowClickEntry", "get_world_size", "get_group", "all_gather",
"InMemoryDataset", "barrier", "all_reduce", "alltoall", "send", "reduce",
"recv", "ReduceOp", "wait", "get_rank", "ProbabilityEntry", "ParallelMode",
"is_initialized", "isend", "irecv", "reduce_scatter"
"all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall",
"send", "reduce", "recv", "ReduceOp", "wait", "get_rank",
"ProbabilityEntry", "ParallelMode", "is_initialized", "isend", "irecv",
"reduce_scatter"
]
......@@ -14,6 +14,8 @@
import numpy as np
import os
import pickle
import io
from datetime import timedelta
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable
......@@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
......@@ -941,7 +943,6 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
......@@ -949,21 +950,26 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
init_parallel_env()
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_gather(tensor_list, data2)
"""
if group is not None and not group.is_member():
return
def convert_to_complex(list_of_tensor):
list_of_complex = []
for tensor in list_of_tensor:
list_of_complex.append(paddle.as_complex(tensor))
return list_of_complex
is_input_complex = (tensor.dtype == paddle.complex64
or tensor.dtype == paddle.complex128)
if is_input_complex:
tensor = paddle.as_real(tensor)
if in_dygraph_mode():
group = _get_default_group() if group is None else group
if len(tensor_list) == 0:
......@@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
task = group.process_group.all_gather(tensor, out)
task.wait()
tensor_list.clear()
tensor_list.extend(paddle.split(out, group.nranks, 0))
list_of_tensor = paddle.split(out, group.nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
return
ring_id = 0 if group is None else group.id
......@@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
raise ValueError("The type of 'tensor_list' for all_gather "
"should be list.")
for elem in tensor_list:
check_variable_and_dtype(
elem, 'tensor_list',
['float16', 'float32', 'float64', 'int32', 'int64'],
'all_gather')
check_variable_and_dtype(
tensor, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
check_variable_and_dtype(elem, 'tensor_list', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
'int8', 'uint8', 'complex64', 'complex128'
], 'all_gather')
check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
'uint8', 'complex64', 'complex128'
], 'all_gather')
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
......@@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
'nranks': nranks
})
tensor_list.extend(paddle.split(out, nranks, 0))
list_of_tensor = paddle.split(out, nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor
def _convert_tensor_to_object(tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy())).load()
def all_gather_object(object_list, obj, group=None):
"""
Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.
Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
dist.init_parallel_env()
object_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
obj = {"foo": [1, 2, 3]}
paddle.distributed.all_gather_object(object_list, obj)
else:
obj = {"bar": [4, 5, 6]}
paddle.distributed.all_gather_object(object_list, obj)
"""
assert in_dygraph_mode(
), "all_gather_object doesn't support static graph mode."
tensor = _convert_object_to_tensor(obj)
tensor_list = []
all_gather(tensor_list, tensor, group)
for tensor in tensor_list:
object_list.append(_convert_tensor_to_object(tensor))
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
......
......@@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list(REMOVE_ITEM TEST_OPS test_new_group_api)
list(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
list(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
list(REMOVE_ITEM TEST_OPS test_collective_allgather_object_api)
list(REMOVE_ITEM TEST_OPS test_collective_alltoall_api)
list(REMOVE_ITEM TEST_OPS test_collective_global_gather)
list(REMOVE_ITEM TEST_OPS test_collective_global_scatter)
......@@ -1598,7 +1599,9 @@ if(APPLE)
endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 300)
set_tests_properties(test_collective_allgather_object_api PROPERTIES TIMEOUT
120)
set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 200)
set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 200)
......@@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_allgather_object_api
test_collective_alltoall_api
test_collective_global_gather
test_collective_global_scatter
......
......@@ -30,28 +30,64 @@ import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
import pickle
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import test_collective_api_base as test_base
paddle.enable_static()
class TestCollectiveAllgatherAPI(TestCollectiveAPIRunnerBase):
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
def get_model(self, main_prog, startup_program, rank, dtype=None):
dtype = "float32" if dtype is None else dtype
with fluid.program_guard(main_prog, startup_program):
tensor_list = []
tindata = layers.data(name="tindata",
shape=[10, 1000],
dtype='float32')
tindata = layers.data(name="tindata", shape=[10, 1000], dtype=dtype)
paddle.distributed.all_gather(tensor_list, tindata)
return tensor_list
def run_trainer(self, args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
device_id) #if args.use_gpu else fluid.CPUPlace()
elif args['backend'] == 'bkcl':
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id)
else:
place = fluid.CPUPlace()
indata = test_base.create_test_data(shape=(10, 1000),
dtype=args["dtype"],
seed=os.getpid())
assert args[
'static_mode'] == 1, "collective_allgather_api only support static mode"
result = self.get_model(train_prog,
startup_prog,
rank,
dtype=args["dtype"])
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=fetch_list)
sys.stdout.buffer.write(pickle.dumps(out))
if __name__ == "__main__":
runtime_main(TestCollectiveAllgatherAPI, "allgather")
test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tensor_list = []
paddle.distributed.all_gather(tensor_list, tindata)
return [tensor.numpy() for tensor in tensor_list]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import test_collective_api_base as test_base
class TestCollectiveAllgatherObjectAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
object_list = []
paddle.distributed.all_gather_object(object_list, indata)
return object_list
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveAllgatherObjectAPI, "allgather_object")
......@@ -28,12 +28,212 @@ class TestCollectiveAllgatherAPI(TestDistBase):
pass
def test_allgather_nccl(self):
self.check_with_place("collective_allgather_api.py", "allgather",
"nccl")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float16")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="bool")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="uint8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="complex64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="complex128")
def test_allgather_gloo(self):
self.check_with_place("collective_allgather_api.py", "allgather",
"gloo", "3")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float16")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="bool")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="uint8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="complex64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="complex128")
def test_allgatther_nccl_dygraph(self):
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float16")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="bool")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="uint8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="complex64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="complex128")
def test_allgather_gloo_dygraph(self):
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float16")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="bool")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="uint8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="complex64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="complex128")
if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import test_collective_api_base as test_base
class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_allgather_nccl(self):
self.check_with_place("collective_allgather_object_api_dygraph.py",
"allgather_object",
"nccl",
static_mode="0",
dtype="pylist")
self.check_with_place("collective_allgather_object_api_dygraph.py",
"allgather_object",
"nccl",
static_mode="0",
dtype="pydict")
def test_allgather_gloo_dygraph(self):
self.check_with_place("collective_allgather_object_api_dygraph.py",
"allgather_object",
"gloo",
"3",
static_mode="0",
dtype="pylist")
self.check_with_place("collective_allgather_object_api_dygraph.py",
"allgather_object",
"gloo",
"3",
static_mode="0",
dtype="pydict")
if __name__ == '__main__':
unittest.main()
......@@ -31,9 +31,77 @@ import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
def create_bool_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
data = np.random.choice([True, False], size=shape)
return data
def create_float_test_data(shape=None, dtype=None, seed=None):
if seed:
np.random.seed(seed)
data = np.random.random(shape).astype(dtype)
return data
def create_int_test_data(shape=None, dtype=None, seed=None):
if seed:
np.random.seed(seed)
data = np.random.randint(0, high=100, size=shape).astype(dtype)
return data
def create_complex_test_data(shape=None, dtype=None, seed=None):
if seed:
np.random.seed(seed)
data = np.random.random(shape).astype(dtype)
data.imag = np.random.random(shape)
return data
def create_pylist_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
data = np.random.random(shape).tolist()
return data
def create_pydict_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
key = [i for i in range(0, shape[0])]
value = np.random.random(shape).tolist()
data = dict(zip(key, value))
return data
def create_test_data(shape=None, dtype=None, seed=None):
assert shape, "Shape should be specified"
if dtype == "float32" or dtype == "float16" or dtype == "float64":
return create_float_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "bool":
return create_bool_test_data(shape=shape, seed=seed)
elif dtype == "int32" or dtype == "int64" or dtype == "int8" or dtype == "uint8":
return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "complex64" or dtype == "complex128":
return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "pylist":
return create_pylist_test_data(shape=shape, seed=seed)
elif dtype == "pydict":
return create_pydict_test_data(shape=shape, seed=seed)
else:
raise NotImplementedError("Unsupported dtype for creating test data.")
class TestCollectiveAPIRunnerBase(object):
def get_model(self, train_prog, startup_prog, rank, indata=None):
def get_model(self,
train_prog,
startup_prog,
rank,
indata=None,
dtype=None):
raise NotImplementedError(
"get model should be implemented by child class.")
......@@ -54,8 +122,9 @@ class TestCollectiveAPIRunnerBase(object):
place = fluid.XPUPlace(device_id)
else:
place = fluid.CPUPlace()
np.random.seed(os.getpid())
indata = np.random.random((10, 1000)).astype("float32")
indata = create_test_data(shape=(10, 1000),
dtype=args["dtype"],
seed=os.getpid())
if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
exe = fluid.Executor(place)
......@@ -83,6 +152,7 @@ def runtime_main(test_class, col_type):
args["backend"] = os.getenv("BACKEND")
args["path_id"] = int(os.getenv("PATH_ID"))
args["static_mode"] = int(os.getenv("STATIC_MODE"))
args["dtype"] = os.getenv("DTYPE")
model.run_trainer(args)
......@@ -203,18 +273,22 @@ class TestDistBase(unittest.TestCase):
static_mode="1",
check_error_log=False,
need_envs={},
eager_mode=True):
eager_mode=True,
dtype=None):
if backend == "nccl" or backend == "bkcl":
with_gloo = '0'
else:
with_gloo = '1'
required_envs = os.environ.copy()
dtype = "float32" if dtype is None else dtype
additional_envs = {
"NCCL_P2P_DISABLE": "1",
"STATIC_MODE": static_mode,
"PADDLE_WITH_GLOO": with_gloo,
"PADDLE_DISTRI_BACKEND": backend,
"BACKEND": backend,
"PATH_ID": path_id
"PATH_ID": path_id,
"DTYPE": dtype
}
required_envs.update(additional_envs)
required_envs.update(need_envs)
......@@ -234,16 +308,18 @@ class TestDistBase(unittest.TestCase):
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs)
np.random.seed(pid0)
input1 = np.random.random((10, 1000))
np.random.seed(pid1)
input2 = np.random.random((10, 1000))
input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0)
input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1)
if col_type == "allgather":
need_result = np.vstack((input1, input2))
tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
self.assertTrue(np.allclose(tr_out0, need_result))
self.assertTrue(np.allclose(tr_out1, need_result))
if col_type == "allgather_object":
need_result = [input1, input2]
self.assertEqual(need_result, tr0_out)
self.assertEqual(need_result, tr1_out)
elif col_type == "broadcast":
need_result = input2
self.assertTrue(np.allclose(tr0_out, need_result))
......
......@@ -1737,7 +1737,7 @@ def split(x, num_or_sections, axis=0, name=None):
Split the input tensor into multiple sub-Tensors.
Args:
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64.
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
......@@ -1814,9 +1814,10 @@ def split(x, num_or_sections, axis=0, name=None):
_C_ops.split(input, out, *attrs)
return out
check_variable_and_dtype(
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'split')
check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
'int8'
], 'split')
check_type(num_or_sections, 'num_or_sections', (list, int, tuple), 'split')
check_type(dim, 'dim', (int, Variable), 'split')
if isinstance(dim, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册