assign_pos_op.cu 3.8 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
R
Roc 已提交
13 14 15 16 17 18 19 20 21 22
limitations under the License.

The file has been adapted from the two files:
     https://github.com/laekov/fastmoe/blob/master/cuda/local_exchange.cu
     https://github.com/laekov/fastmoe/blob/master/cuda/local_exchange.cuh
     Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
We retain the following license from the original files:
         Copyright 2021, Jiaao He
   Licensed under the Apache License, Version 2.0 (the "License").
*/
R
Roc 已提交
23 24

#include "paddle/fluid/operators/assign_pos_op.h"
25
#include "paddle/fluid/framework/op_registry.h"
R
Roc 已提交
26
#include "paddle/fluid/platform/float16.h"
27
#include "paddle/phi/backends/gpu/gpu_primitives.h"
R
Roc 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

DECLARE_bool(avoid_op_randomness);

namespace paddle {
namespace operators {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

template <typename T>
43 44 45
__global__ void AssignPos(T* cum_count,
                          const T* numbers,
                          T* out,
R
Roc 已提交
46 47 48 49
                          int64_t limit) {
  CUDA_KERNEL_LOOP(i, limit) {
    int number_idx = numbers[i];
    if (number_idx > -1) {
50
      int p = phi::CudaAtomicAdd(cum_count + number_idx, -1);
R
Roc 已提交
51 52 53 54 55
      out[p - 1] = i;
    }
  }
}

56
template <typename T, typename DeviceContext>
R
Roc 已提交
57 58 59 60 61
class AssignPosCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    // assign pos decides which tokens should be fetched belong to specially
    // counter orderingly.
62
    auto cum_count = context.Input<phi::DenseTensor>(
R
Roc 已提交
63
        "cum_count");  // (counter number) int32 | int64
64 65
    auto numbers = context.Input<phi::DenseTensor>(
        "X");  // (batch_size * seq_len, topk) int32
R
Roc 已提交
66
    auto eff_num_len =
67 68 69
        context.Input<phi::DenseTensor>("eff_num_len");  // (sum(cum_count))
    auto out =
        context.Output<phi::DenseTensor>("Out");  // (cum_count) value ranges
R
Roc 已提交
70 71 72 73 74 75 76
                                                  // from 0 to batch_size *
                                                  // seq_len * topk
    auto place = context.GetPlace();
    auto numel = numbers->numel();
    T* cum_data = const_cast<T*>(cum_count->data<T>());
    auto cum_size = cum_count->numel();

77
    phi::DenseTensor cpu_eff_num_len;
R
Roc 已提交
78 79 80 81
    int64_t cpu_eff_num_len_data = 0;
    if (platform::is_cpu_place(eff_num_len->place())) {
      cpu_eff_num_len_data = eff_num_len->data<T>()[0];
    } else {
82 83
      framework::TensorCopySync(
          *eff_num_len, platform::CPUPlace(), &cpu_eff_num_len);
R
Roc 已提交
84 85
      cpu_eff_num_len_data = cpu_eff_num_len.data<T>()[0];
    }
L
Leo Chen 已提交
86
    const auto& dev_ctx = context.template device_context<phi::GPUContext>();
R
Roc 已提交
87 88 89 90 91 92 93 94
    framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data});
    auto out_data = out->mutable_data<T>(out_dims, place);

    const T* num_data = numbers->data<T>();

    int blocks = NumBlocks(numel);
    int threads = kNumCUDAThreads;

95 96
    AssignPos<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
        cum_data, num_data, out_data, numel);
R
Roc 已提交
97 98 99 100 101 102 103 104
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
105 106 107

PD_REGISTER_STRUCT_KERNEL(
    assign_pos, GPU, ALL_LAYOUT, ops::AssignPosCUDAKernel, int64_t) {}