softmax_with_cross_entropy_op.h 10.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
caoying03 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
caoying03 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
20
#include "paddle/fluid/operators/softmax_op.h"
C
caoying03 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

27
template <typename T>
Y
Yu Yang 已提交
28
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
C
caoying03 已提交
29
 public:
C
caoying03 已提交
30
  void Compute(const framework::ExecutionContext& context) const override {
P
phlrain 已提交
31
   
32 33 34
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(context.GetPlace()), true,
        platform::errors::Unimplemented("This kernel only runs on CPU."));
35
    const bool use_softmax = context.Attr<bool>("use_softmax");
36 37

    // do not with softmax op, and input is softmax
38
    if (!use_softmax) {
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");
      const bool soft_label = context.Attr<bool>("soft_label");
      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

      softmax_out->mutable_data<T>(context.GetPlace());
      loss->mutable_data<T>(context.GetPlace());

      const int n = SizeToAxis(axis, softmax->dims());
      const int d = SizeFromAxis(axis, softmax->dims());

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      auto& dev_ctx =
          context.template device_context<platform::CPUDeviceContext>();

      math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
          dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
          context.Attr<int>("ignore_index"), axis_dim);

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
75
    const Tensor* logits = context.Input<Tensor>("Logits");
76
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
77
    Tensor* softmax = context.Output<Tensor>("Softmax");
78
    Tensor* loss = context.Output<Tensor>("Loss");
79 80 81 82 83
    const bool soft_label = context.Attr<bool>("soft_label");

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];
C
caoying03 已提交
84

85 86
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
C
caoying03 已提交
87

88 89 90 91 92 93 94
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());
    Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
    logits_2d.ShareDataWith(*logits).Resize({n, d});
    softmax_2d.ShareDataWith(*softmax).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
D
dengkaipeng 已提交
95

Q
QI JUN 已提交
96 97
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
98
    math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
99
        dev_ctx, axis_dim, &logits_2d, &softmax_2d);
Q
QI JUN 已提交
100
    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
101 102
        dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
        context.Attr<int>("ignore_index"), axis_dim);
C
caoying03 已提交
103
  }
C
caoying03 已提交
104 105
};

106
template <typename T>
Y
Yu Yang 已提交
107
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
C
caoying03 已提交
108
 public:
109
  void Compute(const framework::ExecutionContext& context) const override {
P
phlrain 已提交
110
    
111 112
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
P
phlrain 已提交
113
    
114
    const Tensor* labels = context.Input<Tensor>("Label");
115 116
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
P
phlrain 已提交
117
    
Z
Zeng Jinle 已提交
118
    const Tensor* softmax = context.Input<Tensor>("Softmax");
P
phlrain 已提交
119
    
120
    const bool use_softmax = context.Attr<bool>("use_softmax");
P
phlrain 已提交
121
    
122
    if (logit_grad != softmax || !use_softmax) {
Z
Zeng Jinle 已提交
123 124 125
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
P
phlrain 已提交
126
    
127
    const bool soft_label = context.Attr<bool>("soft_label");
128
    auto ignore_index = context.Attr<int>("ignore_index");
129 130 131 132 133 134 135 136 137 138 139

    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    Tensor logit_grad_2d, labels_2d, out_grad_2d;
    logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
P
phlrain 已提交
140
    
W
wuhuanzhou 已提交
141 142
    auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
    auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
Q
QI JUN 已提交
143 144
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
145 146
    if (!use_softmax) {
      // use_softmax step1
147 148 149 150 151 152 153 154
      if (soft_label) {
        auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
        logit_grad_mat.device(place) =
            (-lbl_mat / logit_grad_mat);  // for each sample ,i  is sample id
        logit_grad_mat.device(place) =
            out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
            logit_grad_mat;
      }
155
      // use_softmax step2
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
      else {
        const int64_t* label_data = labels->data<int64_t>();
        T* logit_grad_data = logit_grad->data<T>();
        const T* out_grad_data = out_grad->data<T>();
        const int remain = d / axis_dim;
        for (int i = 0; i < n; ++i) {         // for each sample_1_dim
          for (int j = 0; j < remain; j++) {  // for each sample_other_dims
            int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                       // remain=1 and j=0, so, idx = i
            if (label_data[idx] == ignore_index) {
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                logit_grad_data[i * d + k * remain + j] = 0;
              }
            } else {
              // only for this sample's label_idx, the label is 1, others is 0,
              // so, only compute this label_idx's class
              logit_grad_data[i * d + label_data[idx] * remain + j] =
                  (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
                  out_grad_data[idx];
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                if (k !=
                    label_data[idx]) {  // label_data[idx]: this sample's label
                  logit_grad_data[i * d + k * remain + j] = 0;
                }
              }
            }
          }
        }
      }
      return;
    }
P
phlrain 已提交
187
    
188
    // for use_softmax=False, continue
189

190
    if (soft_label) {
191
      // when soft_label = True, ignore_index is not supported
W
wuhuanzhou 已提交
192
      auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
Q
QI JUN 已提交
193
      logit_grad_mat.device(place) =
194
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
195 196 197 198 199 200 201 202
          (logit_grad_mat - lbl_mat);  // for each sample ,i  is sample id
      //         1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
      //            P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
      //            all class's labels
      //         2) compute dy * dy/dx by   Chain rule, dy=out_grad_mat[i]
      // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
      // operation

203
    } else {
Q
QI JUN 已提交
204
      logit_grad_mat.device(place) =
205
          logit_grad_mat *  // element_wise multiply
206
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
207

C
caoying03 已提交
208
      const int64_t* label_data = labels->data<int64_t>();
209
      T* logit_grad_data = logit_grad->data<T>();
C
caoying03 已提交
210
      const T* out_grad_data = out_grad->data<T>();
211
      const int remain = d / axis_dim;
212 213 214 215
      for (int i = 0; i < n; ++i) {         // for each sample_1_dim
        for (int j = 0; j < remain; j++) {  // for each sample_other_dims
          int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                     // remain=1 and j=0, so, idx = i
216
          if (label_data[idx] == ignore_index) {
217
            for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
218 219 220
              logit_grad_data[i * d + k * remain + j] = 0;
            }
          } else {
221 222 223 224 225 226 227 228 229 230 231 232
            // only for this sample's label_idx, the label is 1, others is 0,
            // so, only compute this label_idx's class
            // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
            // remain + j] = [i * d + label_data[idx]]
            // let idx_x = i * d + label_data[idx] * remain + j,
            //   logit_grad_data[idx_x] = logit_grad_data[idx_x] -
            //   out_grad_data[idx]
            // note: logit_grad_mat = logit_grad_mat * out_grad_mat
            // so: logit_grad_data[idx_x] =  (logit_grad_data[idx_x] - 1) *
            // out_grad_data[idx]
            // means:           dy/dp * dy=   ( p - y ) * dy

233 234 235
            logit_grad_data[i * d + label_data[idx] * remain + j] -=
                out_grad_data[idx];
          }
236
        }
237
      }
238 239
    }
  }
C
caoying03 已提交
240 241 242 243
};

}  // namespace operators
}  // namespace paddle