nce_op.h 15.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
T
tangwei12 已提交
18
#include <iterator>
W
wanghaoshuang 已提交
19
#include <random>
20
#include <set>
T
tangwei12 已提交
21
#include <string>
22
#include <vector>
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
25
#include "paddle/fluid/framework/selected_rows.h"
26
#include "paddle/fluid/operators/math/sampler.h"
W
wanghaoshuang 已提交
27
#include "unsupported/Eigen/CXX11/Tensor"
28

W
wanghaoshuang 已提交
29 30 31
namespace paddle {
namespace operators {

32
using Tensor = framework::Tensor;
33 34
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
35
using Sampler = math::Sampler;
36
using DDim = framework::DDim;
W
wanghaoshuang 已提交
37 38 39 40 41

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
42
template <typename DeviceContext, typename T>
43 44
void PrepareSamples(const framework::ExecutionContext &context,
                    Sampler *sampler) {
W
wanghaoshuang 已提交
45
  auto label = context.Input<Tensor>("Label");
46
  const int64_t *label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
47
  auto label_dims = label->dims();
48
  //  int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
49
  // for unitest
W
wanghaoshuang 已提交
50 51
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
52 53 54

  auto sample_labels = context.Output<Tensor>("SampleLabels");
  auto sample_labels_dims = sample_labels->dims();
55
  int64_t *sample_labels_data =
W
wanghaoshuang 已提交
56
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
57 58

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
59
  int index = 0;
60
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
61 62
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
63
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
64
    }
W
wanghaoshuang 已提交
65 66
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
67 68 69 70
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
71
        // TODO(wanghaoshuang): support more distribution sampling
72
        sample_labels_data[index++] = sampler->Sample();
W
wanghaoshuang 已提交
73
      }
W
wanghaoshuang 已提交
74 75 76 77
    }
  }
}

Q
QI JUN 已提交
78
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
79 80
class NCEKernel : public framework::OpKernel<T> {
 public:
81
  void Compute(const framework::ExecutionContext &context) const override {
82 83 84 85 86
    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    int num_total_classes = context.Attr<int>("num_total_classes");
    int num_neg_samples = context.Attr<int>("num_neg_samples");

87
    Sampler *sampler;
88 89 90 91 92 93 94 95 96 97
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
98 99 100 101 102 103 104 105 106 107 108 109 110
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

        PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
111 112 113 114 115 116
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    PrepareSamples<DeviceContext, T>(context, sampler);
W
wanghaoshuang 已提交
117
    auto sample_labels = context.Output<Tensor>("SampleLabels");
118
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
119
    auto sample_out = context.Output<Tensor>("SampleLogits");
120
    T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
121 122
    auto label = context.Input<Tensor>("Label");
    auto sample_weight = context.Input<Tensor>("SampleWeight");
123
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
124 125 126
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
127
    auto out = context.Output<Tensor>("Cost");
128
    T *out_data = out->mutable_data<T>(context.GetPlace());
129
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
130 131 132
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
133 134
    int64_t sampled_labels_num = sample_labels->dims()[1];
    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
135
    // forward bias
W
wanghaoshuang 已提交
136
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
137
    if (bias != nullptr) {
138
      const T *bias_data = bias->data<T>();
139
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
140 141 142
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
143
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
144 145 146 147
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
148
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
T
tangwei12 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

    // for remote prefetch
    auto epmap = context.Attr<std::vector<std::string>>("epmap");

    if (!epmap.empty()) {
      // if epmap is not empty, then the parameter will be fetched from remote
      // parameter
      // server

      std::vector<int64_t> labels;
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        labels.push_back(sample_labels_data[i]);
      }
      std::set<T> st(labels.begin(), labels.end());
      labels.assign(st.begin(), st.end());

      auto &local_scope = context.scope().NewScope();
      auto height_sections = context.Attr<std::vector<int>>("height_sections");
      auto table_names = context.Attr<std::vector<std::string>>("table_names");

      framework::Variable *ids = local_scope.Var("Ids");
      framework::Variable *weight = local_scope.Var("Weight");

#ifdef PADDLE_WITH_DISTRIBUTE
      operators::distributed::prefetch("Ids", "Weight", table_names, epmap,
                                       height_sections, context);
#else
      PADDLE_THROW(
          "paddle is not compiled with distribute support, can not do "
          "parameter prefetch!");

      auto weight_mat = EigenMatrix<T>::From(*(weight->Get<T>()));
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        std::vector<int64_t>::iterator it =
            std::find(labels.begin(), labels.end(), sample_labels_data[i]);
        int idx = std::distance(labels.begin(), it);

        Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
            (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
             weight_mat.chip(idx, 0))
                .sum();
        sample_out_data[i] += result(0);
        sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
      }
#endif
    } else {
      auto weight_mat =
          EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
            (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
             weight_mat.chip(sample_labels_data[i], 0))
                .sum();
        sample_out_data[i] += result(0);
        sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
      }
W
wanghaoshuang 已提交
205
    }
T
tangwei12 已提交
206

W
wanghaoshuang 已提交
207
    // forward cost
208
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
W
wanghaoshuang 已提交
209 210
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
211 212 213 214 215
      for (int64_t j = 0; j < sampled_labels_num; ++j) {
        int64_t target = sample_labels_data[i * sampled_labels_num + j];
        T o = sample_out_data[i * sampled_labels_num + j];
        float b = sampler->Probability(target) * num_neg_samples;
        T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
W
wanghaoshuang 已提交
216 217 218
        out_data[i] += w * cost;
      }
    }
219
    delete sampler;
W
wanghaoshuang 已提交
220 221 222
  }
};

Q
QI JUN 已提交
223
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
224 225
class NCEGradKernel : public framework::OpKernel<T> {
 public:
226
  void Compute(const framework::ExecutionContext &context) const override {
W
wanghaoshuang 已提交
227
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
228
    const T *d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
229 230
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
231
    const T *sample_out_data = sample_out->data<T>();
W
wanghaoshuang 已提交
232
    auto sample_labels = context.Input<Tensor>("SampleLabels");
233
    const int64_t *sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
234
    auto sample_weight = context.Input<Tensor>("SampleWeight");
235
    const T *sample_weight_data = nullptr;
W
wanghaoshuang 已提交
236 237 238
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
239 240
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
241 242 243 244
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
245 246 247

    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
248
    Sampler *sampler;
249 250 251 252 253 254 255 256 257 258
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
259 260 261 262 263 264 265 266 267 268 269 270 271
        auto dist_probs = context.Input<Tensor>("CustomDistProbs");
        auto dist_alias = context.Input<Tensor>("CustomDistAlias");
        auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");

        PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
        PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);

        const float *probs_data = dist_probs->data<float>();
        const int *alias_data = dist_alias->data<int>();
        const float *alias_probs_data = dist_alias_probs->data<float>();
        sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
                                          alias_data, alias_probs_data, seed);
272 273 274 275 276 277
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
278
    Tensor sample_grad;  // tmp tensor
279
    T *sample_grad_data =
W
wanghaoshuang 已提交
280 281
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
282
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
283 284 285
      int64_t label_idx = i % sample_labels->dims()[1];
      int64_t sample_idx = i / sample_labels->dims()[1];
      float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
W
wanghaoshuang 已提交
286
      T o = sample_out_data[i];
287 288
      T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
      sample_grad_data[i] = label_idx < num_true_class
W
wanghaoshuang 已提交
289 290
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
291
      sample_grad_data[i] *= d_out_data[sample_idx];
W
wanghaoshuang 已提交
292
    }
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320

    bool is_sparse = context.Attr<bool>("is_sparse");

    if (!is_sparse) {
      // get d_bias
      auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
      if (d_bias != nullptr) {
        T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
        std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
        }
      }
      // get d_w
      auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
      if (d_w != nullptr) {
        auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
        std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
        auto d_w_matrix = EigenMatrix<T>::From(*d_w);
        auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
          d_w_matrix.chip(sample_labels_data[i], 0) +=
              x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
              sample_grad_data[i];
        }
      }
    } else {
      std::vector<int64_t> labels;
321
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
322
        labels.push_back(sample_labels_data[i]);
W
wanghaoshuang 已提交
323
      }
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
      std::set<T> st(labels.begin(), labels.end());
      labels.assign(st.begin(), st.end());

      auto *bias_var = context.InputVar("Bias");
      DDim bias_dim;
      if (bias_var->IsType<LoDTensor>()) {
        bias_dim = context.Input<LoDTensor>("Bias")->dims();
      } else if (bias_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Bias");
        bias_dim = table_t->value().dims();
      } else {
        PADDLE_THROW(
            "The parameter Bias of a NCE_OP "
            "must be either LoDTensor or SelectedRows");
      }

      auto d_bias =
          context.Output<SelectedRows>(framework::GradVarName("Bias"));
      d_bias->set_rows(labels);
      d_bias->set_height(bias_dim[0]);

      d_bias->mutable_value()->Resize(
          {static_cast<int64_t>(labels.size()), bias_dim[1]});
      T *d_bias_data =
          d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
      std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
        d_bias_data[d_bias->Index(sample_labels_data[i])] +=
            sample_grad_data[i];
      }

      auto *table_var = context.InputVar("Weight");
      DDim table_dim;
      if (table_var->IsType<LoDTensor>()) {
        table_dim = context.Input<LoDTensor>("Weight")->dims();
      } else if (table_var->IsType<SelectedRows>()) {
        auto *table_t = context.Input<SelectedRows>("Weight");
        table_dim = table_t->value().dims();
      } else {
        PADDLE_THROW(
            "The parameter Weight of a NCE_OP "
            "must be either LoDTensor or SelectedRows");
      }

      auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));

      d_w->set_rows(labels);
      d_w->set_height(table_dim[0]);

      auto *d_table_value = d_w->mutable_value();
      d_table_value->Resize(
          {static_cast<int64_t>(labels.size()), table_dim[1]});
      auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);

      auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
W
wanghaoshuang 已提交
380
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
381
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
382
        d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
383
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
384 385 386
            sample_grad_data[i];
      }
    }
387

W
wanghaoshuang 已提交
388
    // get d_x
W
wanghaoshuang 已提交
389
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
390
    if (d_x != nullptr) {
391
      auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
Y
Yang Yu 已提交
392
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
393
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
394
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
395
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
396
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
397 398 399
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
400

401
    delete sampler;
W
wanghaoshuang 已提交
402 403 404 405
  }
};
}  // namespace operators
}  // namespace paddle