sample_logits_op.cu 9.7 KB
Newer Older
X
xuezhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/sample_logits_op.h"

namespace paddle {
namespace operators {

// UNDERSTAND: something like take_along_axis in numpy.
template <typename T>
__global__ void GPUTakeAlongD1(size_t size, const int batch_size,
                               const int array_slice_size,
                               const int idx_slice_size, const T* p_array,
                               const int64_t* p_index, T* p_value) {
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    int i = idx / idx_slice_size;
    auto array_index = p_index[idx];
    p_value[idx] = p_array[i * array_slice_size + array_index];
  }
}

// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template <typename T>
__global__ void GPUPutAlongD1(size_t size, const int batch_size,
                              const int array_slice_size,
                              const int idx_slice_size, T* p_array,
                              const int64_t* p_index, const T* p_value) {
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  // size == batch_size
  for (; idx < size; idx += step_size) {
    int i = idx;
    for (int j = 0; j < idx_slice_size; ++j) {
      auto array_index = p_index[i * idx_slice_size + j];
      p_array[i * array_slice_size + array_index] +=
          p_value[i * idx_slice_size + j];
    }
  }
}

// UNDERSTAND: set label as 0,1,...,num_true-1
template <typename T>
__global__ void GPUSetLabel(size_t size, const int num_true, int64_t* p_array) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    p_array[idx] = idx % num_true;
  }
}

// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template <typename T>
__global__ void gpu_compute_remove_accidental_hits(const int size,
                                                   const int num_true,
                                                   const int idx_slice_size,
                                                   const int64_t* p_index,
                                                   T* p_value) {
  const auto value_slice_size = idx_slice_size;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = blockDim.x * gridDim.x;

  for (; idx < size; idx += step_size) {
    int i = idx / idx_slice_size;
    if (idx % idx_slice_size < num_true) continue;
    for (int j = 0; j < num_true; ++j) {
      const auto true_idx = i * idx_slice_size + j;
      if (p_index[true_idx] == p_index[idx]) {
        p_value[idx] -= 1e20;
        break;
      }
    }
  }
}

template <typename T>
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
 public:
  using Tensor = framework::Tensor;
  void Compute(const framework::ExecutionContext& context) const override {
    // get necessary inputs
    const Tensor* logits = context.Input<Tensor>("Logits");
X
xuezhong 已提交
112
    const Tensor* labels = context.Input<Tensor>("Labels");
X
xuezhong 已提交
113 114 115 116 117 118
    VLOG(3) << "Enter SampleLogitsCUDAKernel";

    // get necessary outputs
    Tensor* samples = context.Output<Tensor>("Samples");
    Tensor* probabilities = context.Output<Tensor>("Probabilities");
    Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
X
xuezhong 已提交
119
    Tensor* sampled_labels = context.Output<Tensor>("SampledLabels");
X
xuezhong 已提交
120 121 122 123

    // shapes
    const auto batch_size = logits->dims()[0];
    const auto num_classes = logits->dims()[1];
X
xuezhong 已提交
124 125
    const auto labels_dim = labels->dims();
    const auto num_true = labels_dim[1];
X
xuezhong 已提交
126 127 128 129
    const auto samples_dim = samples->dims();

    // attrs
    const auto num_samples = context.Attr<int>("num_samples");
X
xuezhong 已提交
130 131
    const bool use_customized_samples =
        context.Attr<bool>("use_customized_samples");
X
xuezhong 已提交
132 133 134 135 136 137 138 139 140 141 142 143
    const bool uniq = context.Attr<bool>("uniq");
    const bool remove_accidental_hits =
        context.Attr<bool>("remove_accidental_hits");

    // device contexts
    auto& dev_ctx = context.cuda_device_context();

    // UNDERSTAND: allocate memories for temporaries
    sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
    math::SetConstant<platform::CUDADeviceContext, T> set_zero;
    set_zero(dev_ctx, sampled_logits, static_cast<T>(0));

X
xuezhong 已提交
144 145
    auto sampled_labels_data =
        sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
X
xuezhong 已提交
146 147 148 149 150
    int threads = 512;
    size_t size = batch_size * num_true;
    int grid = (size + threads - 1) / threads;
    GPUSetLabel<
        T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
X
xuezhong 已提交
151 152 153 154 155 156 157 158 159
        size, num_true, sampled_labels_data);

    if (use_customized_samples) {
      const Tensor* customized_samples =
          context.Input<Tensor>("CustomizedSamples");
      const Tensor* customized_probabilities =
          context.Input<Tensor>("CustomizedProbabilities");
      samples->ShareDataWith(*customized_samples);
      probabilities->ShareDataWith(*customized_probabilities);
X
xuezhong 已提交
160 161 162 163 164 165 166
    } else {
      samples->mutable_data<int64_t>(context.GetPlace());
      probabilities->mutable_data<T>(samples_dim, context.GetPlace());
      // UNDERSTAND: sampling
      const auto seed = context.Attr<int>("seed");
      auto sampler_with_prob = math::GPUSampleWithProb<T>();
      sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
X
xuezhong 已提交
167
                        num_samples, labels, samples, probabilities);
X
xuezhong 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    }

    // UNDERSTAND: gather sampled logits and remove accidental hits if needed
    const auto num_take = samples->dims()[1];
    const auto array_dims = logits->dims();
    const auto idx_dims = samples->dims();

    const T* p_array = logits->data<T>();
    const int64_t* p_index = samples->data<int64_t>();
    T* p_value = sampled_logits->data<T>();

    // src slice size
    const auto array_slice_size = array_dims[1];
    // index slice size
    const auto idx_slice_size = idx_dims[1];

    size = batch_size * num_take;
    grid = (size + threads - 1) / threads;
    GPUTakeAlongD1<
        T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
        size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
        p_value);

    if (remove_accidental_hits) {
      const size_t size = batch_size * (num_true + num_samples);
      int grid = (size + threads - 1) / threads;
      gpu_compute_remove_accidental_hits<
          T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
          size, num_true, idx_slice_size, p_index, p_value);
    }

    // subtracted sampled logits with logQ(y|x)
    auto probs = EigenMatrix<T>::From(*probabilities);
    auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
    smp_logits.device(*dev_ctx.eigen_device()) =
        (smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
            .unaryExpr(TolerableValue<T>());
  }
};

template <typename T>
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
 public:
  using Tensor = framework::Tensor;
  void Compute(const framework::ExecutionContext& context) const override {
    auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
    const Tensor* samples = context.Input<Tensor>("Samples");
    const Tensor* sampled_logits_grad =
        context.Input<Tensor>(framework::GradVarName("SampledLogits"));
    logits_grad->mutable_data<T>(context.GetPlace());

    auto& dev_ctx = context.cuda_device_context();
    math::SetConstant<platform::CUDADeviceContext, T> set_zero;
    set_zero(dev_ctx, logits_grad, static_cast<T>(0));

    // UNDERSTAND: scatter it back to logit_grad
    const auto batch_size = samples->dims()[0];
    const auto num_put = samples->dims()[1];
    const auto array_dims = logits_grad->dims();
    const auto idx_dims = samples->dims();

    T* p_array = logits_grad->data<T>();
    const int64_t* p_index = samples->data<int64_t>();
    const T* p_value = sampled_logits_grad->data<T>();

    // src slice size
    const auto array_slice_size = array_dims[1];
    // index slice size
    const auto idx_slice_size = idx_dims[1];

    int threads = 128;
    const size_t size = batch_size;
    int grid = (size + threads - 1) / threads;

    GPUPutAlongD1<
        T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
        size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
        p_value);
  }
};

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(sample_logits, ops::SampleLogitsCUDAKernel<float>,
                        ops::SampleLogitsCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(sample_logits_grad,
                        ops::SampleLogitsGradCUDAKernel<float>,
                        ops::SampleLogitsGradCUDAKernel<double>);