未验证 提交 9f3b5f15 编写于 作者: L LiYuRio 提交者: GitHub

[Reshard] Implement replicated to split with same placement (#55552)

* Implement replicated to split reshard function

* fix link error in clang

* refine split functor

* simplify reshard code
上级 f5830c05
......@@ -26,6 +26,9 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#endif
namespace py = pybind11;
......@@ -107,6 +110,11 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
}
void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction")
.def(py::init<>());
#endif
py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>())
.def(py::init<const std::vector<int64_t> &,
......
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")
if(WITH_DISTRIBUTE)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
endif()
collect_srcs(
core_srcs
SRCS
......@@ -7,4 +20,4 @@ collect_srcs(
process_mesh.cc
dist_attr.cc
dist_mapper.cc
dist_tensor.cc)
${DISTRIBUTED_SRCS})
......@@ -31,7 +31,7 @@ class DistTensor final
public:
/// \brief Construct a dist tensor and allocate space.
/// \param a The allocator used to allocate space.
/// \param meta The meta data of dense tensor.
/// \param meta The meta data of dist tensor.
DistTensor(Allocator* a,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "glog/logging.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/kernel_factory.h"
namespace phi {
namespace distributed {
bool RToSReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= IsDimsMappingShard(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
return flag;
}
std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
const auto& out_process_mesh = out_dist_attr->process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value();
DenseTensor out_physical_tensor_cur_rank;
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(out_dims_mapping);
std::vector<int64_t> coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;
PADDLE_ENFORCE_LT(
mesh_axis,
out_process_mesh.ndim(),
phi::errors::OutOfRange(
"The mesh axis %lld exceed the size of process mesh %lld.",
mesh_axis,
out_process_mesh.ndim()));
int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis
<< ". Split will use axis " << mesh_axis << " of process_mesh."
<< " There will have " << num_of_process
<< " process participate in.";
// TODO(liyurui): Consider the tensor can not be balanced split,
// for example, the shape of tensor is {6} but want to split it by 4
// process.
IntArray sections(std::vector<int64_t>(
num_of_process, in.dims()[split_axis] / num_of_process));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
out_dist_attr);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
namespace phi {
namespace distributed {
class RToSReshardFunction final : public ReshardFunction {
public:
RToSReshardFunction() = default;
~RToSReshardFunction() = default;
bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace phi {
namespace distributed {} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace phi {
class DeviceContext;
namespace distributed {
namespace auto_parallel {
class TensorDistAttr;
} // namespace auto_parallel
class DistTensor;
using auto_parallel::TensorDistAttr;
class ReshardFunction {
public:
ReshardFunction() = default;
virtual ~ReshardFunction() = default;
virtual bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi {
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
size_t out_number = sections.size();
std::vector<DenseTensor> result(out_number);
std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;
out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(result[i]);
out_meta_ptr.emplace_back(&out_meta.back());
}
SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr);
std::vector<DenseTensor*> outs;
for (size_t i = 0; i < out_number; ++i) {
outs.emplace_back(&result[i]);
}
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The split in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <vector>
#include "paddle/phi/common/int_array.h"
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis);
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include <cstdlib>
#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
namespace phi {
namespace distributed {
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
dims_mapping.end(),
[](int64_t value) { return value != -1; });
}
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
return std::all_of(dims_mapping.begin(),
dims_mapping.end(),
[](int64_t value) { return value == -1; });
}
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids();
int64_t ndims_mesh = process_shape.size();
int64_t cur_global_rank = GetCurGlobalRank();
VLOG(3) << "Searching current global rank " << cur_global_rank
<< " in process_mesh " << process_mesh;
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
PADDLE_ENFORCE_NE(
iter,
process_ids.end(),
phi::errors::NotFound("Rank %lld cannot be found in process_mesh",
cur_global_rank));
int64_t flat_idx_in_mesh = iter - process_ids.begin();
std::vector<int64_t> coord(ndims_mesh, -1);
for (int64_t i = ndims_mesh - 1; i >= 0; --i) {
coord[i] = flat_idx_in_mesh % process_shape[i];
flat_idx_in_mesh /= process_shape[i];
}
return coord;
}
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
}
}
return split_axis_to_mesh_axis;
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <map>
#include <vector>
namespace phi {
namespace distributed {
namespace auto_parallel {
class ProcessMesh;
} // namespace auto_parallel
using auto_parallel::ProcessMesh;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);
int64_t GetCurGlobalRank();
// Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1].
std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
// If the index i's value in dims_mapping is x ( x != -1), means the ith axis of
// tensor need be split by xth axis of process_mesh. The function analyze the
// input vector, return a key-value map of tensor_split_axis and
// process_mesh_split_axis.
// For example, if dims_mapping is [-1, 1, -1, 0], will return {1: 1, 3: 0}.
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);
} // namespace distributed
} // namespace phi
......@@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split,
bool,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::complex64,
phi::complex128) {}
PD_REGISTER_KERNEL(split_with_num,
CPU,
......
......@@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split,
bool,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::complex64,
phi::complex128) {}
PD_REGISTER_KERNEL(split_with_num,
GPU,
......
......@@ -9,6 +9,11 @@ if(WITH_DISTRIBUTE)
dist_tensor_test
SRCS dist_tensor_test.cc
DEPS phi)
cc_test(
test_reshard_r_to_s
SRCS test_reshard_r_to_s.cc
DEPS phi)
endif()
cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
namespace tests {
std::shared_ptr<DistTensor> ConstructReplicatedDistCPU(
phi::CPUContext* dev_ctx,
const std::vector<int64_t>& shape,
const ProcessMesh& mesh) {
phi::CPUPlace cpu_place = dev_ctx->GetPlace();
const DDim dims(shape.data(), shape.size());
int64_t num_of_elems = 1;
for (const auto& value : shape) {
num_of_elems *= value;
}
phi::DenseTensor input_dense;
float* input_dense_ptr = input_dense.mutable_data<float>(dims, cpu_place);
std::vector<float> vec(num_of_elems);
memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float));
std::shared_ptr<TensorDistAttr> dist_attr =
std::make_shared<TensorDistAttr>(shape);
std::vector<int64_t> dims_mapping(shape.size(), -1);
dist_attr->set_dims_mapping(dims_mapping);
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense), dist_attr);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::shared_ptr<DistTensor> ConstructReplicatedDistGPU(
phi::GPUContext* dev_ctx,
const std::vector<int64_t>& shape,
const ProcessMesh& mesh) {
phi::GPUPlace gpu_place = dev_ctx->GetPlace();
phi::CPUPlace cpu_place;
const DDim dims(shape.data(), shape.size());
int64_t num_of_elems = 1;
for (const auto& value : shape) {
num_of_elems *= value;
}
phi::DenseTensor input_dense;
phi::DenseTensor input_dense_gpu;
float* input_dense_ptr = input_dense.mutable_data<float>(dims, cpu_place);
std::vector<float> vec(num_of_elems);
memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float));
phi::Copy(*dev_ctx, input_dense, gpu_place, true, &input_dense_gpu);
std::shared_ptr<TensorDistAttr> dist_attr =
std::make_shared<TensorDistAttr>(shape);
std::vector<int64_t> dims_mapping(shape.size(), -1);
dist_attr->set_dims_mapping(dims_mapping);
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense_gpu), dist_attr);
}
#endif
TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "1", 1);
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {-1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 12);
CHECK_EQ(output->dims(), DDim({6, 2}));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "0", 0);
std::vector<int64_t> tensor_shape = {6, 8, 4};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::GPUContext*>(pool.Get(phi::GPUPlace()));
std::vector<int64_t> mesh_shape = {6};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistGPU(context, tensor_shape, mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 32);
CHECK_EQ(output->dims(), DDim({1, 8, 4}));
}
#endif
TEST(reshard_r_to_s, r_to_s_diff_placement) {
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::vector<int64_t> out_process_ids = {2, 3, 4, 5};
ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {-1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(out_mesh);
RToSReshardFunction r_to_s_func;
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false);
}
TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) {
std::vector<int64_t> tensor_shape = {6, 12};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {1, 0};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false);
}
} // namespace tests
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册