softmax_with_cross_entropy_op.h 11.3 KB
Newer Older
1 2


3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
4

L
Luo Tao 已提交
5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
caoying03 已提交
8

L
Luo Tao 已提交
9
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
10

C
caoying03 已提交
11 12 13 14 15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
caoying03 已提交
16 17

#pragma once
Y
Yi Wang 已提交
18 19 20 21
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
22
#include "paddle/fluid/operators/softmax_op.h"
C
caoying03 已提交
23 24 25 26 27 28

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

29
template <typename T>
Y
Yu Yang 已提交
30
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
C
caoying03 已提交
31
 public:
C
caoying03 已提交
32
  void Compute(const framework::ExecutionContext& context) const override {
33 34 35
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(context.GetPlace()), true,
        platform::errors::Unimplemented("This kernel only runs on CPU."));
36
    const bool use_softmax = context.Attr<bool>("use_softmax");
37 38

    // do not with softmax op, and input is softmax
39
    if (!use_softmax) {
40 41 42 43 44 45 46 47 48
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");
      const bool soft_label = context.Attr<bool>("soft_label");
      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

49 50 51 52 53 54 55
      PADDLE_ENFORCE_GT(
          axis_dim, 0,
          platform::errors::InvalidArgument(
              "The axis dimention should be larger than 0, but received "
              "axis dimention is %d.",
              axis_dim));

56 57 58 59
      softmax_out->mutable_data<T>(context.GetPlace());
      loss->mutable_data<T>(context.GetPlace());

      const int n = SizeToAxis(axis, softmax->dims());
60 61 62 63

      PADDLE_ENFORCE_GT(
          n, 0, platform::errors::InvalidArgument(
                    "The size of axis should be larger than 0, but received "
64
                    "SizeToAxis of softmax is %d.",
65 66
                    n));

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
      const int d = SizeFromAxis(axis, softmax->dims());

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      auto& dev_ctx =
          context.template device_context<platform::CPUDeviceContext>();

      math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
          dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
          context.Attr<int>("ignore_index"), axis_dim);

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
90
    const Tensor* logits = context.Input<Tensor>("Logits");
91
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
92
    Tensor* softmax = context.Output<Tensor>("Softmax");
93
    Tensor* loss = context.Output<Tensor>("Loss");
94 95 96 97 98
    const bool soft_label = context.Attr<bool>("soft_label");

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];
99 100 101 102 103 104
    PADDLE_ENFORCE_GT(
        axis_dim, 0,
        platform::errors::InvalidArgument(
            "The axis dimention should be larger than 0, but received "
            "axis dimention is %d.",
            axis_dim));
C
caoying03 已提交
105

106 107
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
C
caoying03 已提交
108

109
    const int n = SizeToAxis(axis, logits->dims());
110 111 112
    PADDLE_ENFORCE_GT(
        n, 0, platform::errors::InvalidArgument(
                  "The size of axis should be larger than 0, but received "
113
                  "SizeToAxis of logits is %d.",
114 115
                  n));

116 117 118 119 120 121
    const int d = SizeFromAxis(axis, logits->dims());
    Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
    logits_2d.ShareDataWith(*logits).Resize({n, d});
    softmax_2d.ShareDataWith(*softmax).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
D
dengkaipeng 已提交
122

Q
QI JUN 已提交
123 124
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
125
    math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
126
        dev_ctx, axis_dim, &logits_2d, &softmax_2d);
Q
QI JUN 已提交
127
    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
128 129
        dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
        context.Attr<int>("ignore_index"), axis_dim);
C
caoying03 已提交
130
  }
C
caoying03 已提交
131 132
};

133
template <typename T>
Y
Yu Yang 已提交
134
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
C
caoying03 已提交
135
 public:
136
  void Compute(const framework::ExecutionContext& context) const override {
137 138 139
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
    const Tensor* labels = context.Input<Tensor>("Label");
140 141
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
142
    const Tensor* softmax = context.Input<Tensor>("Softmax");
143 144
    const bool use_softmax = context.Attr<bool>("use_softmax");
    if (logit_grad != softmax || !use_softmax) {
Z
Zeng Jinle 已提交
145 146 147
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
148
    const bool soft_label = context.Attr<bool>("soft_label");
149
    auto ignore_index = context.Attr<int>("ignore_index");
150 151 152 153

    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];
154 155 156 157 158 159
    PADDLE_ENFORCE_GT(
        axis_dim, 0,
        platform::errors::InvalidArgument(
            "The axis dimention should be larger than 0, but received "
            "axis dimention is %d.",
            axis_dim));
160 161

    const int n = SizeToAxis(axis, logit_grad->dims());
162 163 164
    PADDLE_ENFORCE_GT(
        n, 0, platform::errors::InvalidArgument(
                  "The size of axis should be larger than 0, but received "
165
                  "SizeToAxis of logit_grad is %d.",
166 167
                  n));

168 169 170 171 172
    const int d = SizeFromAxis(axis, logit_grad->dims());
    Tensor logit_grad_2d, labels_2d, out_grad_2d;
    logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
    labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
    out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
W
wuhuanzhou 已提交
173 174
    auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
    auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
Q
QI JUN 已提交
175 176
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
177 178
    if (!use_softmax) {
      // use_softmax step1
179 180 181 182 183 184 185
      if (soft_label) {
        auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
        logit_grad_mat.device(place) =
            (-lbl_mat / logit_grad_mat);  // for each sample ,i  is sample id
        logit_grad_mat.device(place) =
            out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
            logit_grad_mat;
H
hong 已提交
186 187
      } else {
        // use_softmax step2
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
        const int64_t* label_data = labels->data<int64_t>();
        T* logit_grad_data = logit_grad->data<T>();
        const T* out_grad_data = out_grad->data<T>();
        const int remain = d / axis_dim;
        for (int i = 0; i < n; ++i) {         // for each sample_1_dim
          for (int j = 0; j < remain; j++) {  // for each sample_other_dims
            int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                       // remain=1 and j=0, so, idx = i
            if (label_data[idx] == ignore_index) {
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                logit_grad_data[i * d + k * remain + j] = 0;
              }
            } else {
              // only for this sample's label_idx, the label is 1, others is 0,
              // so, only compute this label_idx's class
              logit_grad_data[i * d + label_data[idx] * remain + j] =
                  (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
                  out_grad_data[idx];
              for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
                if (k !=
                    label_data[idx]) {  // label_data[idx]: this sample's label
                  logit_grad_data[i * d + k * remain + j] = 0;
                }
              }
            }
          }
        }
      }
      return;
    }
218
    // for use_softmax=False, continue
219

220
    if (soft_label) {
221
      // when soft_label = True, ignore_index is not supported
W
wuhuanzhou 已提交
222
      auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
Q
QI JUN 已提交
223
      logit_grad_mat.device(place) =
224
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
225 226 227 228 229 230 231 232
          (logit_grad_mat - lbl_mat);  // for each sample ,i  is sample id
      //         1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
      //            P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
      //            all class's labels
      //         2) compute dy * dy/dx by   Chain rule, dy=out_grad_mat[i]
      // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
      // operation

233
    } else {
Q
QI JUN 已提交
234
      logit_grad_mat.device(place) =
235
          logit_grad_mat *  // element_wise multiply
236
          out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
237

C
caoying03 已提交
238
      const int64_t* label_data = labels->data<int64_t>();
239
      T* logit_grad_data = logit_grad->data<T>();
C
caoying03 已提交
240
      const T* out_grad_data = out_grad->data<T>();
241
      const int remain = d / axis_dim;
242 243 244 245
      for (int i = 0; i < n; ++i) {         // for each sample_1_dim
        for (int j = 0; j < remain; j++) {  // for each sample_other_dims
          int idx = i * remain + j;  // this sample's label_idx. for 1d case,
                                     // remain=1 and j=0, so, idx = i
246
          if (label_data[idx] == ignore_index) {
247
            for (int k = 0; k < axis_dim; ++k) {  // for each class id's label
248 249 250
              logit_grad_data[i * d + k * remain + j] = 0;
            }
          } else {
251 252 253 254 255 256 257 258 259 260 261 262
            // only for this sample's label_idx, the label is 1, others is 0,
            // so, only compute this label_idx's class
            // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
            // remain + j] = [i * d + label_data[idx]]
            // let idx_x = i * d + label_data[idx] * remain + j,
            //   logit_grad_data[idx_x] = logit_grad_data[idx_x] -
            //   out_grad_data[idx]
            // note: logit_grad_mat = logit_grad_mat * out_grad_mat
            // so: logit_grad_data[idx_x] =  (logit_grad_data[idx_x] - 1) *
            // out_grad_data[idx]
            // means:           dy/dp * dy=   ( p - y ) * dy

263 264 265
            logit_grad_data[i * d + label_data[idx] * remain + j] -=
                out_grad_data[idx];
          }
266
        }
267
      }
268 269
    }
  }
C
caoying03 已提交
270 271 272 273
};

}  // namespace operators
}  // namespace paddle