softmax_with_cross_entropy_op.h 10.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
caoying03 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
caoying03 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
20
#include "paddle/fluid/operators/softmax_op.h"
C
caoying03 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

27
template <typename T>
Y
Yu Yang 已提交
28
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
C
caoying03 已提交
29
 public:
C
caoying03 已提交
30
  void Compute(const framework::ExecutionContext& context) const override {
31 32 33
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(context.GetPlace()), true,
        platform::errors::Unimplemented("This kernel only runs on CPU."));
34
    const bool use_softmax = context.Attr<bool>("use_softmax");
35 36

    // do not with softmax op, and input is softmax
37
    if (!use_softmax) {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");
      const bool soft_label = context.Attr<bool>("soft_label");
      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

      softmax_out->mutable_data<T>(context.GetPlace());
      loss->mutable_data<T>(context.GetPlace());

      const int n = SizeToAxis(axis, softmax->dims());
      const int d = SizeFromAxis(axis, softmax->dims());

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      auto& dev_ctx =
          context.template device_context<platform::CPUDeviceContext>();

      math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
          dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
          context.Attr<int>("ignore_index"), axis_dim);

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
74
    const Tensor* logits = context.Input<Tensor>("Logits");
75
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
76
    Tensor* softmax = context.Output<Tensor>("Softmax");
77
    Tensor* loss = context.Output<Tensor>("Loss");
78 79 80 81 82
    const bool soft_label = context.Attr<bool>("soft_label");

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];
C
caoying03 已提交
83

84 85
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
C
caoying03 已提交
86

87 88 89 90 91 92 93
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());
    Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
    logits_2d.ShareDataWith(*logits).Resize({n, d});
    softmax_2d.ShareDataWith(*softmax).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
D
dengkaipeng 已提交
94

Q
QI JUN 已提交
95 96
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
97
    math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
98
        dev_ctx, axis_dim, &logits_2d, &softmax_2d);
Q
QI JUN 已提交
99
    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
100 101
        dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
        context.Attr<int>("ignore_index"), axis_dim);
C
caoying03 已提交
102
  }
C
caoying03 已提交
103 104
};

105
template <typename T>
Y
Yu Yang 已提交
106
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
C
caoying03 已提交
107
 public:
108
  void Compute(const framework::ExecutionContext& context) const override {
109 110 111
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
    const Tensor* labels = context.Input<Tensor>("Label");
112 113
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
114 115

    const Tensor* softmax = context.Input<Tensor>("Softmax");
116
    const bool use_softmax = context.Attr<bool>("use_softmax");
117

118
    if (logit_grad != softmax || !use_softmax) {
Z
Zeng Jinle 已提交
119 120 121
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
122

123
    const bool soft_label = context.Attr<bool>("soft_label");
124
    auto ignore_index = context.Attr<int>("ignore_index");
125 126 127 128 129 130 131 132 133 134 135

    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    Tensor logit_grad_2d, labels_2d, out_grad_2d;
    logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
136

W
wuhuanzhou 已提交
137 138
    auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
    auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
Q
QI JUN 已提交
139 140
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
141 142
    if (!use_softmax) {
      // use_softmax step1
143 144 145 146 147 148 149 150
      if (soft_label) {
        auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
        logit_grad_mat.device(place) =
            (-lbl_mat / logit_grad_mat);  // for each sample ,i  is sample id
        logit_grad_mat.device(place) =
            out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
            logit_grad_mat;
      }
151
      // use_softmax step2
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
      else {
        const int64_t* label_data = labels->data<int64_t>();
        T* logit_grad_data = logit_grad->data<T>();
        const T* out_grad_data = out_grad->data<T>();
        const int remain = d / axis_dim;
        for (int i = 0; i < n; ++i) {         // for each sample_1_dim
          for (int j = 0; j < remain; j++) {  // for each sample_other_dims
            int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                       // remain=1 and j=0, so, idx = i
            if (label_data[idx] == ignore_index) {
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                logit_grad_data[i * d + k * remain + j] = 0;
              }
            } else {
              // only for this sample's label_idx, the label is 1, others is 0,
              // so, only compute this label_idx's class
              logit_grad_data[i * d + label_data[idx] * remain + j] =
                  (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
                  out_grad_data[idx];
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                if (k !=
                    label_data[idx]) {  // label_data[idx]: this sample's label
                  logit_grad_data[i * d + k * remain + j] = 0;
                }
              }
            }
          }
        }
      }
      return;
    }

184
    // for use_softmax=False, continue
185

186
    if (soft_label) {
187
      // when soft_label = True, ignore_index is not supported
W
wuhuanzhou 已提交
188
      auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
Q
QI JUN 已提交
189
      logit_grad_mat.device(place) =
190
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
191 192 193 194 195 196 197 198
          (logit_grad_mat - lbl_mat);  // for each sample ,i  is sample id
      //         1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
      //            P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
      //            all class's labels
      //         2) compute dy * dy/dx by   Chain rule, dy=out_grad_mat[i]
      // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
      // operation

199
    } else {
Q
QI JUN 已提交
200
      logit_grad_mat.device(place) =
201
          logit_grad_mat *  // element_wise multiply
202
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
203

C
caoying03 已提交
204
      const int64_t* label_data = labels->data<int64_t>();
205
      T* logit_grad_data = logit_grad->data<T>();
C
caoying03 已提交
206
      const T* out_grad_data = out_grad->data<T>();
207
      const int remain = d / axis_dim;
208 209 210 211
      for (int i = 0; i < n; ++i) {         // for each sample_1_dim
        for (int j = 0; j < remain; j++) {  // for each sample_other_dims
          int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                     // remain=1 and j=0, so, idx = i
212
          if (label_data[idx] == ignore_index) {
213
            for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
214 215 216
              logit_grad_data[i * d + k * remain + j] = 0;
            }
          } else {
217 218 219 220 221 222 223 224 225 226 227 228
            // only for this sample's label_idx, the label is 1, others is 0,
            // so, only compute this label_idx's class
            // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
            // remain + j] = [i * d + label_data[idx]]
            // let idx_x = i * d + label_data[idx] * remain + j,
            //   logit_grad_data[idx_x] = logit_grad_data[idx_x] -
            //   out_grad_data[idx]
            // note: logit_grad_mat = logit_grad_mat * out_grad_mat
            // so: logit_grad_data[idx_x] =  (logit_grad_data[idx_x] - 1) *
            // out_grad_data[idx]
            // means:           dy/dp * dy=   ( p - y ) * dy

229 230 231
            logit_grad_data[i * d + label_data[idx] * remain + j] -=
                out_grad_data[idx];
          }
232
        }
233
      }
234 235
    }
  }
C
caoying03 已提交
236 237 238 239
};

}  // namespace operators
}  // namespace paddle