sample_logits_op.h 9.9 KB
Newer Older
X
xuezhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <unordered_set>
X
xuezhong 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename T>
struct TolerableValue {
  HOSTDEVICE T operator()(const T& x) const {
37 38
    PADDLE_ENFORCE(std::is_floating_point<T>::value,
                   "TolerableValue should be float in sample_logits_op.");
X
xuezhong 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51
    const T kApproInf = 1e20;
    if (x == INFINITY) return kApproInf;
    if (x == -INFINITY) return -kApproInf;
    return x;
  }
};

// UNDERSTAND: something like take_along_axis in numpy.
template <typename T>
static void CPUTakeAlongD1(const platform::DeviceContext& ctx,
                           const framework::Tensor& array,
                           const framework::Tensor& index,
                           framework::Tensor* value) {
52
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
X
xuezhong 已提交
53
  // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
54 55 56 57
  PADDLE_ENFORCE_EQ(index.dims().size(), 2);
  PADDLE_ENFORCE_EQ(array.dims().size(), 2);
  PADDLE_ENFORCE_EQ(index.dims()[0], array.dims()[0]);
  PADDLE_ENFORCE_EQ(index.dims(), value->dims());
X
xuezhong 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

  const auto batch_size = index.dims()[0];
  const auto num_take = index.dims()[1];
  const auto array_dims = array.dims();
  const auto idx_dims = index.dims();

  // UNDERSTAND: no allocations here
  const T* p_array = array.data<T>();
  const int64_t* p_index = index.data<int64_t>();
  T* p_value = value->data<T>();

  // src slice size
  const auto array_slice_size = array_dims[1];

  // index slice size
  const auto idx_slice_size = idx_dims[1];
  const auto value_slice_size = idx_slice_size;

  for (int i = 0; i < batch_size; ++i) {
    for (int j = 0; j < num_take; ++j) {
      auto array_index = p_index[i * idx_slice_size + j];
      p_value[i * value_slice_size + j] =
          p_array[i * array_slice_size + array_index];
    }
  }
}

// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template <typename T>
static void CPUPutAlongD1(const platform::DeviceContext& ctx,
                          framework::Tensor* array,
                          const framework::Tensor& index,
                          const framework::Tensor& value) {
92
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
X
xuezhong 已提交
93
  // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
94 95 96 97
  PADDLE_ENFORCE_EQ(index.dims().size(), 2);
  PADDLE_ENFORCE_EQ(array->dims().size(), 2);
  PADDLE_ENFORCE_EQ(index.dims()[0], array->dims()[0]);
  PADDLE_ENFORCE_EQ(index.dims(), value.dims());
X
xuezhong 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
  const auto batch_size = index.dims()[0];
  const auto num_put = index.dims()[1];
  auto array_dims = array->dims();
  auto idx_dims = index.dims();

  // UNDERSTAND: no allocations here
  T* p_array = array->data<T>();
  const int64_t* p_index = index.data<int64_t>();
  const T* p_value = value.data<T>();

  // slice sizes
  const auto array_slice_size = array_dims[1];
  const auto idx_slice_size = idx_dims[1];
  const auto value_slice_size = idx_slice_size;

  for (int i = 0; i < batch_size; ++i) {
    for (int j = 0; j < num_put; ++j) {
      auto array_index = p_index[i * idx_slice_size + j];
      p_array[i * array_slice_size + array_index] +=
          p_value[i * value_slice_size + j];
    }
  }
}

// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template <typename T>
static void compute_remove_accidental_hits(const platform::DeviceContext& ctx,
                                           framework::Tensor* sampled_logits,
                                           const framework::Tensor& samples,
                                           const int num_true) {
  const auto batch_size = sampled_logits->dims()[0];
  const auto num_sampled_classes = sampled_logits->dims()[1];
  T* sampled_logits_data = sampled_logits->data<T>();
  const auto samples_data = samples.data<int64_t>();

  std::unordered_set<int64_t> tmp_true_labels;
  for (int i = 0; i < batch_size; ++i) {
    tmp_true_labels.clear();
    tmp_true_labels.insert(samples_data + i * num_sampled_classes,
                           samples_data + i * num_sampled_classes + num_true);
    for (int j = num_true; j < num_sampled_classes; ++j) {
      const auto idx = i * num_sampled_classes + j;
      if (tmp_true_labels.find(samples_data[idx]) != tmp_true_labels.end())
        sampled_logits_data[idx] -= 1e20;
    }
  }
}

template <typename T>
class SampleLogitsKernel : public framework::OpKernel<T> {
 public:
  using Tensor = framework::Tensor;
  void Compute(const framework::ExecutionContext& context) const override {
152 153
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(context.GetPlace()), true,
                      "This kernel only runs on CPU.");
X
xuezhong 已提交
154 155 156
    VLOG(3) << "Enter SampleLogitsKernel";
    // get necessary inputs
    const Tensor* logits = context.Input<Tensor>("Logits");
X
xuezhong 已提交
157
    const Tensor* labels = context.Input<Tensor>("Labels");
X
xuezhong 已提交
158 159 160 161 162

    // get necessary outputs
    Tensor* samples = context.Output<Tensor>("Samples");
    Tensor* probabilities = context.Output<Tensor>("Probabilities");
    Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
X
xuezhong 已提交
163
    Tensor* sampled_labels = context.Output<Tensor>("SampledLabels");
X
xuezhong 已提交
164 165 166 167

    // shapes
    const auto batch_size = logits->dims()[0];
    const auto num_classes = logits->dims()[1];
X
xuezhong 已提交
168 169
    const auto labels_dim = labels->dims();
    const auto num_true = labels_dim[1];
X
xuezhong 已提交
170 171 172 173
    const auto samples_dim = samples->dims();

    // attrs
    const auto num_samples = context.Attr<int>("num_samples");
X
xuezhong 已提交
174 175
    const bool use_customized_samples =
        context.Attr<bool>("use_customized_samples");
X
xuezhong 已提交
176 177 178 179 180 181 182 183 184
    const bool remove_accidental_hits =
        context.Attr<bool>("remove_accidental_hits");

    // device contexts
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();

    // UNDERSTAND: allocate memories for temporaries
    sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
X
xuezhong 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197
    auto sampled_labels_data =
        sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
    for (int i = 0; i < batch_size; ++i) {
      for (int j = 0; j < num_true; ++j) {
        sampled_labels_data[i * num_true + j] = j;
      }
    }

    if (use_customized_samples) {
      const Tensor* customized_samples =
          context.Input<Tensor>("CustomizedSamples");
      const Tensor* customized_probabilities =
          context.Input<Tensor>("CustomizedProbabilities");
198 199 200 201 202 203 204 205 206
      PADDLE_ENFORCE_EQ(customized_samples, samples,
                        platform::errors::InvalidArgument(
                            "CustomizedSamples must be the same Tensor with "
                            "Samples when use_customized_samples = True"));
      PADDLE_ENFORCE_EQ(
          customized_probabilities, probabilities,
          platform::errors::InvalidArgument(
              "CustomizedProbabilities must be the same Tensor with "
              "Probabilities when use_customized_samples = True"));
X
xuezhong 已提交
207 208 209 210 211 212 213 214
    } else {
      samples->mutable_data<int64_t>(context.GetPlace());
      probabilities->mutable_data<T>(samples_dim, context.GetPlace());
      // UNDERSTAND: sampling
      const auto seed = context.Attr<int>("seed");
      auto sampler_with_prob =
          math::SampleWithProb<platform::CPUDeviceContext, T>();
      sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed),
X
xuezhong 已提交
215
                        num_samples, labels, samples, probabilities);
X
xuezhong 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    }

    // UNDERSTAND: gather sampled logits and remove accidental hits if needed
    CPUTakeAlongD1<T>(dev_ctx, *logits, *samples, sampled_logits);
    if (remove_accidental_hits) {
      compute_remove_accidental_hits<T>(dev_ctx, sampled_logits, *samples,
                                        num_true);
    }

    // subtracted sampled logits with logQ(y|x)
    auto probs = EigenMatrix<T>::From(*probabilities);
    auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
    smp_logits.device(*dev_ctx.eigen_device()) =
        (smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
            .unaryExpr(TolerableValue<T>());
  }
};

template <typename T>
class SampleLogitsGradKernel : public framework::OpKernel<T> {
 public:
  using Tensor = framework::Tensor;
  void Compute(const framework::ExecutionContext& context) const override {
    auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
    const Tensor* samples = context.Input<Tensor>("Samples");
    const Tensor* sampled_logits_grad =
        context.Input<Tensor>(framework::GradVarName("SampledLogits"));
    logits_grad->mutable_data<T>(context.GetPlace());

    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    math::SetConstant<platform::CPUDeviceContext, T> set_zero;
    set_zero(dev_ctx, logits_grad, static_cast<T>(0));

    // UNDERSTAND: scatter it back to logit_grad
    CPUPutAlongD1<T>(dev_ctx, logits_grad, *samples, *sampled_logits_grad);
  }
};

}  // namespace operators
}  // namespace paddle