sample_logits_op.cu 10.9 KB
Newer Older
X
xuezhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
19

X
xuezhong 已提交
20 21 22 23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/sample_logits_op.h"
25
#include "paddle/phi/kernels/funcs/math_function.h"
26
#include "paddle/phi/kernels/funcs/softmax.h"
X
xuezhong 已提交
27 28 29 30 31 32

namespace paddle {
namespace operators {

// UNDERSTAND: something like take_along_axis in numpy.
template <typename T>
33 34
__global__ void GPUTakeAlongD1(size_t size,
                               const int batch_size,
X
xuezhong 已提交
35
                               const int array_slice_size,
36 37 38 39
                               const int idx_slice_size,
                               const T* p_array,
                               const int64_t* p_index,
                               T* p_value) {
X
xuezhong 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    int i = idx / idx_slice_size;
    auto array_index = p_index[idx];
    p_value[idx] = p_array[i * array_slice_size + array_index];
  }
}

// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template <typename T>
54 55
__global__ void GPUPutAlongD1(size_t size,
                              const int batch_size,
X
xuezhong 已提交
56
                              const int array_slice_size,
57 58 59 60
                              const int idx_slice_size,
                              T* p_array,
                              const int64_t* p_index,
                              const T* p_value) {
X
xuezhong 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  // size == batch_size
  for (; idx < size; idx += step_size) {
    int i = idx;
    for (int j = 0; j < idx_slice_size; ++j) {
      auto array_index = p_index[i * idx_slice_size + j];
      p_array[i * array_slice_size + array_index] +=
          p_value[i * idx_slice_size + j];
    }
  }
}

// UNDERSTAND: set label as 0,1,...,num_true-1
template <typename T>
__global__ void GPUSetLabel(size_t size, const int num_true, int64_t* p_array) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    p_array[idx] = idx % num_true;
  }
}

// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template <typename T>
__global__ void gpu_compute_remove_accidental_hits(const int size,
                                                   const int num_true,
                                                   const int idx_slice_size,
                                                   const int64_t* p_index,
                                                   T* p_value) {
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    int i = idx / idx_slice_size;
    if (idx % idx_slice_size < num_true) continue;
    for (int j = 0; j < num_true; ++j) {
      const auto true_idx = i * idx_slice_size + j;
      if (p_index[true_idx] == p_index[idx]) {
        p_value[idx] -= 1e20;
        break;
      }
    }
  }
}

112
template <typename T, typename DeviceContext>
X
xuezhong 已提交
113 114 115 116
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    // get necessary inputs
117 118
    const phi::DenseTensor* logits = context.Input<phi::DenseTensor>("Logits");
    const phi::DenseTensor* labels = context.Input<phi::DenseTensor>("Labels");
X
xuezhong 已提交
119 120 121
    VLOG(3) << "Enter SampleLogitsCUDAKernel";

    // get necessary outputs
122 123 124 125 126 127 128
    phi::DenseTensor* samples = context.Output<phi::DenseTensor>("Samples");
    phi::DenseTensor* probabilities =
        context.Output<phi::DenseTensor>("Probabilities");
    phi::DenseTensor* sampled_logits =
        context.Output<phi::DenseTensor>("SampledLogits");
    phi::DenseTensor* sampled_labels =
        context.Output<phi::DenseTensor>("SampledLabels");
X
xuezhong 已提交
129 130 131 132

    // shapes
    const auto batch_size = logits->dims()[0];
    const auto num_classes = logits->dims()[1];
X
xuezhong 已提交
133 134
    const auto labels_dim = labels->dims();
    const auto num_true = labels_dim[1];
X
xuezhong 已提交
135 136 137 138
    const auto samples_dim = samples->dims();

    // attrs
    const auto num_samples = context.Attr<int>("num_samples");
X
xuezhong 已提交
139 140
    const bool use_customized_samples =
        context.Attr<bool>("use_customized_samples");
X
xuezhong 已提交
141 142 143 144 145 146 147 148 149
    const bool uniq = context.Attr<bool>("uniq");
    const bool remove_accidental_hits =
        context.Attr<bool>("remove_accidental_hits");

    // device contexts
    auto& dev_ctx = context.cuda_device_context();

    // UNDERSTAND: allocate memories for temporaries
    sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
L
Leo Chen 已提交
150
    phi::funcs::SetConstant<phi::GPUContext, T> set_zero;
X
xuezhong 已提交
151 152
    set_zero(dev_ctx, sampled_logits, static_cast<T>(0));

X
xuezhong 已提交
153 154
    auto sampled_labels_data =
        sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
X
xuezhong 已提交
155 156 157
    int threads = 512;
    size_t size = batch_size * num_true;
    int grid = (size + threads - 1) / threads;
158 159 160
    GPUSetLabel<T>
        <<<grid, threads, 0, context.cuda_device_context().stream()>>>(
            size, num_true, sampled_labels_data);
X
xuezhong 已提交
161 162

    if (use_customized_samples) {
163 164 165 166
      const phi::DenseTensor* customized_samples =
          context.Input<phi::DenseTensor>("CustomizedSamples");
      const phi::DenseTensor* customized_probabilities =
          context.Input<phi::DenseTensor>("CustomizedProbabilities");
167 168 169 170 171 172
      PADDLE_ENFORCE_EQ(
          customized_samples,
          samples,
          platform::errors::InvalidArgument(
              "CustomizedSamples must be the same phi::DenseTensor with "
              "Samples when use_customized_samples = True"));
173
      PADDLE_ENFORCE_EQ(
174 175
          customized_probabilities,
          probabilities,
176
          platform::errors::InvalidArgument(
177
              "CustomizedProbabilities must be the same phi::DenseTensor with "
178
              "Probabilities when use_customized_samples = True"));
X
xuezhong 已提交
179 180 181 182 183 184
    } else {
      samples->mutable_data<int64_t>(context.GetPlace());
      probabilities->mutable_data<T>(samples_dim, context.GetPlace());
      // UNDERSTAND: sampling
      const auto seed = context.Attr<int>("seed");
      auto sampler_with_prob = math::GPUSampleWithProb<T>();
185 186 187 188 189 190 191 192
      sampler_with_prob(context.cuda_device_context(),
                        seed,
                        num_classes,
                        uniq,
                        num_samples,
                        labels,
                        samples,
                        probabilities);
X
xuezhong 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    }

    // UNDERSTAND: gather sampled logits and remove accidental hits if needed
    const auto num_take = samples->dims()[1];
    const auto array_dims = logits->dims();
    const auto idx_dims = samples->dims();

    const T* p_array = logits->data<T>();
    const int64_t* p_index = samples->data<int64_t>();
    T* p_value = sampled_logits->data<T>();

    // src slice size
    const auto array_slice_size = array_dims[1];
    // index slice size
    const auto idx_slice_size = idx_dims[1];

    size = batch_size * num_take;
    grid = (size + threads - 1) / threads;
211 212
    GPUTakeAlongD1<T>
        <<<grid, threads, 0, context.cuda_device_context().stream()>>>(
213 214 215 216 217 218 219
            size,
            batch_size,
            array_slice_size,
            idx_slice_size,
            p_array,
            p_index,
            p_value);
X
xuezhong 已提交
220 221 222 223

    if (remove_accidental_hits) {
      const size_t size = batch_size * (num_true + num_samples);
      int grid = (size + threads - 1) / threads;
224 225 226
      gpu_compute_remove_accidental_hits<T>
          <<<grid, threads, 0, context.cuda_device_context().stream()>>>(
              size, num_true, idx_slice_size, p_index, p_value);
X
xuezhong 已提交
227 228 229 230 231 232 233 234 235 236 237
    }

    // subtracted sampled logits with logQ(y|x)
    auto probs = EigenMatrix<T>::From(*probabilities);
    auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
    smp_logits.device(*dev_ctx.eigen_device()) =
        (smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
            .unaryExpr(TolerableValue<T>());
  }
};

238
template <typename T, typename DeviceContext>
X
xuezhong 已提交
239 240 241
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
242 243 244 245 246 247 248
    auto logits_grad =
        context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
    const phi::DenseTensor* samples =
        context.Input<phi::DenseTensor>("Samples");
    const phi::DenseTensor* sampled_logits_grad =
        context.Input<phi::DenseTensor>(
            framework::GradVarName("SampledLogits"));
X
xuezhong 已提交
249 250 251
    logits_grad->mutable_data<T>(context.GetPlace());

    auto& dev_ctx = context.cuda_device_context();
L
Leo Chen 已提交
252
    phi::funcs::SetConstant<phi::GPUContext, T> set_zero;
X
xuezhong 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    set_zero(dev_ctx, logits_grad, static_cast<T>(0));

    // UNDERSTAND: scatter it back to logit_grad
    const auto batch_size = samples->dims()[0];
    const auto num_put = samples->dims()[1];
    const auto array_dims = logits_grad->dims();
    const auto idx_dims = samples->dims();

    T* p_array = logits_grad->data<T>();
    const int64_t* p_index = samples->data<int64_t>();
    const T* p_value = sampled_logits_grad->data<T>();

    // src slice size
    const auto array_slice_size = array_dims[1];
    // index slice size
    const auto idx_slice_size = idx_dims[1];

    int threads = 128;
    const size_t size = batch_size;
    int grid = (size + threads - 1) / threads;

274 275
    GPUPutAlongD1<T>
        <<<grid, threads, 0, context.cuda_device_context().stream()>>>(
276 277 278 279 280 281 282
            size,
            batch_size,
            array_slice_size,
            idx_slice_size,
            p_array,
            p_index,
            p_value);
X
xuezhong 已提交
283 284 285 286 287 288 289
  }
};

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

290 291 292 293 294 295 296 297 298 299 300 301
PD_REGISTER_STRUCT_KERNEL(sample_logits,
                          GPU,
                          ALL_LAYOUT,
                          ops::SampleLogitsCUDAKernel,
                          float,
                          double) {}
PD_REGISTER_STRUCT_KERNEL(sample_logits_grad,
                          GPU,
                          ALL_LAYOUT,
                          ops::SampleLogitsGradCUDAKernel,
                          float,
                          double) {}