softmax_with_cross_entropy_op.h 10.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
caoying03 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
caoying03 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
20
#include "paddle/fluid/operators/softmax_op.h"
C
caoying03 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

27
template <typename T>
Y
Yu Yang 已提交
28
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
C
caoying03 已提交
29
 public:
C
caoying03 已提交
30
  void Compute(const framework::ExecutionContext& context) const override {
31 32 33
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(context.GetPlace()), true,
        platform::errors::Unimplemented("This kernel only runs on CPU."));
34
    const bool use_softmax = context.Attr<bool>("use_softmax");
35 36

    // do not with softmax op, and input is softmax
37
    if (!use_softmax) {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");
      const bool soft_label = context.Attr<bool>("soft_label");
      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

      softmax_out->mutable_data<T>(context.GetPlace());
      loss->mutable_data<T>(context.GetPlace());

      const int n = SizeToAxis(axis, softmax->dims());
      const int d = SizeFromAxis(axis, softmax->dims());

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      auto& dev_ctx =
          context.template device_context<platform::CPUDeviceContext>();

      math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
          dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
          context.Attr<int>("ignore_index"), axis_dim);

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
74
    const Tensor* logits = context.Input<Tensor>("Logits");
75
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
76
    Tensor* softmax = context.Output<Tensor>("Softmax");
77
    Tensor* loss = context.Output<Tensor>("Loss");
78 79 80 81 82
    const bool soft_label = context.Attr<bool>("soft_label");

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];
C
caoying03 已提交
83

84 85
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
C
caoying03 已提交
86

87 88 89 90 91 92 93
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());
    Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
    logits_2d.ShareDataWith(*logits).Resize({n, d});
    softmax_2d.ShareDataWith(*softmax).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
D
dengkaipeng 已提交
94

Q
QI JUN 已提交
95 96
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
97
    math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
98
        dev_ctx, axis_dim, &logits_2d, &softmax_2d);
Q
QI JUN 已提交
99
    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
100 101
        dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
        context.Attr<int>("ignore_index"), axis_dim);
C
caoying03 已提交
102
  }
C
caoying03 已提交
103 104
};

105
template <typename T>
Y
Yu Yang 已提交
106
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
C
caoying03 已提交
107
 public:
108
  void Compute(const framework::ExecutionContext& context) const override {
109 110 111
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
    const Tensor* labels = context.Input<Tensor>("Label");
112 113
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
114
    const Tensor* softmax = context.Input<Tensor>("Softmax");
115 116
    const bool use_softmax = context.Attr<bool>("use_softmax");
    if (logit_grad != softmax || !use_softmax) {
Z
Zeng Jinle 已提交
117 118 119
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
120
    const bool soft_label = context.Attr<bool>("soft_label");
121
    auto ignore_index = context.Attr<int>("ignore_index");
122 123 124 125 126 127 128 129 130 131 132

    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    Tensor logit_grad_2d, labels_2d, out_grad_2d;
    logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
W
wuhuanzhou 已提交
133 134
    auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
    auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
Q
QI JUN 已提交
135 136
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
137 138
    if (!use_softmax) {
      // use_softmax step1
139 140 141 142 143 144 145
      if (soft_label) {
        auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
        logit_grad_mat.device(place) =
            (-lbl_mat / logit_grad_mat);  // for each sample ,i  is sample id
        logit_grad_mat.device(place) =
            out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
            logit_grad_mat;
H
hong 已提交
146 147
      } else {
        // use_softmax step2
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        const int64_t* label_data = labels->data<int64_t>();
        T* logit_grad_data = logit_grad->data<T>();
        const T* out_grad_data = out_grad->data<T>();
        const int remain = d / axis_dim;
        for (int i = 0; i < n; ++i) {         // for each sample_1_dim
          for (int j = 0; j < remain; j++) {  // for each sample_other_dims
            int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                       // remain=1 and j=0, so, idx = i
            if (label_data[idx] == ignore_index) {
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                logit_grad_data[i * d + k * remain + j] = 0;
              }
            } else {
              // only for this sample's label_idx, the label is 1, others is 0,
              // so, only compute this label_idx's class
              logit_grad_data[i * d + label_data[idx] * remain + j] =
                  (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
                  out_grad_data[idx];
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                if (k !=
                    label_data[idx]) {  // label_data[idx]: this sample's label
                  logit_grad_data[i * d + k * remain + j] = 0;
                }
              }
            }
          }
        }
      }
      return;
    }
178
    // for use_softmax=False, continue
179

180
    if (soft_label) {
181
      // when soft_label = True, ignore_index is not supported
W
wuhuanzhou 已提交
182
      auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
Q
QI JUN 已提交
183
      logit_grad_mat.device(place) =
184
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
185 186 187 188 189 190 191 192
          (logit_grad_mat - lbl_mat);  // for each sample ,i  is sample id
      //         1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
      //            P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
      //            all class's labels
      //         2) compute dy * dy/dx by   Chain rule, dy=out_grad_mat[i]
      // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
      // operation

193
    } else {
Q
QI JUN 已提交
194
      logit_grad_mat.device(place) =
195
          logit_grad_mat *  // element_wise multiply
196
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
197

C
caoying03 已提交
198
      const int64_t* label_data = labels->data<int64_t>();
199
      T* logit_grad_data = logit_grad->data<T>();
C
caoying03 已提交
200
      const T* out_grad_data = out_grad->data<T>();
201
      const int remain = d / axis_dim;
202 203 204 205
      for (int i = 0; i < n; ++i) {         // for each sample_1_dim
        for (int j = 0; j < remain; j++) {  // for each sample_other_dims
          int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                     // remain=1 and j=0, so, idx = i
206
          if (label_data[idx] == ignore_index) {
207
            for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
208 209 210
              logit_grad_data[i * d + k * remain + j] = 0;
            }
          } else {
211 212 213 214 215 216 217 218 219 220 221 222
            // only for this sample's label_idx, the label is 1, others is 0,
            // so, only compute this label_idx's class
            // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
            // remain + j] = [i * d + label_data[idx]]
            // let idx_x = i * d + label_data[idx] * remain + j,
            //   logit_grad_data[idx_x] = logit_grad_data[idx_x] -
            //   out_grad_data[idx]
            // note: logit_grad_mat = logit_grad_mat * out_grad_mat
            // so: logit_grad_data[idx_x] =  (logit_grad_data[idx_x] - 1) *
            // out_grad_data[idx]
            // means:           dy/dp * dy=   ( p - y ) * dy

223 224 225
            logit_grad_data[i * d + label_data[idx] * remain + j] -=
                out_grad_data[idx];
          }
226
        }
227
      }
228 229
    }
  }
C
caoying03 已提交
230 231 232 233
};

}  // namespace operators
}  // namespace paddle