nce_op.h 16.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
T
tangwei12 已提交
18
#include <iterator>
W
wanghaoshuang 已提交
19
#include <random>
20
#include <set>
T
tangwei12 已提交
21
#include <string>
22
#include <vector>
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
25
#include "paddle/fluid/framework/selected_rows_utils.h"
26
#include "paddle/fluid/operators/math/sampler.h"
W
wanghaoshuang 已提交
27
#include "unsupported/Eigen/CXX11/Tensor"
28

W
wanghaoshuang 已提交
29 30 31
namespace paddle {
namespace operators {

32
using Tensor = framework::Tensor;
33 34
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
35
using Sampler = math::Sampler;
36
using DDim = framework::DDim;
W
wanghaoshuang 已提交
37 38 39 40 41

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
42
template <typename DeviceContext, typename T>
43
void PrepareSamples(const framework::ExecutionContext &context,
P
pangyoki 已提交
44
                    Sampler *sampler, Tensor *sample_labels) {
W
wanghaoshuang 已提交
45
  auto label = context.Input<Tensor>("Label");
46
  const int64_t *label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
47
  auto label_dims = label->dims();
W
wanghaoshuang 已提交
48
  // for unitest
W
wanghaoshuang 已提交
49 50
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
51 52

  auto sample_labels_dims = sample_labels->dims();
53
  int64_t *sample_labels_data =
W
wanghaoshuang 已提交
54
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
55 56

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
57
  int index = 0;
58
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
59 60
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
61
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
62
    }
W
wanghaoshuang 已提交
63 64
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
65 66 67 68
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
69
        // TODO(wanghaoshuang): support more distribution sampling
70
        sample_labels_data[index++] = sampler->Sample();
W
wanghaoshuang 已提交
71
      }
W
wanghaoshuang 已提交
72 73 74 75
    }
  }
}

Q
QI JUN 已提交
76
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
77 78
class NCEKernel : public framework::OpKernel<T> {
 public:
79
  void Compute(const framework::ExecutionContext &context) const override {
80 81 82 83
    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    int num_total_classes = context.Attr<int>("num_total_classes");
    int num_neg_samples = context.Attr<int>("num_neg_samples");
P
pangyoki 已提交
84
    bool is_test = context.Attr<bool>("is_test");
85

86
    Sampler *sampler;
87 88 89 90 91 92 93 94 95 96
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
97 98 99 100
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

101 102
        PADDLE_ENFORCE_EQ(
            dist_probs->numel(), num_total_classes,
103 104 105 106 107 108
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_probs->numel(), num_total_classes));
109 110
        PADDLE_ENFORCE_EQ(
            dist_alias->numel(), num_total_classes,
111 112 113 114 115 116
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistAlias) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_alias->numel(), num_total_classes));
117 118
        PADDLE_ENFORCE_EQ(
            dist_alias_probs->numel(), num_total_classes,
119 120 121 122 123 124 125
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in "
                "Input(CustomDistAliasProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAliasProbs).numel() = %d, "
                "Attr(num_total_classes) = %d.",
                dist_alias_probs->numel(), num_total_classes));
126 127 128 129 130 131

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
132 133
        break;
      }
F
Feiyu Chan 已提交
134 135 136 137 138 139
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Unsupported SamplerType. SamplerType should be 0: Uniform, "
            "1: LogUniform or 2: CostumDist. Received SamplerType: %d",
            sampler_type));
      }
140 141
    }

P
pangyoki 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    std::vector<int64_t> sample_out_dims;
    auto label = context.Input<Tensor>("Label");
    Tensor *sample_labels;
    Tensor *sample_out;
    Tensor sample_labels_tmp, sample_out_tmp;
    if (is_test) {
      // set dims of output(SampleOut)
      int num_true_classes = label->dims().size() == 2 ? label->dims()[1] : 1;
      sample_out_dims.push_back((context.Input<Tensor>("Input"))->dims()[0]);
      sample_out_dims.push_back(
          (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));

      sample_labels = &sample_labels_tmp;
      sample_labels->Resize(framework::make_ddim(sample_out_dims));

      sample_out = &sample_out_tmp;
      sample_out->Resize(framework::make_ddim(sample_out_dims));
    } else {
      sample_labels = context.Output<Tensor>("SampleLabels");
      sample_out = context.Output<Tensor>("SampleLogits");
    }

    PrepareSamples<DeviceContext, T>(context, sampler, sample_labels);
165
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
166 167

    for (int x = 0; x < sample_labels->numel(); x++) {
168
      PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
169 170 171 172 173
                        platform::errors::InvalidArgument(
                            "ValueError: Every sample label should be "
                            "non-negative. But received: "
                            "Input(SampleLabels)[%d] = %d",
                            x, sample_labels_data[x]));
174 175
    }

176
    T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
177
    auto sample_weight = context.Input<Tensor>("SampleWeight");
178
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
179 180 181
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
182
    auto out = context.Output<Tensor>("Cost");
183
    T *out_data = out->mutable_data<T>(context.GetPlace());
184
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
185 186 187
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
188 189
    int64_t sampled_labels_num = sample_labels->dims()[1];
    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
190
    // forward bias
W
wanghaoshuang 已提交
191
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
192
    if (bias != nullptr) {
193
      const T *bias_data = bias->data<T>();
194
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
195 196 197
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
198
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
199 200 201 202
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
203
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
T
tangwei12 已提交
204

T
tangwei12 已提交
205 206 207 208 209 210 211 212
    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
           weight_mat.chip(sample_labels_data[i], 0))
              .sum();
      sample_out_data[i] += result(0);
      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
W
wanghaoshuang 已提交
213
    }
T
tangwei12 已提交
214

W
wanghaoshuang 已提交
215
    // forward cost
216
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
W
wanghaoshuang 已提交
217 218
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
219 220 221 222 223
      for (int64_t j = 0; j < sampled_labels_num; ++j) {
        int64_t target = sample_labels_data[i * sampled_labels_num + j];
        T o = sample_out_data[i * sampled_labels_num + j];
        float b = sampler->Probability(target) * num_neg_samples;
        T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
W
wanghaoshuang 已提交
224 225 226
        out_data[i] += w * cost;
      }
    }
227
    delete sampler;
W
wanghaoshuang 已提交
228 229 230
  }
};

Q
QI JUN 已提交
231
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
232 233
class NCEGradKernel : public framework::OpKernel<T> {
 public:
234
  void Compute(const framework::ExecutionContext &context) const override {
W
wanghaoshuang 已提交
235
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
236
    const T *d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
237 238
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
239
    const T *sample_out_data = sample_out->data<T>();
W
wanghaoshuang 已提交
240
    auto sample_labels = context.Input<Tensor>("SampleLabels");
241
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
242
    auto sample_weight = context.Input<Tensor>("SampleWeight");
243
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
244 245 246
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
247 248
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
249 250 251 252
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
253 254 255

    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
256
    Sampler *sampler;
257 258 259 260 261 262 263 264 265 266
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
267 268 269 270
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

271 272
        PADDLE_ENFORCE_EQ(
            dist_probs->numel(), num_total_classes,
273 274 275 276 277 278
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_probs->numel(), num_total_classes));
279 280
        PADDLE_ENFORCE_EQ(
            dist_alias->numel(), num_total_classes,
281 282 283 284 285 286
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in Input(CustomDistAlias) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
                "= %d.",
                dist_alias->numel(), num_total_classes));
287 288
        PADDLE_ENFORCE_EQ(
            dist_alias_probs->numel(), num_total_classes,
289 290 291 292 293 294 295
            platform::errors::InvalidArgument(
                "ShapeError: The number of elements in "
                "Input(CustomDistAliasProbs) "
                "should be equal to the number of total classes. But Received: "
                "Input(CustomDistAliasProbs).numel() = %d, "
                "Attr(num_total_classes) = %d.",
                dist_alias_probs->numel(), num_total_classes));
296 297 298 299 300 301

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
302 303
        break;
      }
F
Feiyu Chan 已提交
304 305 306 307 308 309
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Unsupported SamplerType. SamplerType should be 0: Uniform, "
            "1: LogUniform or 2: CostumDist. Received SamplerType: %d",
            sampler_type));
      }
310 311 312
    }

    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
313
    Tensor sample_grad;  // tmp tensor
314
    T *sample_grad_data =
W
wanghaoshuang 已提交
315 316
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
317
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
318 319 320
      int64_t label_idx = i % sample_labels->dims()[1];
      int64_t sample_idx = i / sample_labels->dims()[1];
      float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
W
wanghaoshuang 已提交
321
      T o = sample_out_data[i];
322 323
      T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
      sample_grad_data[i] = label_idx < num_true_class
W
wanghaoshuang 已提交
324 325
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
326
      sample_grad_data[i] *= d_out_data[sample_idx];
W
wanghaoshuang 已提交
327
    }
328

329 330 331 332 333 334 335 336 337 338
    // get d_bias
    auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
    if (d_bias != nullptr) {
      T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
      std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
      }
    }

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    bool is_sparse = context.Attr<bool>("is_sparse");

    if (!is_sparse) {
      // get d_w
      auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
      if (d_w != nullptr) {
        auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
        std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
        auto d_w_matrix = EigenMatrix<T>::From(*d_w);
        auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_w_matrix.chip(sample_labels_data[i], 0) +=
              x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
              sample_grad_data[i];
        }
      }
    } else {
      std::vector<int64_t> labels;
357
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
358
        labels.push_back(sample_labels_data[i]);
W
wanghaoshuang 已提交
359
      }
360 361 362 363 364 365 366 367 368 369 370
      std::set<T> st(labels.begin(), labels.end());
      labels.assign(st.begin(), st.end());

      auto *table_var = context.InputVar("Weight");
      DDim table_dim;
      if (table_var->IsType<LoDTensor>()) {
        table_dim = context.Input<LoDTensor>("Weight")->dims();
      } else if (table_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Weight");
        table_dim = table_t->value().dims();
      } else {
F
Feiyu Chan 已提交
371
        PADDLE_THROW(platform::errors::InvalidArgument(
372
            "The parameter Weight of a NCE_OP "
F
Feiyu Chan 已提交
373
            "must be either LoDTensor or SelectedRows"));
374 375 376 377 378 379 380 381 382 383 384 385 386 387
      }

      auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));

      d_w->set_rows(labels);
      d_w->set_height(table_dim[0]);

      auto *d_table_value = d_w->mutable_value();
      d_table_value->Resize(
          {static_cast<int64_t>(labels.size()), table_dim[1]});
      auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);

      auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
W
wanghaoshuang 已提交
388
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
389
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
390
        d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
391
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
392 393 394
            sample_grad_data[i];
      }
    }
395

W
wanghaoshuang 已提交
396
    // get d_x
W
wanghaoshuang 已提交
397
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
398
    if (d_x != nullptr) {
399
      auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
Y
Yang Yu 已提交
400
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
401
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
402
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
403
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
404
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
405 406 407
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
408

409
    delete sampler;
W
wanghaoshuang 已提交
410 411 412 413
  }
};
}  // namespace operators
}  // namespace paddle