nce_op.h 13.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
W
wanghaoshuang 已提交
18
#include <random>
19
#include <set>
20
#include <vector>
Y
Yi Wang 已提交
21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/framework/selected_rows.h"
24
#include "paddle/fluid/operators/math/sampler.h"
W
wanghaoshuang 已提交
25
#include "unsupported/Eigen/CXX11/Tensor"
26

W
wanghaoshuang 已提交
27 28 29
namespace paddle {
namespace operators {

30
using Tensor = framework::Tensor;
31 32
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
33
using Sampler = math::Sampler;
34
using DDim = framework::DDim;
W
wanghaoshuang 已提交
35 36 37 38 39

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
40
template <typename DeviceContext, typename T>
41 42
void PrepareSamples(const framework::ExecutionContext &context,
                    Sampler *sampler) {
W
wanghaoshuang 已提交
43
  auto label = context.Input<Tensor>("Label");
44
  const int64_t *label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
45
  auto label_dims = label->dims();
46
  //  int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
47
  // for unitest
W
wanghaoshuang 已提交
48 49
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
50 51 52

  auto sample_labels = context.Output<Tensor>("SampleLabels");
  auto sample_labels_dims = sample_labels->dims();
53
  int64_t *sample_labels_data =
W
wanghaoshuang 已提交
54
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
55 56

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
57
  int index = 0;
58
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
59 60
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
61
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
62
    }
W
wanghaoshuang 已提交
63 64
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
65 66 67 68
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
69
        // TODO(wanghaoshuang): support more distribution sampling
70
        sample_labels_data[index++] = sampler->Sample();
W
wanghaoshuang 已提交
71
      }
W
wanghaoshuang 已提交
72 73 74 75
    }
  }
}

Q
QI JUN 已提交
76
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
77 78
class NCEKernel : public framework::OpKernel<T> {
 public:
79
  void Compute(const framework::ExecutionContext &context) const override {
80 81 82 83 84
    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    int num_total_classes = context.Attr<int>("num_total_classes");
    int num_neg_samples = context.Attr<int>("num_neg_samples");

85
    Sampler *sampler;
86 87 88 89 90 91 92 93 94 95
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
96 97 98 99 100 101 102 103 104 105 106 107 108
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

        PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
109 110 111 112 113 114
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    PrepareSamples<DeviceContext, T>(context, sampler);
W
wanghaoshuang 已提交
115
    auto sample_labels = context.Output<Tensor>("SampleLabels");
116
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
117
    auto sample_out = context.Output<Tensor>("SampleLogits");
118
    T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
119 120
    auto label = context.Input<Tensor>("Label");
    auto sample_weight = context.Input<Tensor>("SampleWeight");
121
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
122 123 124
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
125
    auto out = context.Output<Tensor>("Cost");
126
    T *out_data = out->mutable_data<T>(context.GetPlace());
127
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
128 129 130
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
131 132
    int64_t sampled_labels_num = sample_labels->dims()[1];
    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
133
    // forward bias
W
wanghaoshuang 已提交
134
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
135
    if (bias != nullptr) {
136
      const T *bias_data = bias->data<T>();
137
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
138 139 140
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
141
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
142 143 144 145
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
146 147
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
148
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
149
      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
150
          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
151 152 153
           weight_mat.chip(sample_labels_data[i], 0))
              .sum();
      sample_out_data[i] += result(0);
W
wanghaoshuang 已提交
154
      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
W
wanghaoshuang 已提交
155 156
    }
    // forward cost
157
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
W
wanghaoshuang 已提交
158 159
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
160 161 162 163 164
      for (int64_t j = 0; j < sampled_labels_num; ++j) {
        int64_t target = sample_labels_data[i * sampled_labels_num + j];
        T o = sample_out_data[i * sampled_labels_num + j];
        float b = sampler->Probability(target) * num_neg_samples;
        T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
W
wanghaoshuang 已提交
165 166 167
        out_data[i] += w * cost;
      }
    }
168
    delete sampler;
W
wanghaoshuang 已提交
169 170 171
  }
};

Q
QI JUN 已提交
172
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
173 174
class NCEGradKernel : public framework::OpKernel<T> {
 public:
175
  void Compute(const framework::ExecutionContext &context) const override {
W
wanghaoshuang 已提交
176
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
177
    const T *d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
178 179
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
180
    const T *sample_out_data = sample_out->data<T>();
W
wanghaoshuang 已提交
181
    auto sample_labels = context.Input<Tensor>("SampleLabels");
182
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
183
    auto sample_weight = context.Input<Tensor>("SampleWeight");
184
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
185 186 187
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
188 189
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
190 191 192 193
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
194 195 196

    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
197
    Sampler *sampler;
198 199 200 201 202 203 204 205 206 207
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
208 209 210 211 212 213 214 215 216 217 218 219 220
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

        PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
221 222 223 224 225 226
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
227
    Tensor sample_grad;  // tmp tensor
228
    T *sample_grad_data =
W
wanghaoshuang 已提交
229 230
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
231
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
232 233 234
      int64_t label_idx = i % sample_labels->dims()[1];
      int64_t sample_idx = i / sample_labels->dims()[1];
      float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
W
wanghaoshuang 已提交
235
      T o = sample_out_data[i];
236 237
      T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
      sample_grad_data[i] = label_idx < num_true_class
W
wanghaoshuang 已提交
238 239
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
240
      sample_grad_data[i] *= d_out_data[sample_idx];
W
wanghaoshuang 已提交
241
    }
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

    bool is_sparse = context.Attr<bool>("is_sparse");

    if (!is_sparse) {
      // get d_bias
      auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
      if (d_bias != nullptr) {
        T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
        std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
        }
      }
      // get d_w
      auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
      if (d_w != nullptr) {
        auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
        std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
        auto d_w_matrix = EigenMatrix<T>::From(*d_w);
        auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_w_matrix.chip(sample_labels_data[i], 0) +=
              x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
              sample_grad_data[i];
        }
      }
    } else {
      std::vector<int64_t> labels;
270
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
271
        labels.push_back(sample_labels_data[i]);
W
wanghaoshuang 已提交
272
      }
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
      std::set<T> st(labels.begin(), labels.end());
      labels.assign(st.begin(), st.end());

      auto *bias_var = context.InputVar("Bias");
      DDim bias_dim;
      if (bias_var->IsType<LoDTensor>()) {
        bias_dim = context.Input<LoDTensor>("Bias")->dims();
      } else if (bias_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Bias");
        bias_dim = table_t->value().dims();
      } else {
        PADDLE_THROW(
            "The parameter Bias of a NCE_OP "
            "must be either LoDTensor or SelectedRows");
      }

      auto d_bias =
          context.Output<SelectedRows>(framework::GradVarName("Bias"));
      d_bias->set_rows(labels);
      d_bias->set_height(bias_dim[0]);

      d_bias->mutable_value()->Resize(
          {static_cast<int64_t>(labels.size()), bias_dim[1]});
      T *d_bias_data =
          d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
      std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        d_bias_data[d_bias->Index(sample_labels_data[i])] +=
            sample_grad_data[i];
      }

      auto *table_var = context.InputVar("Weight");
      DDim table_dim;
      if (table_var->IsType<LoDTensor>()) {
        table_dim = context.Input<LoDTensor>("Weight")->dims();
      } else if (table_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Weight");
        table_dim = table_t->value().dims();
      } else {
        PADDLE_THROW(
            "The parameter Weight of a NCE_OP "
            "must be either LoDTensor or SelectedRows");
      }

      auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));

      d_w->set_rows(labels);
      d_w->set_height(table_dim[0]);

      auto *d_table_value = d_w->mutable_value();
      d_table_value->Resize(
          {static_cast<int64_t>(labels.size()), table_dim[1]});
      auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);

      auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
W
wanghaoshuang 已提交
329
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
330
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
331
        d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
332
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
333 334 335
            sample_grad_data[i];
      }
    }
336

W
wanghaoshuang 已提交
337
    // get d_x
W
wanghaoshuang 已提交
338
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
339
    if (d_x != nullptr) {
340
      auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
Y
Yang Yu 已提交
341
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
342
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
343
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
344
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
345
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
346 347 348
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
349

350
    delete sampler;
W
wanghaoshuang 已提交
351 352 353 354
  }
};
}  // namespace operators
}  // namespace paddle