recv_v2_op.cu.cc 9.3 KB
Newer Older
L
lilong12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/recv_v2_op.h"

17
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
L
lilong12 已提交
18
#include "paddle/fluid/platform/collective_helper.h"
19
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
L
lilong12 已提交
20 21
#endif

W
Wen Sun 已提交
22
#include "paddle/fluid/distributed/collective/process_group.h"
23 24
#include "paddle/phi/api/include/tensor.h"

L
lilong12 已提交
25 26 27
namespace paddle {
namespace operators {

28 29 30 31
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
    NCCL_VERSION_CODE >= 2703
framework::DDim recv_shape_info(const platform::Place &place,
                                const gpuStream_t &stream,
32 33
                                platform::NCCLComm *comm,
                                const int &peer,
34 35
                                distributed::ProcessGroup *group) {
  if (!group) {
36 37
    PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr),
                      true,
38 39 40 41 42
                      platform::errors::InvalidArgument(
                          "NCCLComm and Stream should be provided if use NCCL "
                          "to send the shape info."));
  }

C
co63oc 已提交
43
  phi::DataType shape_dtype = phi::DataType::INT32;
44
  ncclDataType_t nccl_dtype =
C
co63oc 已提交
45
      platform::ToNCCLDataType(framework::TransToProtoVarType(shape_dtype));
46 47

  // step1: recv the shape size
C
co63oc 已提交
48
  phi::DenseTensor gpu_shape_size_tensor(shape_dtype);
49 50
  if (!group) {
    gpu_shape_size_tensor.Resize({1});
C
co63oc 已提交
51
    gpu_shape_size_tensor.mutable_data(place, shape_dtype);
52 53 54 55 56 57
    auto *gpu_data = gpu_shape_size_tensor.data<int>();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
        gpu_data, 1, nccl_dtype, peer, comm->comm(), stream));
  }

  // copy the shape size tensor to cpu
C
co63oc 已提交
58
  phi::DenseTensor *cpu_shape_size_tensor = new phi::DenseTensor(shape_dtype);
59
  cpu_shape_size_tensor->Resize({1});
C
co63oc 已提交
60
  cpu_shape_size_tensor->mutable_data(platform::CPUPlace(), shape_dtype);
61
  if (group) {
62
    std::vector<phi::DenseTensor> shape_size_tensor;
63 64 65
    shape_size_tensor.emplace_back(*cpu_shape_size_tensor);
    auto shape_size_task = group->Recv(shape_size_tensor, peer);
  } else {
66 67
    framework::TensorCopySync(
        gpu_shape_size_tensor, platform::CPUPlace(), cpu_shape_size_tensor);
68 69 70 71 72 73
  }
  auto *cpu_data = cpu_shape_size_tensor->data<int>();
  int shape_size = cpu_data[0];
  VLOG(3) << "recv the shape size: " << shape_size << " from peer";

  // step2: recv the shape
C
co63oc 已提交
74
  phi::DenseTensor gpu_shape_tensor(shape_dtype);
75 76
  if (!group) {
    gpu_shape_tensor.Resize({shape_size});
C
co63oc 已提交
77
    gpu_shape_tensor.mutable_data(place, shape_dtype);
78 79 80 81 82 83
    auto *gpu_shape_data = gpu_shape_tensor.data<int>();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
        gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream));
  }

  // copy the shape tensor to cpu
C
co63oc 已提交
84
  phi::DenseTensor *cpu_shape_tensor = new phi::DenseTensor(shape_dtype);
85
  cpu_shape_tensor->Resize({shape_size});
C
co63oc 已提交
86
  cpu_shape_tensor->mutable_data(platform::CPUPlace(), shape_dtype);
87
  if (group) {
88
    std::vector<phi::DenseTensor> shape_tensor;
89 90 91
    shape_tensor.emplace_back(*cpu_shape_tensor);
    auto shape_task = group->Recv(shape_tensor, peer);
  } else {
92 93
    framework::TensorCopySync(
        gpu_shape_tensor, platform::CPUPlace(), cpu_shape_tensor);
94 95 96 97 98 99 100 101 102 103 104 105 106 107
  }
  auto *cpu_shape_data = cpu_shape_tensor->data<int>();
  std::vector<int> all_shape;
  for (int i = 0; i < shape_size; ++i) {
    all_shape.emplace_back(cpu_shape_data[i]);
  }
  framework::DDim new_dim;
  new_dim = new_dim.reshape(all_shape);
  VLOG(3) << "recv the shape: (" << new_dim << ") from peer";

  return new_dim;
}
#endif

108
template <typename T, typename DeviceContext>
L
lilong12 已提交
109 110 111
class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
112 113
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
    NCCL_VERSION_CODE >= 2703
L
lilong12 已提交
114
    int rid = ctx.Attr<int>("ring_id");
115
    bool dynamic_shape = ctx.Attr<bool>("dynamic_shape");
L
lilong12 已提交
116
    PADDLE_ENFORCE_GE(
117 118
        rid,
        0,
L
lilong12 已提交
119 120 121 122 123
        platform::errors::InvalidArgument(
            "The ring_id (%d) for recv_v2 op must be non-negative.", rid));

    int peer = ctx.Attr<int>("peer");
    PADDLE_ENFORCE_GE(
124 125
        peer,
        0,
L
lilong12 已提交
126 127 128
        platform::errors::InvalidArgument(
            "The peer (%d) for recv_v2 op must be non-negative.", peer));

129
    gpuStream_t stream = nullptr;
L
lilong12 已提交
130
    auto place = ctx.GetPlace();
131 132 133 134 135 136
    auto map = distributed::ProcessGroupMapFromGid::getInstance();
    if (map->has(rid)) {
      // Use ProcessGroup
      distributed::ProcessGroup *pg = map->get(rid);
      std::vector<phi::DenseTensor> out_tensor;
      auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
137
      auto out = ctx.Output<phi::DenseTensor>("Out");
138
      auto out_dims = out->dims();
139 140 141

      if (dynamic_shape) {
        VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch";
142 143 144 145 146
        framework::DDim new_dim = recv_shape_info(ctx.GetPlace(),
                                                  /* gpuStream_t */ nullptr,
                                                  /* NCCLComm* */ nullptr,
                                                  peer,
                                                  pg);
147 148 149 150 151
        out->Resize(new_dim);
        out->mutable_data<T>(new_dim, place);
      } else {
        out->mutable_data<T>(out_dims, place);
      }
152 153 154 155 156

      out_tensor.emplace_back(*out);
      auto task = pg->Recv(out_tensor, peer);
      return;
    }
L
lilong12 已提交
157 158
    auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
    if (ctx.Attr<bool>("use_calc_stream")) {
159 160
      // should ExecutionContext for calc stream.
      stream = ctx.cuda_device_context().stream();
L
lilong12 已提交
161 162 163 164
    } else {
      stream = comm->stream();
    }
    PADDLE_ENFORCE_LT(
165 166
        peer,
        comm->nranks(),
L
lilong12 已提交
167 168
        platform::errors::InvalidArgument("The value of peer (%d) you set must "
                                          "be less than comm->nranks (%d).",
169 170
                                          peer,
                                          comm->nranks()));
171 172 173 174

    int data_type = ctx.Attr<int>("dtype");
    framework::proto::VarType::Type type =
        framework::proto::VarType::Type(data_type);
175
    ncclDataType_t dtype = platform::ToNCCLDataType(type);
176 177 178

    auto *out_var = ctx.OutputVar("Out");
    if (out_var->IsType<framework::LoDTensorArray>()) {
179
      PADDLE_ENFORCE_EQ(
180 181
          dynamic_shape,
          false,
182 183
          platform::errors::InvalidArgument("Dynamic shape for send/recv not "
                                            "support LoDTensorArray for now."));
184 185 186 187 188 189 190
      auto out_array = out_var->GetMutable<framework::LoDTensorArray>();
      for (size_t idx = 0; idx < out_array->size(); ++idx) {
        VLOG(3) << "LodTensorArray: idx(" << idx << ")";
        auto out = &out_array->at(idx);
        auto out_dims = out->dims();
        out->mutable_data<T>(out_dims, place, 0);
        auto numel = out->numel();
191
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
192
            out->data<T>(), numel, dtype, peer, comm->comm(), stream));
193 194
        VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out_dims)
                << " from " << peer;
195 196 197 198 199
      }
      return;
    }

    auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
200
    auto out = ctx.Output<phi::DenseTensor>("Out");
201 202 203
    auto out_dims = out->dims();
    auto numel = out->numel();

204 205
    if (dynamic_shape) {
      VLOG(3) << "recv_v2 will use dynamic shape with send_v2";
206 207 208 209
      framework::DDim new_dim = recv_shape_info(place,
                                                stream,
                                                comm,
                                                peer,
210 211 212 213 214 215 216
                                                /* ProcessGroup* */ nullptr);
      out->Resize(new_dim);
      numel = out->numel();
      out->mutable_data<T>(new_dim, place);
    } else {
      out->mutable_data<T>(out_dims, place);
    }
217
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
L
lilong12 已提交
218
        out->data<T>(), numel, dtype, peer, comm->comm(), stream));
219
    VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out->dims())
220
            << " from " << peer;
L
lilong12 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "PaddlePaddle should be compiled with NCCL and "
        "NCCL version >= 2.7.3 is needed."));
#endif
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

235 236 237 238 239 240
PD_REGISTER_STRUCT_KERNEL(recv_v2,
                          GPU,
                          ALL_LAYOUT,
                          ops::RecvOpV2CUDAKernel,
                          float,
                          double,
L
LiYuRio 已提交
241
#if NCCL_VERSION_CODE >= 21000
242
                          plat::bfloat16,
243
#endif
244 245 246 247 248
                          int,
                          int64_t,
                          int8_t,
                          plat::float16) {
}