softmax_with_cross_entropy_op.h 11.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
caoying03 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
caoying03 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
20
#include "paddle/fluid/operators/softmax_op.h"
C
caoying03 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

27
template <typename T>
Y
Yu Yang 已提交
28
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
C
caoying03 已提交
29
 public:
C
caoying03 已提交
30
  void Compute(const framework::ExecutionContext& context) const override {
31 32 33
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(context.GetPlace()), true,
        platform::errors::Unimplemented("This kernel only runs on CPU."));
34
    const bool use_softmax = context.Attr<bool>("use_softmax");
35 36

    // do not with softmax op, and input is softmax
37
    if (!use_softmax) {
38 39 40 41 42 43 44 45 46
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");
      const bool soft_label = context.Attr<bool>("soft_label");
      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

47 48 49 50 51 52 53
      PADDLE_ENFORCE_GT(
          axis_dim, 0,
          platform::errors::InvalidArgument(
              "The axis dimention should be larger than 0, but received "
              "axis dimention is %d.",
              axis_dim));

54 55 56 57
      softmax_out->mutable_data<T>(context.GetPlace());
      loss->mutable_data<T>(context.GetPlace());

      const int n = SizeToAxis(axis, softmax->dims());
58 59 60 61

      PADDLE_ENFORCE_GT(
          n, 0, platform::errors::InvalidArgument(
                    "The size of axis should be larger than 0, but received "
62
                    "SizeToAxis of softmax is %d.",
63 64
                    n));

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
      const int d = SizeFromAxis(axis, softmax->dims());

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      auto& dev_ctx =
          context.template device_context<platform::CPUDeviceContext>();

      math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
          dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
          context.Attr<int>("ignore_index"), axis_dim);

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
88
    const Tensor* logits = context.Input<Tensor>("Logits");
89
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
90
    Tensor* softmax = context.Output<Tensor>("Softmax");
91
    Tensor* loss = context.Output<Tensor>("Loss");
92 93 94 95 96
    const bool soft_label = context.Attr<bool>("soft_label");

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];
97 98 99 100 101 102
    PADDLE_ENFORCE_GT(
        axis_dim, 0,
        platform::errors::InvalidArgument(
            "The axis dimention should be larger than 0, but received "
            "axis dimention is %d.",
            axis_dim));
C
caoying03 已提交
103

104 105
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
C
caoying03 已提交
106

107
    const int n = SizeToAxis(axis, logits->dims());
108 109 110
    PADDLE_ENFORCE_GT(
        n, 0, platform::errors::InvalidArgument(
                  "The size of axis should be larger than 0, but received "
111
                  "SizeToAxis of logits is %d.",
112 113
                  n));

114 115 116 117 118 119
    const int d = SizeFromAxis(axis, logits->dims());
    Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
    logits_2d.ShareDataWith(*logits).Resize({n, d});
    softmax_2d.ShareDataWith(*softmax).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
D
dengkaipeng 已提交
120

Q
QI JUN 已提交
121 122
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
123
    math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
124
        dev_ctx, axis_dim, &logits_2d, &softmax_2d);
Q
QI JUN 已提交
125
    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
126 127
        dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
        context.Attr<int>("ignore_index"), axis_dim);
C
caoying03 已提交
128
  }
C
caoying03 已提交
129 130
};

131
template <typename T>
Y
Yu Yang 已提交
132
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
C
caoying03 已提交
133
 public:
134
  void Compute(const framework::ExecutionContext& context) const override {
135 136 137
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
    const Tensor* labels = context.Input<Tensor>("Label");
138 139
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
140
    const Tensor* softmax = context.Input<Tensor>("Softmax");
141 142
    const bool use_softmax = context.Attr<bool>("use_softmax");
    if (logit_grad != softmax || !use_softmax) {
Z
Zeng Jinle 已提交
143 144 145
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
146
    const bool soft_label = context.Attr<bool>("soft_label");
147
    auto ignore_index = context.Attr<int>("ignore_index");
148 149 150 151

    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];
152 153 154 155 156 157
    PADDLE_ENFORCE_GT(
        axis_dim, 0,
        platform::errors::InvalidArgument(
            "The axis dimention should be larger than 0, but received "
            "axis dimention is %d.",
            axis_dim));
158 159

    const int n = SizeToAxis(axis, logit_grad->dims());
160 161 162
    PADDLE_ENFORCE_GT(
        n, 0, platform::errors::InvalidArgument(
                  "The size of axis should be larger than 0, but received "
163
                  "SizeToAxis of logit_grad is %d.",
164 165
                  n));

166 167 168 169 170
    const int d = SizeFromAxis(axis, logit_grad->dims());
    Tensor logit_grad_2d, labels_2d, out_grad_2d;
    logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
W
wuhuanzhou 已提交
171 172
    auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
    auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
Q
QI JUN 已提交
173 174
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
175 176
    if (!use_softmax) {
      // use_softmax step1
177 178 179 180 181 182 183
      if (soft_label) {
        auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
        logit_grad_mat.device(place) =
            (-lbl_mat / logit_grad_mat);  // for each sample ,i  is sample id
        logit_grad_mat.device(place) =
            out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
            logit_grad_mat;
H
hong 已提交
184 185
      } else {
        // use_softmax step2
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        const int64_t* label_data = labels->data<int64_t>();
        T* logit_grad_data = logit_grad->data<T>();
        const T* out_grad_data = out_grad->data<T>();
        const int remain = d / axis_dim;
        for (int i = 0; i < n; ++i) {         // for each sample_1_dim
          for (int j = 0; j < remain; j++) {  // for each sample_other_dims
            int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                       // remain=1 and j=0, so, idx = i
            if (label_data[idx] == ignore_index) {
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                logit_grad_data[i * d + k * remain + j] = 0;
              }
            } else {
              // only for this sample's label_idx, the label is 1, others is 0,
              // so, only compute this label_idx's class
              logit_grad_data[i * d + label_data[idx] * remain + j] =
                  (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
                  out_grad_data[idx];
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                if (k !=
                    label_data[idx]) {  // label_data[idx]: this sample's label
                  logit_grad_data[i * d + k * remain + j] = 0;
                }
              }
            }
          }
        }
      }
      return;
    }
216
    // for use_softmax=False, continue
217

218
    if (soft_label) {
219
      // when soft_label = True, ignore_index is not supported
W
wuhuanzhou 已提交
220
      auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
Q
QI JUN 已提交
221
      logit_grad_mat.device(place) =
222
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
223 224 225 226 227 228 229 230
          (logit_grad_mat - lbl_mat);  // for each sample ,i  is sample id
      //         1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
      //            P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
      //            all class's labels
      //         2) compute dy * dy/dx by   Chain rule, dy=out_grad_mat[i]
      // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
      // operation

231
    } else {
Q
QI JUN 已提交
232
      logit_grad_mat.device(place) =
233
          logit_grad_mat *  // element_wise multiply
234
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
235

C
caoying03 已提交
236
      const int64_t* label_data = labels->data<int64_t>();
237
      T* logit_grad_data = logit_grad->data<T>();
C
caoying03 已提交
238
      const T* out_grad_data = out_grad->data<T>();
239
      const int remain = d / axis_dim;
240 241 242 243
      for (int i = 0; i < n; ++i) {         // for each sample_1_dim
        for (int j = 0; j < remain; j++) {  // for each sample_other_dims
          int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                     // remain=1 and j=0, so, idx = i
244
          if (label_data[idx] == ignore_index) {
245
            for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
246 247 248
              logit_grad_data[i * d + k * remain + j] = 0;
            }
          } else {
249 250 251 252 253 254 255 256 257 258 259 260
            // only for this sample's label_idx, the label is 1, others is 0,
            // so, only compute this label_idx's class
            // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
            // remain + j] = [i * d + label_data[idx]]
            // let idx_x = i * d + label_data[idx] * remain + j,
            //   logit_grad_data[idx_x] = logit_grad_data[idx_x] -
            //   out_grad_data[idx]
            // note: logit_grad_mat = logit_grad_mat * out_grad_mat
            // so: logit_grad_data[idx_x] =  (logit_grad_data[idx_x] - 1) *
            // out_grad_data[idx]
            // means:           dy/dp * dy=   ( p - y ) * dy

261 262 263
            logit_grad_data[i * d + label_data[idx] * remain + j] -=
                out_grad_data[idx];
          }
264
        }
265
      }
266 267
    }
  }
C
caoying03 已提交
268 269 270 271
};

}  // namespace operators
}  // namespace paddle