s_to_r_reshard_function.cc 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"

#include "glog/logging.h"
18
#include "paddle/phi/common/int_array.h"
19 20 21
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
L
LiYuRio 已提交
22 23 24
#include "paddle/phi/kernels/all_gather_kernel.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
25 26 27 28

namespace phi {
namespace distributed {

29 30
bool SToRReshardFunction::IsSuitable(const DistTensor& in,
                                     const TensorDistAttr& out_dist_attr) {
31 32
  bool flag = true;
  const auto& in_dist_attr = in.dist_attr();
33
  const auto& in_dims_mapping = in_dist_attr.dims_mapping();
34

L
LiYuRio 已提交
35 36
  flag &= in_dist_attr.is_shard();
  flag &= out_dist_attr.is_replicated();
37

38 39
  const auto& in_process_mesh = in_dist_attr.process_mesh();
  const auto& out_process_mesh = out_dist_attr.process_mesh();
40 41 42 43 44

  flag &= (in_process_mesh.ndim() == 1);
  flag &= (out_process_mesh.ndim() == 1);
  flag &= (in_process_mesh == out_process_mesh);

45 46 47 48 49 50
  // Ensure the tensor is balanced split, or we need send/recv rather than
  // all_gather
  std::map<int64_t, int64_t> split_axis_to_mesh_axis =
      GetSplitAxisWithDimsMapping(in_dims_mapping);
  int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
  int64_t num_of_process = in_process_mesh.size();
51 52
  flag &= (in.local_dims()[static_cast<int>(split_axis)] * num_of_process ==
           in.dims()[static_cast<int>(split_axis)]);
53

54 55 56
  return flag;
}

57 58 59 60
void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
                               const DistTensor& in,
                               const TensorDistAttr& out_dist_attr,
                               DistTensor* out) {
61
  const auto& in_dist_attr = in.dist_attr();
62
  const auto& in_dims_mapping = in_dist_attr.dims_mapping();
63
  const auto& in_process_mesh = in_dist_attr.process_mesh();
64
  const auto& in_process_ids = in_process_mesh.process_ids();
L
LiYuRio 已提交
65
  auto dtype = in.dtype();
66 67 68 69

  // Since the precondition ensure the out_process_ids is equal to the
  // in_process_ids, so the participate process ids mush equal to either
  // in_process_ids or out_process_ids.
L
LiYuRio 已提交
70 71 72 73 74 75 76
  RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
                            AllGather,
                            dtype,
                            in_process_ids,
                            in.value(),
                            in_process_ids.size(),
                            GetMutableTensor(out));
77 78 79 80 81 82 83
  std::map<int64_t, int64_t> split_axis_to_mesh_axis =
      GetSplitAxisWithDimsMapping(in_dims_mapping);
  int64_t split_axis = split_axis_to_mesh_axis.begin()->first;

  if (split_axis == 0) {
    // If the input dist tensor is shard(0), the subsequent split
    // and concat is unnecessary.
L
LiYuRio 已提交
84
    SetDistProps(out, in.dims(), out_dist_attr);
85 86 87 88 89
  } else {
    // Since the result of all_gather always concat the tensor on axis 0,
    // first we need to split the result on axis 0,
    // then we need to concat the split result on input split axis.
    int64_t default_split_axis = 0;
90
    int64_t num_of_process = static_cast<int64_t>(in_process_ids.size());
91 92

    IntArray sections(std::vector<int64_t>(
93 94
        num_of_process,
        in.value().dims()[static_cast<int>(default_split_axis)]));
L
LiYuRio 已提交
95 96 97 98 99 100 101 102
    std::vector<DenseTensor> split_out_vec;
    RESHARD_FUNCTOR(dev_ctx,
                    Split,
                    dtype,
                    out->value(),
                    sections,
                    default_split_axis,
                    &split_out_vec);
103 104 105 106 107 108 109

    // Concat the result after split on correct axis.
    std::vector<const DenseTensor*> concat_input_vec;
    for (const auto& tensor : split_out_vec) {
      concat_input_vec.emplace_back(&tensor);
    }

L
LiYuRio 已提交
110 111 112 113 114 115 116 117
    RESHARD_FUNCTOR(dev_ctx,
                    Concat,
                    dtype,
                    concat_input_vec,
                    split_axis,
                    GetMutableTensor(out));

    SetDistProps(out, in.dims(), out_dist_attr);
118
  }
119 120
}

121 122
REGISTER_RESHARD_FUNC(SToRReshardFunction);

123 124
}  // namespace distributed
}  // namespace phi