partial_send_op.cu.cc 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/partial_send_op.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wen Sun 已提交
18
#include "paddle/fluid/distributed/collective/process_group.h"
19
#include "paddle/fluid/platform/collective_helper.h"
20
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
21
#endif
22
#include "paddle/fluid/framework/convert_utils.h"
23 24 25 26

namespace paddle {
namespace operators {

27
template <typename T, typename DeviceContext>
28 29 30 31 32
class PartialSendCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
    NCCL_VERSION_CODE >= 2703
33
    auto x = ctx.Input<phi::DenseTensor>("X");
34 35 36 37 38 39 40
    int numel = x->numel();
    int rid = ctx.Attr<int>("ring_id");
    int peer = ctx.Attr<int>("peer");
    int num = ctx.Attr<int>("num");
    int id = ctx.Attr<int>("id");

    PADDLE_ENFORCE_GE(
41 42
        rid,
        0,
43 44 45
        platform::errors::InvalidArgument(
            "The ring_id (%d) for partial_send op must be non-negative.", rid));
    PADDLE_ENFORCE_GE(
46 47
        peer,
        0,
48 49
        platform::errors::InvalidArgument(
            "The peer (%d) for partial_send op must be non-negative.", peer));
50 51
    PADDLE_ENFORCE_GE(num,
                      1,
52 53 54
                      platform::errors::InvalidArgument(
                          "The num (%d) for partial_send op must >=1", num));
    PADDLE_ENFORCE_EQ(
55 56
        (id >= 0 && id < num),
        true,
57 58 59
        platform::errors::InvalidArgument(
            "The id (%d) for partial_send op must >=0 and <num (%d)", id, num));
    PADDLE_ENFORCE_EQ(
60 61
        (numel % num),
        0,
62 63 64
        platform::errors::InvalidArgument(
            "The input numel (%d) must be divisible by num(%d)", numel, num));

65 66
    int64_t send_numel = numel / num;
    int64_t offset = send_numel * id;
67

68 69 70 71 72
    auto map = distributed::ProcessGroupMapFromGid::getInstance();
    if (map->has(rid)) {
      // Use ProcessGroup
      distributed::ProcessGroup* pg = map->get(rid);
      phi::DenseTensor tmp = *x;
73
      auto task = pg->Send(tmp, peer, offset, send_numel, /*sync_op*/ true);
74 75 76 77 78 79
      task->Wait();
    } else {
      gpuStream_t stream = nullptr;
      auto place = ctx.GetPlace();
      auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
      if (ctx.Attr<bool>("use_calc_stream")) {
80 81
        // should ExecutionContext for calc stream.
        stream = ctx.cuda_device_context().stream();
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
      } else {
        stream = comm->stream();
      }
      PADDLE_ENFORCE_LT(peer,
                        comm->nranks(),
                        platform::errors::InvalidArgument(
                            "The value of peer (%d) you set must "
                            "be less than comm->nranks (%d).",
                            peer,
                            comm->nranks()));

      ncclDataType_t dtype =
          platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));

      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::ncclSend(x->data<T>() + offset,
                                      send_numel,
                                      dtype,
                                      peer,
                                      comm->comm(),
                                      stream));
      VLOG(3) << "rank " << comm->rank() << " send " << send_numel
              << " from offset[" << offset << "] to " << peer;
    }
106 107 108 109 110 111 112 113 114 115 116 117 118 119
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "PaddlePaddle should be compiled with NCCL "
        "and NCCL version >= 2.7.3 is needed."));
#endif
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

120 121 122 123 124 125
PD_REGISTER_STRUCT_KERNEL(partial_send,
                          GPU,
                          ALL_LAYOUT,
                          ops::PartialSendCUDAKernel,
                          float,
                          double,
L
LiYuRio 已提交
126
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
127
                          plat::bfloat16,
128
#endif
129 130 131 132
                          int,
                          int64_t,
                          plat::float16) {
}