未验证 提交 8fc9a817 编写于 作者: L LiYuRio 提交者: GitHub

support r to s unbalanced split (#56149)

上级 476bc134
......@@ -78,11 +78,9 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
<< " There will have " << num_of_process
<< " process participate in.";
// TODO(liyurui): Consider the tensor can not be balanced split,
// for example, the shape of tensor is {6} but want to split it by 4
// process.
IntArray sections(std::vector<int64_t>(
num_of_process, in.dims()[split_axis] / num_of_process));
std::vector<int64_t> split_num_vec =
BalancedSplit(in.dims()[split_axis], num_of_process);
IntArray sections(split_num_vec);
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
......@@ -90,6 +88,8 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims();
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
......
......@@ -27,45 +27,26 @@ std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
size_t out_number = sections.size();
std::vector<DenseTensor> result(out_number);
std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;
out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(result[i]);
out_meta_ptr.emplace_back(&out_meta.back());
}
SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr);
std::vector<DenseTensor*> outs;
for (size_t i = 0; i < out_number; ++i) {
outs.emplace_back(&result[i]);
}
std::vector<DenseTensor> result;
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
......
......@@ -189,5 +189,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
return comm_context;
}
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces);
int64_t remain_nums = total_nums % num_of_pieces;
for (int64_t i = 0; i < remain_nums; ++i) {
result[i] += 1;
}
return result;
}
} // namespace distributed
} // namespace phi
......@@ -69,5 +69,10 @@ uint16_t GetMasterPort();
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();
// If given a number, balance split it to multiple pieces.
// For example, the input value is 12, split it to 5 pieces, then return
// {3, 3, 2, 2, 2}.
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
} // namespace distributed
} // namespace phi
......@@ -50,31 +50,43 @@ void SplitWithNumStridedKernel(const Context& dev_ctx,
std::vector<DenseTensor*> out);
template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis) {
size_t out_number;
out_number = sections.GetData().size();
void Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis,
std::vector<DenseTensor>* result) {
size_t out_number = sections.GetData().size();
std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;
out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
std::vector<DenseTensor> result(out_number);
result->resize(out_number);
for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(&result[i]);
out_meta.emplace_back(&result->at(i));
out_meta_ptr.push_back(&out_meta.back());
}
SplitInferMeta(x, sections, axis, out_meta_ptr);
std::vector<DenseTensor*> outs;
outs.reserve(out_meta.size());
for (size_t i = 0; i < out_meta.size(); ++i) {
outs.push_back(&result[i]);
outs.push_back(&result->at(i));
}
SplitKernel<T, Context>(dev_ctx, x, sections, axis, outs);
}
template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis) {
size_t out_number = sections.GetData().size();
std::vector<DenseTensor> result(out_number);
Split(dev_ctx, x, sections, axis, &result);
return result;
}
......
......@@ -121,6 +121,36 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
CHECK_EQ(output->dims(), DDim({6, 2}));
}
TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh_unbalance_split) {
setenv("PADDLE_TRAINER_ID", "1", 1);
std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);
std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 16);
CHECK_EQ(output->dims(), DDim({2, 8}));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "0", 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册