nce_op.h 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

W
wanghaoshuang 已提交
17
#include <math.h>
W
wanghaoshuang 已提交
18
#include <random>
19
#include <vector>
Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
W
wanghaoshuang 已提交
22 23 24 25
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {

26
using Tensor = framework::Tensor;
W
wanghaoshuang 已提交
27 28 29 30 31

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
32
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
33 34
void PrepareSamples(const framework::ExecutionContext& context) {
  auto label = context.Input<Tensor>("Label");
W
wanghaoshuang 已提交
35
  const int64_t* label_data = label->data<int64_t>();
W
wanghaoshuang 已提交
36
  auto label_dims = label->dims();
W
wanghaoshuang 已提交
37
  int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
38
  // for unitest
W
wanghaoshuang 已提交
39 40
  std::vector<int> custom_neg_classes =
      context.Attr<std::vector<int>>("custom_neg_classes");
W
wanghaoshuang 已提交
41 42 43
  // random machine
  std::random_device rd;
  std::mt19937 rng(rd());
W
wanghaoshuang 已提交
44
  std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
W
wanghaoshuang 已提交
45 46 47

  auto sample_labels = context.Output<Tensor>("SampleLabels");
  auto sample_labels_dims = sample_labels->dims();
W
wanghaoshuang 已提交
48 49
  int64_t* sample_labels_data =
      sample_labels->mutable_data<int64_t>(context.GetPlace());
W
wanghaoshuang 已提交
50 51

  int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
W
wanghaoshuang 已提交
52
  int index = 0;
53
  for (int64_t i = 0; i < label_dims[0]; ++i) {
W
wanghaoshuang 已提交
54 55
    int j = 0;
    for (; j < num_label; ++j) {
W
wanghaoshuang 已提交
56
      sample_labels_data[index++] = label_data[i * num_label + j];
W
wanghaoshuang 已提交
57
    }
W
wanghaoshuang 已提交
58 59
    if (custom_neg_classes.size() > 0) {
      for (auto label : custom_neg_classes) {
W
wanghaoshuang 已提交
60 61 62 63
        sample_labels_data[index++] = label;
      }
    } else {
      for (; j < sample_labels_dims[1]; ++j) {
W
wanghaoshuang 已提交
64
        // TODO(wanghaoshuang): support more distribution sampling
W
wanghaoshuang 已提交
65 66
        sample_labels_data[index++] = rand(rng);
      }
W
wanghaoshuang 已提交
67 68 69 70
    }
  }
}

Q
QI JUN 已提交
71
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
72 73 74
class NCEKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Q
QI JUN 已提交
75
    PrepareSamples<DeviceContext, T>(context);
W
wanghaoshuang 已提交
76
    auto sample_labels = context.Output<Tensor>("SampleLabels");
W
wanghaoshuang 已提交
77
    const int64_t* sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
78 79 80 81 82 83 84 85
    auto sample_out = context.Output<Tensor>("SampleLogits");
    T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
    auto label = context.Input<Tensor>("Label");
    auto sample_weight = context.Input<Tensor>("SampleWeight");
    const T* sample_weight_data = nullptr;
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
86
    auto out = context.Output<Tensor>("Cost");
W
wanghaoshuang 已提交
87
    T* out_data = out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
88 89
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
90
    int64_t num_true_class = 1;
W
wanghaoshuang 已提交
91 92 93
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
W
wanghaoshuang 已提交
94
    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
95
    // forward bias
W
wanghaoshuang 已提交
96
    auto bias = context.Input<Tensor>("Bias");
W
wanghaoshuang 已提交
97 98
    if (bias != nullptr) {
      const T* bias_data = bias->data<T>();
99
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
100 101 102
        sample_out_data[i] = bias_data[sample_labels_data[i]];
      }
    } else {
103
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
104 105 106 107
        sample_out_data[i] = 0;
      }
    }
    // forward mul
W
wanghaoshuang 已提交
108 109
    auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
110
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
111
      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
112
          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
113 114 115
           weight_mat.chip(sample_labels_data[i], 0))
              .sum();
      sample_out_data[i] += result(0);
W
wanghaoshuang 已提交
116
      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
W
wanghaoshuang 已提交
117 118
    }
    // forward cost
119 120
    for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
      int64_t j = 0;
W
wanghaoshuang 已提交
121 122
      out_data[i] = 0;
      T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
W
wanghaoshuang 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
      // for true classes
      for (; j < num_true_class; ++j) {
        T o = sample_out_data[i * sample_out->dims()[1] + j];
        T cost = -log(o / (o + b));
        out_data[i] += w * cost;
      }
      // for sampled neg classes
      for (; j < sample_labels->dims()[1]; ++j) {
        T o = sample_out_data[i * sample_out->dims()[1] + j];
        T cost = -log(b / (o + b));
        out_data[i] += w * cost;
      }
    }
  }
};

Q
QI JUN 已提交
139
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
140 141 142
class NCEGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
W
wanghaoshuang 已提交
143 144
    auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
    const T* d_out_data = d_out->data<T>();
W
wanghaoshuang 已提交
145 146 147 148
    auto label = context.Input<Tensor>("Label");
    auto sample_out = context.Input<Tensor>("SampleLogits");
    const T* sample_out_data = sample_out->data<T>();
    auto sample_labels = context.Input<Tensor>("SampleLabels");
W
wanghaoshuang 已提交
149
    const int64_t* sample_labels_data = sample_labels->data<int64_t>();
W
wanghaoshuang 已提交
150 151 152 153 154
    auto sample_weight = context.Input<Tensor>("SampleWeight");
    const T* sample_weight_data = nullptr;
    if (sample_weight != nullptr) {
      sample_weight_data = sample_weight->data<T>();
    }
W
wanghaoshuang 已提交
155 156
    int num_neg_samples = context.Attr<int>("num_neg_samples");
    int num_total_classes = context.Attr<int>("num_total_classes");
W
wanghaoshuang 已提交
157 158 159 160
    int num_true_class = 1;
    if (label != nullptr) {
      num_true_class = label->dims()[1];
    }
W
wanghaoshuang 已提交
161
    T b = 1. / num_total_classes * num_neg_samples;
W
wanghaoshuang 已提交
162 163 164 165
    Tensor sample_grad;  // tmp tensor
    T* sample_grad_data =
        sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
    // backward cost
166
    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
167 168 169 170 171
      T o = sample_out_data[i];
      T w = sample_weight == nullptr
                ? 1
                : sample_weight_data[i / sample_labels->dims()[1]];
      sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
W
wanghaoshuang 已提交
172 173 174
                                ? w * (b / (o + b)) * (o - 1)
                                : w * (o * (1 - o) / (o + b));
      sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]];
W
wanghaoshuang 已提交
175 176
    }
    // get d_bias
W
wanghaoshuang 已提交
177
    auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
W
wanghaoshuang 已提交
178 179
    if (d_bias != nullptr) {
      T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
180
      std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
181
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
182 183 184 185
        d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
      }
    }
    // get d_w
W
wanghaoshuang 已提交
186
    auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
W
wanghaoshuang 已提交
187
    if (d_w != nullptr) {
W
wanghaoshuang 已提交
188 189
      auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
      std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
W
wanghaoshuang 已提交
190
      auto d_w_matrix = EigenMatrix<T>::From(*d_w);
W
wanghaoshuang 已提交
191
      auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
192
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
W
wanghaoshuang 已提交
193
        d_w_matrix.chip(sample_labels_data[i], 0) +=
194
            x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
W
wanghaoshuang 已提交
195 196 197 198
            sample_grad_data[i];
      }
    }
    // get d_x
W
wanghaoshuang 已提交
199
    auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
W
wanghaoshuang 已提交
200
    if (d_x != nullptr) {
Y
Yang Yu 已提交
201 202
      auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
W
wanghaoshuang 已提交
203
      auto d_x_matrix = EigenMatrix<T>::From(*d_x);
W
wanghaoshuang 已提交
204
      auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
205
      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
206
        d_x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) +=
W
wanghaoshuang 已提交
207 208 209 210 211 212 213
            w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
      }
    }
  }
};
}  // namespace operators
}  // namespace paddle