nce_op.h 10.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
W
wanghaoshuang 已提交
18
#include <random>
19
#include <vector>
Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
22
#include "paddle/fluid/operators/math/sampler.h"
W
wanghaoshuang 已提交
23 24 25 26
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {

27
using Tensor = framework::Tensor;
28
using Sampler = math::Sampler;
W
wanghaoshuang 已提交
29 30 31 32 33

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
34
template <typename DeviceContext, typename T>
35 36
void PrepareSamples(const framework::ExecutionContext& context,
                    Sampler* sampler) {
W
wanghaoshuang 已提交
37
  auto label = context.Input<Tensor>("Label");
W
wanghaoshuang 已提交
38
  const int64_t* label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
39
  auto label_dims = label->dims();
40
  //  int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
41
  // for unitest
W
wanghaoshuang 已提交
42 43
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
44 45 46

  auto sample_labels = context.Output<Tensor>("SampleLabels");
  auto sample_labels_dims = sample_labels->dims();
W
wanghaoshuang 已提交
47 48
  int64_t* sample_labels_data =
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
49 50

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
51
  int index = 0;
52
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
53 54
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
55
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
56
    }
W
wanghaoshuang 已提交
57 58
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
59 60 61 62
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
63
        // TODO(wanghaoshuang): support more distribution sampling
64
        sample_labels_data[index++] = sampler->Sample();
W
wanghaoshuang 已提交
65
      }
W
wanghaoshuang 已提交
66 67 68 69
    }
  }
}

Q
QI JUN 已提交
70
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
71 72 73
class NCEKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    int num_total_classes = context.Attr<int>("num_total_classes");
    int num_neg_samples = context.Attr<int>("num_neg_samples");

    Sampler* sampler;
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
        auto custom_dist = context.Input<Tensor>("CustomDistribution");
        const float* custom_dist_data = custom_dist->data<float>();
        PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
        sampler = new math::CustomSampler(num_total_classes - 1,
                                          custom_dist_data, seed);
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    PrepareSamples<DeviceContext, T>(context, sampler);
W
wanghaoshuang 已提交
101
    auto sample_labels = context.Output<Tensor>("SampleLabels");
W
wanghaoshuang 已提交
102
    const int64_t* sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
103 104 105 106 107 108 109 110
    auto sample_out = context.Output<Tensor>("SampleLogits");
    T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
    auto label = context.Input<Tensor>("Label");
    auto sample_weight = context.Input<Tensor>("SampleWeight");
    const T* sample_weight_data = nullptr;
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
111
    auto out = context.Output<Tensor>("Cost");
W
wanghaoshuang 已提交
112
    T* out_data = out->mutable_data<T>(context.GetPlace());
113
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
114 115 116
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
117 118
    int64_t sampled_labels_num = sample_labels->dims()[1];
    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
119
    // forward bias
W
wanghaoshuang 已提交
120
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
121 122
    if (bias != nullptr) {
      const T* bias_data = bias->data<T>();
123
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
124 125 126
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
127
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
128 129 130 131
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
132 133
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
134
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
135
      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
136
          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
137 138 139
           weight_mat.chip(sample_labels_data[i], 0))
              .sum();
      sample_out_data[i] += result(0);
W
wanghaoshuang 已提交
140
      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
W
wanghaoshuang 已提交
141 142
    }
    // forward cost
143
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
W
wanghaoshuang 已提交
144 145
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
146 147 148 149 150
      for (int64_t j = 0; j < sampled_labels_num; ++j) {
        int64_t target = sample_labels_data[i * sampled_labels_num + j];
        T o = sample_out_data[i * sampled_labels_num + j];
        float b = sampler->Probability(target) * num_neg_samples;
        T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
W
wanghaoshuang 已提交
151 152 153
        out_data[i] += w * cost;
      }
    }
154
    delete sampler;
W
wanghaoshuang 已提交
155 156 157
  }
};

Q
QI JUN 已提交
158
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
159 160 161
class NCEGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
W
wanghaoshuang 已提交
162 163
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
    const T* d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
164 165 166 167
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
    const T* sample_out_data = sample_out->data<T>();
    auto sample_labels = context.Input<Tensor>("SampleLabels");
W
wanghaoshuang 已提交
168
    const int64_t* sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
169 170 171 172 173
    auto sample_weight = context.Input<Tensor>("SampleWeight");
    const T* sample_weight_data = nullptr;
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
174 175
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
176 177 178 179
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

    int sampler_type = context.Attr<int>("sampler");
    int seed = context.Attr<int>("seed");
    Sampler* sampler;
    switch (sampler_type) {
      case 0: {
        sampler = new math::UniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 1: {
        sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
        break;
      }
      case 2: {
        auto custom_dist = context.Input<Tensor>("CustomDistribution");
        const float* custom_dist_data = custom_dist->data<float>();
        PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
        sampler = new math::CustomSampler(num_total_classes - 1,
                                          custom_dist_data, seed);
        break;
      }
      default: { PADDLE_THROW("Unsupported SamplerType."); }
    }

    //    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
205 206 207 208
    Tensor sample_grad;  // tmp tensor
    T* sample_grad_data =
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
209
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
210 211 212
      int64_t label_idx = i % sample_labels->dims()[1];
      int64_t sample_idx = i / sample_labels->dims()[1];
      float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
W
wanghaoshuang 已提交
213
      T o = sample_out_data[i];
214 215
      T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
      sample_grad_data[i] = label_idx < num_true_class
W
wanghaoshuang 已提交
216 217
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
218
      sample_grad_data[i] *= d_out_data[sample_idx];
W
wanghaoshuang 已提交
219 220
    }
    // get d_bias
W
wanghaoshuang 已提交
221
    auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
W
wanghaoshuang 已提交
222 223
    if (d_bias != nullptr) {
      T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
224
      std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
225
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
226 227 228 229
        d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
      }
    }
    // get d_w
W
wanghaoshuang 已提交
230
    auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
W
wanghaoshuang 已提交
231
    if (d_w != nullptr) {
W
wanghaoshuang 已提交
232 233
      auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
W
wanghaoshuang 已提交
234
      auto d_w_matrix = EigenMatrix<T>::From(*d_w);
W
wanghaoshuang 已提交
235
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
236
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
237
        d_w_matrix.chip(sample_labels_data[i], 0) +=
238
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
239 240 241 242
            sample_grad_data[i];
      }
    }
    // get d_x
W
wanghaoshuang 已提交
243
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
244
    if (d_x != nullptr) {
Y
Yang Yu 已提交
245 246
      auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
247
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
248
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
249
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
250
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
251 252 253
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
254
    delete sampler;
W
wanghaoshuang 已提交
255 256 257 258
  }
};
}  // namespace operators
}  // namespace paddle