nce_op.h 15.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
T
tangwei12 已提交
18
#include <iterator>
W
wanghaoshuang 已提交
19
#include <random>
20
#include <set>
T
tangwei12 已提交
21
#include <string>
22
#include <vector>
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
25
#include "paddle/fluid/framework/selected_rows.h"
26
#include "paddle/fluid/operators/math/sampler.h"
W
wanghaoshuang 已提交
27
#include "unsupported/Eigen/CXX11/Tensor"
28

W
wanghaoshuang 已提交
29 30 31
namespace paddle {
namespace operators {

32
using Tensor = framework::Tensor;
33 34
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
35
using Sampler = math::Sampler;
36
using DDim = framework::DDim;
W
wanghaoshuang 已提交
37 38 39 40 41

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
42
template <typename DeviceContext, typename T>
43 44
void PrepareSamples(const framework::ExecutionContext &context,
                    Sampler *sampler) {
W
wanghaoshuang 已提交
45
  auto label = context.Input<Tensor>("Label");
46
  const int64_t *label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
47
  auto label_dims = label->dims();
W
wanghaoshuang 已提交
48
  // for unitest
W
wanghaoshuang 已提交
49 50
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
51 52 53

  auto sample_labels = context.Output<Tensor>("SampleLabels");
  auto sample_labels_dims = sample_labels->dims();
54
  int64_t *sample_labels_data =
W
wanghaoshuang 已提交
55
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
56 57

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
58
  int index = 0;
59
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
60 61
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
62
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
63
    }
W
wanghaoshuang 已提交
64 65
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
66 67 68 69
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
70
        // TODO(wanghaoshuang): support more distribution sampling
71
        sample_labels_data[index++] = sampler->Sample();
W
wanghaoshuang 已提交
72
      }
W
wanghaoshuang 已提交
73 74 75 76
    }
  }
}

Q
QI JUN 已提交
77
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
78 79
class NCEKernel : public framework::OpKernel<T> {
 public:
80
  void Compute(const framework::ExecutionContext &context) const override {
81 82 83 84 85
    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    int num_total_classes = context.Attr<int>("num_total_classes");
    int num_neg_samples = context.Attr<int>("num_neg_samples");

86
    Sampler *sampler;
87 88 89 90 91 92 93 94 95 96
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
97 98 99 100
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

101 102
        PADDLE_ENFORCE_EQ(
            dist_probs->numel(), num_total_classes,
103 104 105 106 107 108
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_probs->numel(), num_total_classes));
109 110
        PADDLE_ENFORCE_EQ(
            dist_alias->numel(), num_total_classes,
111 112 113 114 115 116
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistAlias) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_alias->numel(), num_total_classes));
117 118
        PADDLE_ENFORCE_EQ(
            dist_alias_probs->numel(), num_total_classes,
119 120 121 122 123 124 125
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in "
                "Input(CustomDistAliasProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAliasProbs).numel() = %d, "
                "Attr(num_total_classes) = %d.",
                dist_alias_probs->numel(), num_total_classes));
126 127 128 129 130 131

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
132 133
        break;
      }
F
Feiyu Chan 已提交
134 135 136 137 138 139
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Unsupported SamplerType. SamplerType should be 0: Uniform, "
            "1: LogUniform or 2: CostumDist. Received SamplerType: %d",
            sampler_type));
      }
140 141 142
    }

    PrepareSamples<DeviceContext, T>(context, sampler);
W
wanghaoshuang 已提交
143
    auto sample_labels = context.Output<Tensor>("SampleLabels");
144
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
145 146

    for (int x = 0; x < sample_labels->numel(); x++) {
147
      PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
148 149 150 151 152
                        platform::errors::InvalidArgument(
                            "ValueError: Every sample label should be "
                            "non-negative. But received: "
                            "Input(SampleLabels)[%d] = %d",
                            x, sample_labels_data[x]));
153 154
    }

W
wanghaoshuang 已提交
155
    auto sample_out = context.Output<Tensor>("SampleLogits");
156
    T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
157 158
    auto label = context.Input<Tensor>("Label");
    auto sample_weight = context.Input<Tensor>("SampleWeight");
159
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
160 161 162
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
163
    auto out = context.Output<Tensor>("Cost");
164
    T *out_data = out->mutable_data<T>(context.GetPlace());
165
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
166 167 168
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
169 170
    int64_t sampled_labels_num = sample_labels->dims()[1];
    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
171
    // forward bias
W
wanghaoshuang 已提交
172
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
173
    if (bias != nullptr) {
174
      const T *bias_data = bias->data<T>();
175
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
176 177 178
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
179
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
180 181 182 183
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
184
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
T
tangwei12 已提交
185

T
tangwei12 已提交
186 187 188 189 190 191 192 193
    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
           weight_mat.chip(sample_labels_data[i], 0))
              .sum();
      sample_out_data[i] += result(0);
      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
W
wanghaoshuang 已提交
194
    }
T
tangwei12 已提交
195

W
wanghaoshuang 已提交
196
    // forward cost
197
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
W
wanghaoshuang 已提交
198 199
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
200 201 202 203 204
      for (int64_t j = 0; j < sampled_labels_num; ++j) {
        int64_t target = sample_labels_data[i * sampled_labels_num + j];
        T o = sample_out_data[i * sampled_labels_num + j];
        float b = sampler->Probability(target) * num_neg_samples;
        T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
W
wanghaoshuang 已提交
205 206 207
        out_data[i] += w * cost;
      }
    }
208
    delete sampler;
W
wanghaoshuang 已提交
209 210 211
  }
};

Q
QI JUN 已提交
212
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
213 214
class NCEGradKernel : public framework::OpKernel<T> {
 public:
215
  void Compute(const framework::ExecutionContext &context) const override {
W
wanghaoshuang 已提交
216
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
217
    const T *d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
218 219
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
220
    const T *sample_out_data = sample_out->data<T>();
W
wanghaoshuang 已提交
221
    auto sample_labels = context.Input<Tensor>("SampleLabels");
222
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
223
    auto sample_weight = context.Input<Tensor>("SampleWeight");
224
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
225 226 227
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
228 229
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
230 231 232 233
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
234 235 236

    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
237
    Sampler *sampler;
238 239 240 241 242 243 244 245 246 247
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
248 249 250 251
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

252 253
        PADDLE_ENFORCE_EQ(
            dist_probs->numel(), num_total_classes,
254 255 256 257 258 259
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_probs->numel(), num_total_classes));
260 261
        PADDLE_ENFORCE_EQ(
            dist_alias->numel(), num_total_classes,
262 263 264 265 266 267
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistAlias) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_alias->numel(), num_total_classes));
268 269
        PADDLE_ENFORCE_EQ(
            dist_alias_probs->numel(), num_total_classes,
270 271 272 273 274 275 276
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in "
                "Input(CustomDistAliasProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAliasProbs).numel() = %d, "
                "Attr(num_total_classes) = %d.",
                dist_alias_probs->numel(), num_total_classes));
277 278 279 280 281 282

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
283 284
        break;
      }
F
Feiyu Chan 已提交
285 286 287 288 289 290
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Unsupported SamplerType. SamplerType should be 0: Uniform, "
            "1: LogUniform or 2: CostumDist. Received SamplerType: %d",
            sampler_type));
      }
291 292 293
    }

    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
294
    Tensor sample_grad;  // tmp tensor
295
    T *sample_grad_data =
W
wanghaoshuang 已提交
296 297
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
298
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
299 300 301
      int64_t label_idx = i % sample_labels->dims()[1];
      int64_t sample_idx = i / sample_labels->dims()[1];
      float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
W
wanghaoshuang 已提交
302
      T o = sample_out_data[i];
303 304
      T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
      sample_grad_data[i] = label_idx < num_true_class
W
wanghaoshuang 已提交
305 306
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
307
      sample_grad_data[i] *= d_out_data[sample_idx];
W
wanghaoshuang 已提交
308
    }
309

310 311 312 313 314 315 316 317 318 319
    // get d_bias
    auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
    if (d_bias != nullptr) {
      T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
      std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
      }
    }

320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    bool is_sparse = context.Attr<bool>("is_sparse");

    if (!is_sparse) {
      // get d_w
      auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
      if (d_w != nullptr) {
        auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
        std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
        auto d_w_matrix = EigenMatrix<T>::From(*d_w);
        auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_w_matrix.chip(sample_labels_data[i], 0) +=
              x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
              sample_grad_data[i];
        }
      }
    } else {
      std::vector<int64_t> labels;
338
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
339
        labels.push_back(sample_labels_data[i]);
W
wanghaoshuang 已提交
340
      }
341 342 343 344 345 346 347 348 349 350 351
      std::set<T> st(labels.begin(), labels.end());
      labels.assign(st.begin(), st.end());

      auto *table_var = context.InputVar("Weight");
      DDim table_dim;
      if (table_var->IsType<LoDTensor>()) {
        table_dim = context.Input<LoDTensor>("Weight")->dims();
      } else if (table_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Weight");
        table_dim = table_t->value().dims();
      } else {
F
Feiyu Chan 已提交
352
        PADDLE_THROW(platform::errors::InvalidArgument(
353
            "The parameter Weight of a NCE_OP "
F
Feiyu Chan 已提交
354
            "must be either LoDTensor or SelectedRows"));
355 356 357 358 359 360 361 362 363 364 365 366 367 368
      }

      auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));

      d_w->set_rows(labels);
      d_w->set_height(table_dim[0]);

      auto *d_table_value = d_w->mutable_value();
      d_table_value->Resize(
          {static_cast<int64_t>(labels.size()), table_dim[1]});
      auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);

      auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
W
wanghaoshuang 已提交
369
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
370
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
371
        d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
372
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
373 374 375
            sample_grad_data[i];
      }
    }
376

W
wanghaoshuang 已提交
377
    // get d_x
W
wanghaoshuang 已提交
378
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
379
    if (d_x != nullptr) {
380
      auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
Y
Yang Yu 已提交
381
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
382
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
383
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
384
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
385
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
386 387 388
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
389

390
    delete sampler;
W
wanghaoshuang 已提交
391 392 393 394
  }
};
}  // namespace operators
}  // namespace paddle