c_concat_op.cu.cc 5.2 KB
Newer Older
L
lilong12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/operators/collective/c_concat_op.h"

L
lilong12 已提交
17 18 19
#include <vector>

#include "paddle/fluid/operators/math/concat_and_split.h"
20
#include "paddle/phi/api/include/tensor.h"
L
lilong12 已提交
21 22

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wen Sun 已提交
23
#include "paddle/fluid/distributed/collective/process_group.h"
L
lilong12 已提交
24
#include "paddle/fluid/platform/collective_helper.h"
25
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
L
lilong12 已提交
26 27 28 29 30
#endif

namespace paddle {
namespace operators {

31
template <typename T, typename DeviceContext>
L
lilong12 已提交
32 33 34
class CConcatOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
35 36
    auto x = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
37 38
    ncclDataType_t dtype =
        platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
L
lilong12 已提交
39 40 41 42 43

    int nranks = ctx.Attr<int>("nranks");
    int rank = ctx.Attr<int>("rank");
    int rid = ctx.Attr<int>("ring_id");
    auto place = ctx.GetPlace();
44 45
    PADDLE_ENFORCE_GE(rank,
                      0,
L
lilong12 已提交
46 47 48 49
                      platform::errors::PreconditionNotMet(
                          "The value of rank (%d) for c_concat must be "
                          "greater than or equal to 0.",
                          rank));
50 51
    PADDLE_ENFORCE_GE(nranks,
                      2,
L
lilong12 已提交
52 53 54 55
                      platform::errors::PreconditionNotMet(
                          "The value of nranks (%d) for c_concat must be "
                          "greater than or equal to 2.",
                          nranks));
56 57
    PADDLE_ENFORCE_LT(rank,
                      nranks,
L
lilong12 已提交
58 59 60
                      platform::errors::PreconditionNotMet(
                          "The value of rank (%d) for c_concat must be "
                          "less than that of nranks (%d).",
61 62
                          rank,
                          nranks));
L
lilong12 已提交
63 64

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
65
    phi::DenseTensor temp_out;
L
lilong12 已提交
66 67 68 69
    framework::DDim temp_out_dims = x->dims();
    temp_out_dims[0] *= nranks;
    temp_out.mutable_data<T>(temp_out_dims, place);

70 71 72 73 74 75 76 77 78 79 80 81 82
    auto map = distributed::ProcessGroupMapFromGid::getInstance();
    if (map->has(rid)) {
      // Use ProcessGroup
      distributed::ProcessGroup* pg = map->get(rid);
      std::vector<phi::DenseTensor> in_tensor;
      std::vector<phi::DenseTensor> out_tensor;
      in_tensor.push_back(*x);
      out_tensor.push_back(temp_out);
      auto task = pg->AllGather(in_tensor, out_tensor);
      task->Wait();
    } else {
      auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
      PADDLE_ENFORCE_EQ(
83 84 85 86
          nranks,
          comm->nranks(),
          platform::errors::InvalidArgument(
              "nranks: %s should equal to %s", nranks, comm->nranks()));
87 88 89 90 91

      int64_t send_numel = x->numel();
      const T* send_buff = x->data<T>();
      T* recv_buff = temp_out.data<T>();
      gpuStream_t stream = nullptr;
92 93
      // should ExecutionContext for calc stream.
      stream = ctx.cuda_device_context().stream();
94

95 96 97 98 99 100 101
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::ncclAllGather(send_buff,
                                           recv_buff,
                                           send_numel,
                                           static_cast<ncclDataType_t>(dtype),
                                           comm->comm(),
                                           stream));
102
    }
L
lilong12 已提交
103

104
    std::vector<phi::DenseTensor> inputs;
L
lilong12 已提交
105 106 107 108 109 110
    int axis = x->dims().size() - 1;
    auto out_dims = x->dims();
    out_dims[out_dims.size() - 1] *= nranks;
    int rows_per_tensor = x->dims()[0];
    int offset = 0;
    for (int i = 0; i < nranks; i++) {
111
      phi::DenseTensor temp = temp_out.Slice(offset, offset + rows_per_tensor);
L
lilong12 已提交
112 113 114 115
      inputs.emplace_back(temp);
      offset += rows_per_tensor;
    }

L
Leo Chen 已提交
116
    math::ConcatFunctor<phi::GPUContext, T> functor;
L
lilong12 已提交
117
    out->mutable_data<T>(out_dims, place);
L
Leo Chen 已提交
118
    auto& dev_ctx2 = ctx.template device_context<phi::GPUContext>();
L
lilong12 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
    functor(dev_ctx2, inputs, axis, out);
#else
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with GPU."));
#endif
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

132 133 134 135 136 137 138 139
PD_REGISTER_STRUCT_KERNEL(c_concat,
                          GPU,
                          ALL_LAYOUT,
                          ops::CConcatOpCUDAKernel,
                          float,
                          double,
                          int,
                          int64_t,
L
LiYuRio 已提交
140
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
141
                          plat::bfloat16,
X
xu98bin 已提交
142
#endif
143 144
                          plat::float16) {
}