From 163152aad17935709fdd45debcabee61c6a6ecf1 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Wed, 16 Aug 2023 20:13:30 +0800 Subject: [PATCH] move test case from cpp to python (#56333) --- test/auto_parallel/CMakeLists.txt | 3 + test/auto_parallel/reshard_r_to_s.py | 75 ++++++ test/auto_parallel/test_reshard_r_to_s.py | 46 ++++ test/cpp/auto_parallel/CMakeLists.txt | 5 - test/cpp/auto_parallel/test_reshard_r_to_s.cc | 242 ------------------ 5 files changed, 124 insertions(+), 247 deletions(-) create mode 100644 test/auto_parallel/reshard_r_to_s.py create mode 100644 test/auto_parallel/test_reshard_r_to_s.py delete mode 100644 test/cpp/auto_parallel/test_reshard_r_to_s.cc diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 0422c3be9c5..2bfdb89ffa0 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -80,6 +80,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r) set_tests_properties(test_reshard_s_to_r PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s) + set_tests_properties(test_reshard_r_to_s + PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout diff --git a/test/auto_parallel/reshard_r_to_s.py b/test/auto_parallel/reshard_r_to_s.py new file mode 100644 index 00000000000..d650b6b402d --- /dev/null +++ b/test/auto_parallel/reshard_r_to_s.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.fluid import core + + +class TestReshardRToS: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.ones(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + reshard_func = core.RToSReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + out_shape = list(self._shape) + + if out_shape[self._shard] % 2 == 0: + out_shape[self._shard] = out_shape[self._shard] // 2 + else: + out_shape[self._shard] = ( + out_shape[self._shard] // 2 + if dist.get_rank() == 1 + else out_shape[self._shard] // 2 + 1 + ) + + assert np.equal(out.shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardRToS().run_test_case() diff --git a/test/auto_parallel/test_reshard_r_to_s.py b/test/auto_parallel/test_reshard_r_to_s.py new file mode 100644 index 00000000000..187fa40918d --- /dev/null +++ b/test/auto_parallel/test_reshard_r_to_s.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardRToS(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(10, 20)", + "dtype": "float32", + "seeds": str(self._seeds), + } + self._changeable_envs = { + "shape": ["(10, 20)", "(5, 7)"], + "shard": ["0", "1"], + "backend": ["cpu", "gpu"], + } + + def test_reshard_r_to_s(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_r_to_s.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index 6e0ea8db1e0..c5912a6fa10 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -9,11 +9,6 @@ if(WITH_DISTRIBUTE) dist_tensor_test SRCS dist_tensor_test.cc DEPS phi) - - cc_test( - test_reshard_r_to_s - SRCS test_reshard_r_to_s.cc - DEPS phi) endif() cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc deleted file mode 100644 index 4885bc321e0..00000000000 --- a/test/cpp/auto_parallel/test_reshard_r_to_s.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "glog/logging.h" -#include "gtest/gtest.h" -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" -#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/tensor_utils.h" - -namespace phi { -namespace distributed { -namespace auto_parallel { -namespace tests { - -std::shared_ptr ConstructReplicatedDistCPU( - phi::CPUContext* dev_ctx, - const std::vector& shape, - const ProcessMesh& mesh) { - phi::CPUPlace cpu_place = dev_ctx->GetPlace(); - const DDim dims(shape.data(), shape.size()); - - int64_t num_of_elems = 1; - for (const auto& value : shape) { - num_of_elems *= value; - } - - phi::DenseTensor input_dense; - float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); - - std::vector vec(num_of_elems); - memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); - - std::shared_ptr dist_attr = - std::make_shared(shape); - - std::vector dims_mapping(shape.size(), -1); - dist_attr->set_dims_mapping(dims_mapping); - dist_attr->set_process_mesh(mesh); - - return std::make_shared( - std::make_shared(input_dense), - input_dense.meta(), - dist_attr); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -std::shared_ptr ConstructReplicatedDistGPU( - phi::GPUContext* dev_ctx, - const std::vector& shape, - const ProcessMesh& mesh) { - phi::GPUPlace gpu_place = dev_ctx->GetPlace(); - phi::CPUPlace cpu_place; - const DDim dims(shape.data(), shape.size()); - - int64_t num_of_elems = 1; - for (const auto& value : shape) { - num_of_elems *= value; - } - - phi::DenseTensor input_dense; - phi::DenseTensor input_dense_gpu; - float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); - - std::vector vec(num_of_elems); - memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); - phi::Copy(*dev_ctx, input_dense, gpu_place, true, &input_dense_gpu); - - std::shared_ptr dist_attr = - std::make_shared(shape); - - std::vector dims_mapping(shape.size(), -1); - dist_attr->set_dims_mapping(dims_mapping); - dist_attr->set_process_mesh(mesh); - - return std::make_shared( - std::make_shared(input_dense_gpu), - input_dense_gpu.meta(), - dist_attr); -} -#endif - -TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) { - setenv("PADDLE_TRAINER_ID", "1", 1); - - std::vector tensor_shape = {6, 8}; - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); - - std::vector mesh_shape = {4}; - std::vector process_ids = {0, 1, 2, 3}; - std::vector dim_names = {"x"}; - ProcessMesh mesh(mesh_shape, process_ids, dim_names); - - std::shared_ptr input = - ConstructReplicatedDistCPU(context, tensor_shape, mesh); - - std::shared_ptr out_dist_attr = - std::make_shared(tensor_shape); - std::vector out_dims_mapping = {-1, 0}; - out_dist_attr->set_dims_mapping(out_dims_mapping); - out_dist_attr->set_process_mesh(mesh); - - RToSReshardFunction r_to_s_func; - std::shared_ptr output = - r_to_s_func.Eval(context, *input, out_dist_attr); - - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); - CHECK_EQ(output->numel(), 12); - CHECK_EQ(output->dims(), DDim({6, 2})); -} - -TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh_unbalance_split) { - setenv("PADDLE_TRAINER_ID", "1", 1); - - std::vector tensor_shape = {6, 8}; - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); - - std::vector mesh_shape = {4}; - std::vector process_ids = {0, 1, 2, 3}; - std::vector dim_names = {"x"}; - ProcessMesh mesh(mesh_shape, process_ids, dim_names); - - std::shared_ptr input = - ConstructReplicatedDistCPU(context, tensor_shape, mesh); - - std::shared_ptr out_dist_attr = - std::make_shared(tensor_shape); - std::vector out_dims_mapping = {0, -1}; - out_dist_attr->set_dims_mapping(out_dims_mapping); - out_dist_attr->set_process_mesh(mesh); - - RToSReshardFunction r_to_s_func; - std::shared_ptr output = - r_to_s_func.Eval(context, *input, out_dist_attr); - - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); - CHECK_EQ(output->numel(), 16); - CHECK_EQ(output->dims(), DDim({2, 8})); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { - setenv("PADDLE_TRAINER_ID", "0", 0); - - std::vector tensor_shape = {6, 8, 4}; - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto* context = reinterpret_cast(pool.Get(phi::GPUPlace())); - - std::vector mesh_shape = {6}; - std::vector process_ids = {0, 1, 2, 3, 4, 5}; - std::vector dim_names = {"x"}; - ProcessMesh mesh(mesh_shape, process_ids, dim_names); - - std::shared_ptr out_dist_attr = - std::make_shared(tensor_shape); - std::vector out_dims_mapping = {0, -1, -1}; - out_dist_attr->set_dims_mapping(out_dims_mapping); - out_dist_attr->set_process_mesh(mesh); - - std::shared_ptr input = - ConstructReplicatedDistGPU(context, tensor_shape, mesh); - - RToSReshardFunction r_to_s_func; - std::shared_ptr output = - r_to_s_func.Eval(context, *input, out_dist_attr); - - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); - CHECK_EQ(output->numel(), 32); - CHECK_EQ(output->dims(), DDim({1, 8, 4})); -} -#endif - -TEST(reshard_r_to_s, r_to_s_diff_placement) { - std::vector tensor_shape = {6, 8}; - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); - - std::vector mesh_shape = {4}; - std::vector process_ids = {0, 1, 2, 3}; - std::vector dim_names = {"x"}; - ProcessMesh mesh(mesh_shape, process_ids, dim_names); - - std::shared_ptr input = - ConstructReplicatedDistCPU(context, tensor_shape, mesh); - - std::vector out_process_ids = {2, 3, 4, 5}; - ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names); - std::shared_ptr out_dist_attr = - std::make_shared(tensor_shape); - std::vector out_dims_mapping = {-1, 0}; - out_dist_attr->set_dims_mapping(out_dims_mapping); - out_dist_attr->set_process_mesh(out_mesh); - - RToSReshardFunction r_to_s_func; - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); -} - -TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { - std::vector tensor_shape = {6, 12}; - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); - - std::vector mesh_shape = {4, 2}; - std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; - std::vector dim_names = {"x", "y"}; - ProcessMesh mesh(mesh_shape, process_ids, dim_names); - - std::shared_ptr input = - ConstructReplicatedDistCPU(context, tensor_shape, mesh); - - std::shared_ptr out_dist_attr = - std::make_shared(tensor_shape); - std::vector out_dims_mapping = {1, 0}; - out_dist_attr->set_dims_mapping(out_dims_mapping); - out_dist_attr->set_process_mesh(mesh); - - RToSReshardFunction r_to_s_func; - - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); -} - -} // namespace tests -} // namespace auto_parallel -} // namespace distributed -} // namespace phi -- GitLab