softmax_op.h 4.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

Q
Qiao Longfei 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

Q
Qiao Longfei 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
Qiao Longfei 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <vector>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
19
#include "paddle/fluid/operators/transpose_op.h"
20 21 22 23

namespace paddle {
namespace operators {

D
dongzhihong 已提交
24 25
using Tensor = framework::Tensor;

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
template <typename DeviceContext, typename T>
static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out,
                                      Tensor* x_trans, Tensor* out_trans,
                                      const int axis, std::vector<int> perm,
                                      const framework::ExecutionContext& ctx) {
  auto dim_x = x.dims();
  int rank = dim_x.size();

  if (axis == -1 || axis == rank - 1) {
    *x_trans = x;
    *out_trans = out;
    return;
  }

  auto& dev_ctx = ctx.template device_context<DeviceContext>();
  std::vector<int> shape;
  for (int i = 0; i < rank - 1; i++) {
    if (i == axis) {
      perm.push_back(rank - 1);
      shape.push_back(dim_x[rank - 1]);
    } else {
      perm.push_back(i);
      shape.push_back(dim_x[i]);
    }
  }
  perm.push_back(axis);
  shape.push_back(dim_x[axis]);

  x_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
  out_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
  TransCompute<DeviceContext, T>(rank, dev_ctx, x, x_trans, perm);
  TransCompute<DeviceContext, T>(rank, dev_ctx, out, out_trans, perm);
}

Q
QI JUN 已提交
60
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
61
class SoftmaxKernel : public framework::OpKernel<T> {
62
 public:
D
dongzhihong 已提交
63
  void Compute(const framework::ExecutionContext& context) const override {
64
    auto* X = context.Input<Tensor>("X");
F
fengjiayi 已提交
65
    auto* Out = context.Output<Tensor>("Out");
66
    const int axis = context.Attr<int>("axis");
Q
qijun 已提交
67

C
caoying03 已提交
68
    // allocate memory on device.
F
fengjiayi 已提交
69
    Out->mutable_data<T>(context.GetPlace());
Q
qijun 已提交
70

71 72 73 74 75
    Tensor X_trans, Out_trans;
    std::vector<int> perm;
    TransposeAxisToEnd<DeviceContext, T>(*X, *Out, &X_trans, &Out_trans, axis,
                                         perm, context);

F
fengjiayi 已提交
76
    int rank = X->dims().size();
77 78
    Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
    Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
F
fengjiayi 已提交
79

80
#ifdef PADDLE_ON_INFERENCE
J
Jacek Czaja 已提交
81
    math::SoftmaxFunctor<DeviceContext, T, true>()(
82
        context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
83 84 85 86
#else
    math::SoftmaxFunctor<DeviceContext, T, false>()(
        context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#endif
87 88 89 90 91

    if (axis != -1 && axis != rank - 1) {
      auto& dev_ctx = context.template device_context<DeviceContext>();
      TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
    }
92 93
  }
};
Q
Qiao Longfei 已提交
94

Q
QI JUN 已提交
95
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
96
class SoftmaxGradKernel : public framework::OpKernel<T> {
97
 public:
D
dongzhihong 已提交
98
  void Compute(const framework::ExecutionContext& context) const override {
F
fengjiayi 已提交
99 100
    auto* Out = context.Input<Tensor>("Out");
    auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
101
    auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
Q
Qiao Longfei 已提交
102

103 104
    // allocate memory on device.
    dX->mutable_data<T>(context.GetPlace());
Q
Qiao Longfei 已提交
105

F
fengjiayi 已提交
106
    int rank = Out->dims().size();
F
fengjiayi 已提交
107 108 109
    Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
    Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
    Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
F
fengjiayi 已提交
110

Q
QI JUN 已提交
111
    math::SoftmaxGradFunctor<DeviceContext, T>()(
F
fengjiayi 已提交
112 113
        context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
        &dX_2d);
Q
Qiao Longfei 已提交
114 115 116
  }
};

117 118
}  // namespace operators
}  // namespace paddle