c_split_op.cu 4.5 KB
Newer Older
L
lilong12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <vector>

#include "paddle/fluid/operators/collective/c_split_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
19
#include "paddle/phi/backends/gpu/gpu_primitives.h"
L
lilong12 已提交
20 21 22 23

namespace paddle {
namespace operators {

24 25
static constexpr int64_t kNumCUDAThreads = 512;
static constexpr int64_t kNumMaxinumNumBlocks = 4096;
26

27
static inline int64_t NumBlocks(const int64_t N) {
28 29 30 31 32
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

template <typename T>
33 34
__global__ void SplitFromRank(const T* input,
                              T* output,
35 36
                              const int64_t rows,
                              const int64_t columns,
37 38
                              const int rank,
                              const int nranks,
39 40 41 42
                              const int64_t limit) {
  CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) {
    int64_t row = i / columns;
    int64_t col = i % columns;
43

44 45 46
    int64_t block = columns / nranks;
    int64_t start = block * rank;
    int64_t end = start + block;
47 48

    if (col >= start && col < end) {
49
      int64_t idx = block * row + col % block;
50 51 52 53 54
      output[idx] = input[i];
    }
  }
}

55
template <typename T, typename DeviceContext>
L
lilong12 已提交
56 57 58
class CSplitOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
59 60
    auto x = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
L
lilong12 已提交
61 62 63 64 65

    int nranks = ctx.Attr<int>("nranks");
    int rank = ctx.Attr<int>("rank");
    auto place = ctx.GetPlace();

66 67
    PADDLE_ENFORCE_GE(rank,
                      0,
68 69 70 71
                      platform::errors::PreconditionNotMet(
                          "The value of rank (%d) for c_split must be "
                          "greater than or equal to 0.",
                          rank));
72 73
    PADDLE_ENFORCE_GE(nranks,
                      2,
L
lilong12 已提交
74 75 76 77
                      platform::errors::PreconditionNotMet(
                          "The value of nranks (%d) for c_split must be "
                          "greater than or equal to 2.",
                          nranks));
78 79
    PADDLE_ENFORCE_LT(rank,
                      nranks,
L
lilong12 已提交
80 81 82
                      platform::errors::PreconditionNotMet(
                          "The value of rank (%d) for c_split must be "
                          "less than that of nranks (%d).",
83 84
                          rank,
                          nranks));
L
lilong12 已提交
85

L
Leo Chen 已提交
86
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
L
lilong12 已提交
87
    auto dims = x->dims();
88 89 90
    auto dims_size = dims.size();
    // final dim
    int64_t end_size = dims[dims_size - 1];
L
lilong12 已提交
91

92
    // remain dim
93 94
    auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
    int64_t remain_numel = phi::product(remain_ddim);
95

96 97 98
    int64_t limit = x->numel();
    int64_t blocks = NumBlocks(limit);
    int64_t threads = kNumCUDAThreads;
99 100

    dims[dims_size - 1] /= nranks;
L
lilong12 已提交
101
    out->mutable_data<T>(dims, place);
102

103 104 105 106 107 108 109
    SplitFromRank<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
                                                               out->data<T>(),
                                                               remain_numel,
                                                               end_size,
                                                               rank,
                                                               nranks,
                                                               limit);
L
lilong12 已提交
110 111 112 113 114 115 116 117
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

118 119 120 121 122 123 124 125
PD_REGISTER_STRUCT_KERNEL(c_split,
                          GPU,
                          ALL_LAYOUT,
                          ops::CSplitOpCUDAKernel,
                          float,
                          double,
                          int,
                          int64_t,
L
LiYuRio 已提交
126
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
127 128 129 130
                          plat::bfloat16,
#endif
                          plat::float16) {
}