未验证 提交 7a92e74b 编写于 作者: W Wen Sun 提交者: GitHub

Completes basic dtypes for collective api in eager mode (#45574)

上级 1137677a
......@@ -738,14 +738,23 @@ void* GetPointerByOffset(void* raw_pointer,
} else if (type == experimental::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
} else if (type == experimental::DataType::INT8) {
return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::UINT8) {
return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::BOOL) {
return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -124,6 +124,8 @@ PD_REGISTER_KERNEL(concat,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -121,6 +121,7 @@ PD_REGISTER_KERNEL(concat,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
......
......@@ -78,7 +78,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
......@@ -92,6 +92,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
)
set_tests_properties(test_collective_alltoall_single PROPERTIES TIMEOUT "350")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_alltoall_single_api MODULES
test_collective_alltoall_single_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_single_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_barrier_api MODULES test_collective_barrier_api ENVS
......@@ -117,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -141,6 +149,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_global_scatter
PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_isend_irecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_optimizer MODULES test_collective_optimizer ENVS
......@@ -186,6 +201,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
)
set_tests_properties(test_collective_reduce_scatter PROPERTIES TIMEOUT "350")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_reduce_scatter_api MODULES
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_scatter_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_scatter MODULES test_collective_scatter ENVS
......@@ -212,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
......@@ -45,12 +45,9 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
output_data = []
for data in tout_data:
output_data.append(data.numpy())
return output_data
toutdata = []
paddle.distributed.alltoall(tindata, toutdata)
return [data.numpy() for data in toutdata]
if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import test_collective_api_base as test_base
class TestCollectiveAllToAllSingleAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
toutdata = paddle.to_tensor(indata)
paddle.distributed.alltoall_single(tindata, toutdata)
return [toutdata.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveAllToAllSingleAPI, "alltoall")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveBroadcastAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
paddle.distributed.broadcast(tindata, src=1)
return [tindata.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveBroadcastAPI, "broadcast")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveIsendIrecvAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
if rank == 0:
task = paddle.distributed.isend(tindata, dst=1)
else:
task = paddle.distributed.irecv(tindata, src=0)
task.wait()
return [tindata.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveIsendIrecvAPI, "sendrecv")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveReduceAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
paddle.distributed.reduce(tindata, dst=0)
return [tindata.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveReduceAPI, "reduce")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveReduceScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
paddle.distributed.reduce_scatter(subdata1, [subdata1, subdata2])
return [subdata1.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base
class TestCollectiveScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
if rank == 0:
paddle.distributed.scatter(subdata1, src=1)
else:
paddle.distributed.scatter(subdata1,
tensor_list=[subdata1, subdata2],
src=1)
return [subdata1.numpy()]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveScatterAPI, "scatter")
......@@ -31,10 +31,16 @@ class TestCollectiveAllToAllAPI(TestDistBase):
self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl")
def test_alltoall_nccl_dygraph(self):
self.check_with_place("collective_alltoall_api_dygraph.py",
"alltoall",
"nccl",
static_mode="0")
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_alltoall_api_dygraph.py",
"alltoall",
"nccl",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import test_collective_api_base as test_base
class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_alltooall_single_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_alltoall_single_api_dygraph.py",
"alltoall",
"nccl",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
......@@ -35,6 +35,31 @@ class TestCollectiveBroadcastAPI(TestDistBase):
self.check_with_place("collective_broadcast_api.py", "broadcast",
"gloo", "0")
def test_broadcast_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_broadcast_api_dygraph.py",
"broadcast",
"nccl",
static_mode="0",
dtype=dtype)
def test_broadcast_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_broadcast_api_dygraph.py",
"broadcast",
"gloo",
"0",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import test_collective_api_base as test_base
class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_isend_irecv_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_isend_irecv_api_dygraph.py",
"sendrecv",
"nccl",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
......@@ -38,6 +38,31 @@ class TestCollectiveReduceAPI(TestDistBase):
def test_reduce_gloo(self):
self.check_with_place("collective_reduce_api.py", "reduce", "gloo", "1")
def test_reduce_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_api_dygraph.py",
"reduce",
"nccl",
static_mode="0",
dtype=dtype)
def test_reduce_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_api_dygraph.py",
"reduce",
"gloo",
"1",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import test_collective_api_base as test_base
class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_reduce_scatter_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_scatter_api_dygraph.py",
"reduce_scatter",
"nccl",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
......@@ -34,6 +34,31 @@ class TestCollectiveScatterAPI(TestDistBase):
def test_scatter_nccl(self):
self.check_with_place("collective_scatter_api.py", "scatter", "nccl")
def test_scatter_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_scatter_api_dygraph.py",
"scatter",
"nccl",
static_mode="0",
dtype=dtype)
def test_scatter_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_scatter_api_dygraph.py",
"scatter",
"gloo",
"4",
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
unittest.main()
......@@ -33,11 +33,16 @@ class TestCollectiveSendRecvAPI(TestDistBase):
# "nccl")
def test_sendrecv_nccl_dygraph(self):
if paddle.fluid.core.is_compiled_with_cuda():
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
]
for dtype in dtypes_to_test:
self.check_with_place("collective_sendrecv_api_dygraph.py",
"sendrecv",
"nccl",
static_mode='0')
static_mode="0",
dtype=dtype)
if __name__ == '__main__':
......
......@@ -8,23 +8,26 @@ test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHO
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
......@@ -335,6 +335,12 @@ class TestDistBase(unittest.TestCase):
need_result2 = need_result[need_result.shape[0] // 2:]
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
elif col_type == "reduce_scatter":
need_result = input1 + input2
need_result1 = need_result[0:need_result.shape[0] // 2]
need_result2 = need_result[need_result.shape[0] // 2:]
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
elif col_type == "allreduce":
need_result = input1 + input2
np.testing.assert_allclose(tr0_out[0],
......
......@@ -1015,7 +1015,7 @@ def concat(x, axis=0, name=None):
Args:
x (list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
float32, float64, int32, int64, uint8. All the Tensors in ``x`` must have same data type.
float32, float64, int32, int64, int8, uint8. All the Tensors in ``x`` must have same data type.
axis (int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
......@@ -1073,10 +1073,10 @@ def concat(x, axis=0, name=None):
check_type(input, 'input', (list, tuple, Variable), 'concat')
if not isinstance(input, Variable):
for id, x in enumerate(input):
check_variable_and_dtype(
x, 'input[' + str(id) + ']',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'concat')
check_variable_and_dtype(x, 'input[' + str(id) + ']', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64',
'int8', 'unit8'
], 'concat')
if x.dtype != input[0].dtype:
raise TypeError(
"All the Tensors in the input must have the same data type."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册