recv_v2_op.cu.cc 9.3 KB
Newer Older
L
lilong12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/recv_v2_op.h"

17
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
L
lilong12 已提交
18
#include "paddle/fluid/platform/collective_helper.h"
19
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
L
lilong12 已提交
20 21
#endif

W
Wen Sun 已提交
22
#include "paddle/fluid/distributed/collective/process_group.h"
23 24
#include "paddle/phi/api/include/tensor.h"

L
lilong12 已提交
25 26 27
namespace paddle {
namespace operators {

28 29 30 31
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
    NCCL_VERSION_CODE >= 2703
framework::DDim recv_shape_info(const platform::Place &place,
                                const gpuStream_t &stream,
32 33
                                platform::NCCLComm *comm,
                                const int &peer,
34 35
                                distributed::ProcessGroup *group) {
  if (!group) {
36 37
    PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr),
                      true,
38 39 40 41 42 43 44 45 46 47 48
                      platform::errors::InvalidArgument(
                          "NCCLComm and Stream should be provided if use NCCL "
                          "to send the shape info."));
  }

  paddle::experimental::DataType shape_dytpe =
      paddle::experimental::DataType::INT32;
  ncclDataType_t nccl_dtype =
      platform::ToNCCLDataType(framework::TransToProtoVarType(shape_dytpe));

  // step1: recv the shape size
49
  phi::DenseTensor gpu_shape_size_tensor(shape_dytpe);
50 51 52 53 54 55 56 57 58
  if (!group) {
    gpu_shape_size_tensor.Resize({1});
    gpu_shape_size_tensor.mutable_data(place, shape_dytpe);
    auto *gpu_data = gpu_shape_size_tensor.data<int>();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
        gpu_data, 1, nccl_dtype, peer, comm->comm(), stream));
  }

  // copy the shape size tensor to cpu
59
  phi::DenseTensor *cpu_shape_size_tensor = new phi::DenseTensor(shape_dytpe);
60 61 62
  cpu_shape_size_tensor->Resize({1});
  cpu_shape_size_tensor->mutable_data(platform::CPUPlace(), shape_dytpe);
  if (group) {
63
    std::vector<phi::DenseTensor> shape_size_tensor;
64 65 66
    shape_size_tensor.emplace_back(*cpu_shape_size_tensor);
    auto shape_size_task = group->Recv(shape_size_tensor, peer);
  } else {
67 68
    framework::TensorCopySync(
        gpu_shape_size_tensor, platform::CPUPlace(), cpu_shape_size_tensor);
69 70 71 72 73 74
  }
  auto *cpu_data = cpu_shape_size_tensor->data<int>();
  int shape_size = cpu_data[0];
  VLOG(3) << "recv the shape size: " << shape_size << " from peer";

  // step2: recv the shape
75
  phi::DenseTensor gpu_shape_tensor(shape_dytpe);
76 77 78 79 80 81 82 83 84
  if (!group) {
    gpu_shape_tensor.Resize({shape_size});
    gpu_shape_tensor.mutable_data(place, shape_dytpe);
    auto *gpu_shape_data = gpu_shape_tensor.data<int>();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
        gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream));
  }

  // copy the shape tensor to cpu
85
  phi::DenseTensor *cpu_shape_tensor = new phi::DenseTensor(shape_dytpe);
86 87 88
  cpu_shape_tensor->Resize({shape_size});
  cpu_shape_tensor->mutable_data(platform::CPUPlace(), shape_dytpe);
  if (group) {
89
    std::vector<phi::DenseTensor> shape_tensor;
90 91 92
    shape_tensor.emplace_back(*cpu_shape_tensor);
    auto shape_task = group->Recv(shape_tensor, peer);
  } else {
93 94
    framework::TensorCopySync(
        gpu_shape_tensor, platform::CPUPlace(), cpu_shape_tensor);
95 96 97 98 99 100 101 102 103 104 105 106 107 108
  }
  auto *cpu_shape_data = cpu_shape_tensor->data<int>();
  std::vector<int> all_shape;
  for (int i = 0; i < shape_size; ++i) {
    all_shape.emplace_back(cpu_shape_data[i]);
  }
  framework::DDim new_dim;
  new_dim = new_dim.reshape(all_shape);
  VLOG(3) << "recv the shape: (" << new_dim << ") from peer";

  return new_dim;
}
#endif

L
lilong12 已提交
109 110 111 112
template <typename T>
class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
113 114
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
    NCCL_VERSION_CODE >= 2703
L
lilong12 已提交
115
    int rid = ctx.Attr<int>("ring_id");
116
    bool dynamic_shape = ctx.Attr<bool>("dynamic_shape");
L
lilong12 已提交
117
    PADDLE_ENFORCE_GE(
118 119
        rid,
        0,
L
lilong12 已提交
120 121 122 123 124
        platform::errors::InvalidArgument(
            "The ring_id (%d) for recv_v2 op must be non-negative.", rid));

    int peer = ctx.Attr<int>("peer");
    PADDLE_ENFORCE_GE(
125 126
        peer,
        0,
L
lilong12 已提交
127 128 129
        platform::errors::InvalidArgument(
            "The peer (%d) for recv_v2 op must be non-negative.", peer));

130
    gpuStream_t stream = nullptr;
L
lilong12 已提交
131
    auto place = ctx.GetPlace();
132 133 134 135 136 137
    auto map = distributed::ProcessGroupMapFromGid::getInstance();
    if (map->has(rid)) {
      // Use ProcessGroup
      distributed::ProcessGroup *pg = map->get(rid);
      std::vector<phi::DenseTensor> out_tensor;
      auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
138
      auto out = ctx.Output<phi::DenseTensor>("Out");
139
      auto out_dims = out->dims();
140 141 142

      if (dynamic_shape) {
        VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch";
143 144 145 146 147
        framework::DDim new_dim = recv_shape_info(ctx.GetPlace(),
                                                  /* gpuStream_t */ nullptr,
                                                  /* NCCLComm* */ nullptr,
                                                  peer,
                                                  pg);
148 149 150 151 152
        out->Resize(new_dim);
        out->mutable_data<T>(new_dim, place);
      } else {
        out->mutable_data<T>(out_dims, place);
      }
153 154 155 156 157

      out_tensor.emplace_back(*out);
      auto task = pg->Recv(out_tensor, peer);
      return;
    }
L
lilong12 已提交
158 159
    auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
    if (ctx.Attr<bool>("use_calc_stream")) {
160 161
      // should ExecutionContext for calc stream.
      stream = ctx.cuda_device_context().stream();
L
lilong12 已提交
162 163 164 165
    } else {
      stream = comm->stream();
    }
    PADDLE_ENFORCE_LT(
166 167
        peer,
        comm->nranks(),
L
lilong12 已提交
168 169
        platform::errors::InvalidArgument("The value of peer (%d) you set must "
                                          "be less than comm->nranks (%d).",
170 171
                                          peer,
                                          comm->nranks()));
172 173 174 175

    int data_type = ctx.Attr<int>("dtype");
    framework::proto::VarType::Type type =
        framework::proto::VarType::Type(data_type);
176
    ncclDataType_t dtype = platform::ToNCCLDataType(type);
177 178 179

    auto *out_var = ctx.OutputVar("Out");
    if (out_var->IsType<framework::LoDTensorArray>()) {
180
      PADDLE_ENFORCE_EQ(
181 182
          dynamic_shape,
          false,
183 184
          platform::errors::InvalidArgument("Dynamic shape for send/recv not "
                                            "support LoDTensorArray for now."));
185 186 187 188 189 190 191
      auto out_array = out_var->GetMutable<framework::LoDTensorArray>();
      for (size_t idx = 0; idx < out_array->size(); ++idx) {
        VLOG(3) << "LodTensorArray: idx(" << idx << ")";
        auto out = &out_array->at(idx);
        auto out_dims = out->dims();
        out->mutable_data<T>(out_dims, place, 0);
        auto numel = out->numel();
192
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
193
            out->data<T>(), numel, dtype, peer, comm->comm(), stream));
194 195
        VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out_dims)
                << " from " << peer;
196 197 198 199 200
      }
      return;
    }

    auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
201
    auto out = ctx.Output<phi::DenseTensor>("Out");
202 203 204
    auto out_dims = out->dims();
    auto numel = out->numel();

205 206
    if (dynamic_shape) {
      VLOG(3) << "recv_v2 will use dynamic shape with send_v2";
207 208 209 210
      framework::DDim new_dim = recv_shape_info(place,
                                                stream,
                                                comm,
                                                peer,
211 212 213 214 215 216 217
                                                /* ProcessGroup* */ nullptr);
      out->Resize(new_dim);
      numel = out->numel();
      out->mutable_data<T>(new_dim, place);
    } else {
      out->mutable_data<T>(out_dims, place);
    }
218
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
L
lilong12 已提交
219
        out->data<T>(), numel, dtype, peer, comm->comm(), stream));
220
    VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out->dims())
221
            << " from " << peer;
L
lilong12 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "PaddlePaddle should be compiled with NCCL and "
        "NCCL version >= 2.7.3 is needed."));
#endif
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

236 237
REGISTER_OP_CUDA_KERNEL(recv_v2,
                        ops::RecvOpV2CUDAKernel<float>,
L
lilong12 已提交
238
                        ops::RecvOpV2CUDAKernel<double>,
L
LiYuRio 已提交
239
#if NCCL_VERSION_CODE >= 21000
240 241
                        ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif
L
lilong12 已提交
242 243
                        ops::RecvOpV2CUDAKernel<int>,
                        ops::RecvOpV2CUDAKernel<int64_t>,
L
lilong12 已提交
244
                        ops::RecvOpV2CUDAKernel<int8_t>,
L
lilong12 已提交
245
                        ops::RecvOpV2CUDAKernel<plat::float16>);