未验证 提交 163152aa 编写于 作者: L LiYuRio 提交者: GitHub

move test case from cpp to python (#56333)

上级 87bc6f2c
......@@ -80,6 +80,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
set_tests_properties(test_reshard_r_to_s
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.fluid import core
class TestReshardRToS:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs[self._shard] = "x"
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
reshard_func = core.RToSReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape)
if out_shape[self._shard] % 2 == 0:
out_shape[self._shard] = out_shape[self._shard] // 2
else:
out_shape[self._shard] = (
out_shape[self._shard] // 2
if dist.get_rank() == 1
else out_shape[self._shard] // 2 + 1
)
assert np.equal(out.shape, out_shape).all()
if __name__ == '__main__':
TestReshardRToS().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestReshardRToS(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
}
self._changeable_envs = {
"shape": ["(10, 20)", "(5, 7)"],
"shard": ["0", "1"],
"backend": ["cpu", "gpu"],
}
def test_reshard_r_to_s(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_r_to_s.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
......@@ -9,11 +9,6 @@ if(WITH_DISTRIBUTE)
dist_tensor_test
SRCS dist_tensor_test.cc
DEPS phi)
cc_test(
test_reshard_r_to_s
SRCS test_reshard_r_to_s.cc
DEPS phi)
endif()
cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
namespace tests {
std::shared_ptr<DistTensor> ConstructReplicatedDistCPU(
phi::CPUContext* dev_ctx,
const std::vector<int64_t>& shape,
const ProcessMesh& mesh) {
phi::CPUPlace cpu_place = dev_ctx->GetPlace();
const DDim dims(shape.data(), shape.size());
int64_t num_of_elems = 1;
for (const auto& value : shape) {
num_of_elems *= value;
}
phi::DenseTensor input_dense;
float* input_dense_ptr = input_dense.mutable_data<float>(dims, cpu_place);
std::vector<float> vec(num_of_elems);
memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float));
std::shared_ptr<TensorDistAttr> dist_attr =
std::make_shared<TensorDistAttr>(shape);
std::vector<int64_t> dims_mapping(shape.size(), -1);
dist_attr->set_dims_mapping(dims_mapping);
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense),
input_dense.meta(),
dist_attr);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::shared_ptr<DistTensor> ConstructReplicatedDistGPU(
phi::GPUContext* dev_ctx,
const std::vector<int64_t>& shape,
const ProcessMesh& mesh) {
phi::GPUPlace gpu_place = dev_ctx->GetPlace();
phi::CPUPlace cpu_place;
const DDim dims(shape.data(), shape.size());
int64_t num_of_elems = 1;
for (const auto& value : shape) {
num_of_elems *= value;
}
phi::DenseTensor input_dense;
phi::DenseTensor input_dense_gpu;
float* input_dense_ptr = input_dense.mutable_data<float>(dims, cpu_place);
std::vector<float> vec(num_of_elems);
memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float));
phi::Copy(*dev_ctx, input_dense, gpu_place, true, &input_dense_gpu);
std::shared_ptr<TensorDistAttr> dist_attr =
std::make_shared<TensorDistAttr>(shape);
std::vector<int64_t> dims_mapping(shape.size(), -1);
dist_attr->set_dims_mapping(dims_mapping);
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense_gpu),
input_dense_gpu.meta(),
dist_attr);
}
#endif
TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "1", 1);
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {-1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 12);
CHECK_EQ(output->dims(), DDim({6, 2}));
}
TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh_unbalance_split) {
setenv("PADDLE_TRAINER_ID", "1", 1);
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 16);
CHECK_EQ(output->dims(), DDim({2, 8}));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "0", 0);
std::vector<int64_t> tensor_shape = {6, 8, 4};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::GPUContext*>(pool.Get(phi::GPUPlace()));
std::vector<int64_t> mesh_shape = {6};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistGPU(context, tensor_shape, mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 32);
CHECK_EQ(output->dims(), DDim({1, 8, 4}));
}
#endif
TEST(reshard_r_to_s, r_to_s_diff_placement) {
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::vector<int64_t> out_process_ids = {2, 3, 4, 5};
ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {-1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(out_mesh);
RToSReshardFunction r_to_s_func;
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false);
}
TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) {
std::vector<int64_t> tensor_shape = {6, 12};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false);
}
} // namespace tests
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册