softmax_impl.h 17.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include <vector>
17

Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
T
tensor-tang 已提交
20
#include "paddle/fluid/operators/jit/kernels.h"
21 22
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {

28 29
template <typename T,
          int MajorType = Eigen::RowMajor,
30 31 32 33 34 35
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename T>
struct ValueClip {
  HOSTDEVICE T operator()(const T& x) const {
36
    const T kThreshold = static_cast<T>(-64.);
37 38 39 40
    return x < kThreshold ? kThreshold : x;
  }
};

41
template <typename DeviceContext, typename T, bool is_test>
42 43
class SoftmaxEigen {
 public:
44 45 46 47
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* X,
                  framework::Tensor* Y) {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;
    constexpr int kAxisDim = 1;

    auto logits = EigenMatrix<T>::From(*X);
    auto softmax = EigenMatrix<T>::From(*Y);

    const int batch_size = logits.dimension(kBatchDim);
    const int num_classes = logits.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_axis(kAxisDim);
    Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
    Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);

    // For numerical stability, logits should be shifted by maximum number along
    // axis, calculate shifted_logits into softmax tensor for memory reuse.
    if (num_remain == 1) {
      // axis == -1, axis and class in same dimension, calculate along
      // class dimension directly for higher performance
73 74 75 76 77 78
      softmax.device(*context.eigen_device()) =
          (logits - logits.maximum(along_axis)
                        .eval()
                        .reshape(batch_by_one)
                        .broadcast(one_by_class))
              .unaryExpr(ValueClip<T>());
79 80 81 82
    } else {
      // axis != -1, class dimension split into (axis, remain), max and sum
      // should be calculated along axis dimension
      softmax.device(*context.eigen_device()) =
83 84 85 86 87 88
          (logits.reshape(batch_axis_remain) - logits.reshape(batch_axis_remain)
                                                   .maximum(along_axis)
                                                   .eval()
                                                   .reshape(batch_one_remain)
                                                   .broadcast(one_axis_one)
                                                   .reshape(batch_classes))
89 90 91 92
              .unaryExpr(ValueClip<T>());
    }

    softmax.device(*context.eigen_device()) = softmax.exp();
93
    softmax.device(*context.eigen_device()) =
94 95 96 97 98
        (softmax * softmax.reshape(batch_axis_remain)
                       .sum(along_axis)
                       .inverse()
                       .eval()
                       .broadcast(one_axis));
99
  }
100
};
101

102 103 104
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
 public:
105 106 107 108
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* X,
                  framework::Tensor* Y) {
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;
    constexpr int kAxisDim = 1;

    auto logits = EigenMatrix<platform::float16>::From(*X);
    auto softmax = EigenMatrix<platform::float16>::From(*Y);

    const int batch_size = logits.dimension(kBatchDim);
    const int num_classes = logits.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_axis(kAxisDim);
    Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
    Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);

    // For numerical stability, logits should be shifted by maximum number along
    // axis, calculate shifted_logits into softmax tensor for memory reuse.
    if (num_remain == 1) {
      // axis == -1, axis and class in same dimension, calculate along
      // class dimension directly for higher performance
      softmax.device(*context.eigen_device()) =
135 136 137
          (logits - logits.maximum(along_axis)
                        .reshape(batch_by_one)
                        .broadcast(one_by_class))
138 139 140 141 142
              .unaryExpr(ValueClip<platform::float16>());
    } else {
      // axis != -1, class dimension split into (axis, remain), max and sum
      // should be calculated along axis dimension
      softmax.device(*context.eigen_device()) =
143 144 145 146 147
          (logits.reshape(batch_axis_remain) - logits.reshape(batch_axis_remain)
                                                   .maximum(along_axis)
                                                   .reshape(batch_one_remain)
                                                   .broadcast(one_axis_one)
                                                   .reshape(batch_classes))
148 149 150 151 152
              .unaryExpr(ValueClip<platform::float16>());
    }

    softmax.device(*context.eigen_device()) = softmax.exp();
    softmax.device(*context.eigen_device()) =
153 154 155 156
        (softmax * softmax.reshape(batch_axis_remain)
                       .sum(along_axis)
                       .inverse()
                       .broadcast(one_axis));
157 158
  }
};
159

160 161 162
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> {
 public:
163 164 165 166
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* X,
                  framework::Tensor* Y) {
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;
    constexpr int kAxisDim = 1;

    auto logits = EigenMatrix<platform::bfloat16>::From(*X);
    auto softmax = EigenMatrix<platform::bfloat16>::From(*Y);

    const int batch_size = logits.dimension(kBatchDim);
    const int num_classes = logits.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_axis(kAxisDim);
    Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
    Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);

    // For numerical stability, logits should be shifted by maximum number along
    // axis, calculate shifted_logits into softmax tensor for memory reuse.
    if (num_remain == 1) {
      // axis == -1, axis and class in same dimension, calculate along
      // class dimension directly for higher performance
      softmax.device(*context.eigen_device()) =
193 194 195
          (logits - logits.maximum(along_axis)
                        .reshape(batch_by_one)
                        .broadcast(one_by_class))
196 197 198 199 200
              .unaryExpr(ValueClip<platform::bfloat16>());
    } else {
      // axis != -1, class dimension split into (axis, remain), max and sum
      // should be calculated along axis dimension
      softmax.device(*context.eigen_device()) =
201 202 203 204 205
          (logits.reshape(batch_axis_remain) - logits.reshape(batch_axis_remain)
                                                   .maximum(along_axis)
                                                   .reshape(batch_one_remain)
                                                   .broadcast(one_axis_one)
                                                   .reshape(batch_classes))
206 207 208 209 210
              .unaryExpr(ValueClip<platform::bfloat16>());
    }

    softmax.device(*context.eigen_device()) = softmax.exp();
    softmax.device(*context.eigen_device()) =
211 212 213 214
        (softmax * softmax.reshape(batch_axis_remain)
                       .sum(along_axis)
                       .inverse()
                       .broadcast(one_axis));
215 216 217
  }
};

218 219
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
220 221 222 223
    const DeviceContext& context,
    const int axis_dim,
    const framework::Tensor* X,
    framework::Tensor* Y) {
224
  SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
225 226
}

227 228
template <class DeviceContext>
using enable_if_CPU = typename std::enable_if<
L
Leo Chen 已提交
229
    std::is_same<DeviceContext, phi::CPUContext>::value>::type;
230

231 232 233
template <typename DeviceContext, typename T, bool is_test>
class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
 public:
234 235 236 237
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* X,
                  framework::Tensor* Y) {
238
    const auto& in_dims = X->dims();
239 240 241 242 243 244 245 246 247 248 249 250 251 252
    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;

    const int num_classes = in_dims[kClassDim];
    const int batch_size = in_dims[kBatchDim];
    const int num_remain = num_classes / axis_dim;

    if (num_remain == 1 && platform::MayIUse(platform::avx)) {
      const T* in_data = X->data<T>();
      T* out_data = Y->data<T>();
      for (int bs = 0; bs < batch_size; ++bs) {
        T max_val = *std::max_element(in_data, in_data + num_classes);
        max_val *= static_cast<T>(-1);
        vec_add_bias<T, platform::avx>(num_classes, max_val, in_data, out_data);
253 254
        vec_clip<T, platform::avx>(
            num_classes, static_cast<T>(-64), out_data, out_data);
255 256 257 258 259 260 261 262 263 264 265
        vec_exp<T>(num_classes, out_data, out_data);

        T sum = 0;
        vec_sum<T, platform::avx>(num_classes, out_data, &sum);
        sum = static_cast<T>(1) / sum;
        vec_scal<T, platform::avx>(num_classes, sum, out_data, out_data);

        in_data += num_classes;
        out_data += num_classes;
      }
    } else {
266
      SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
267 268 269 270
    }
  }
};

271
template <typename DeviceContext>
272
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
273
 public:
274 275 276 277
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* X,
                  framework::Tensor* Y) {
Z
Zhang Jun 已提交
278
    const auto& in_dims = X->dims();
279 280
    const float* in_data = X->data<float>();
    float* out_data = Y->data<float>();
281 282
    const int kBatchDim = 0;
    const int kClassDim = 1;
283
    // 2D data. Batch x C
T
tensor-tang 已提交
284
    auto compute_softmax =
285
        jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
T
tensor-tang 已提交
286
            .At(in_dims[kClassDim]);
287 288 289 290
    compute_softmax(in_data,
                    out_data,
                    in_dims[kClassDim],
                    in_dims[kBatchDim],
291
                    in_dims[kClassDim] / axis_dim);
292 293 294 295
  }
};

template <typename DeviceContext, typename T>
296 297
class SoftmaxGradEigen {
 public:
298 299 300 301
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* y,
                  const framework::Tensor* y_grad,
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
                  framework::Tensor* x_grad) {
    auto softmax = EigenMatrix<T>::From(*y);
    auto softmax_grad = EigenMatrix<T>::From(*y_grad);
    auto logits_grad = EigenMatrix<T>::From(*x_grad);

    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;

    const int batch_size = softmax.dimension(kBatchDim);
    const int num_classes = softmax.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_class(kClassDim);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);

    auto dot = (softmax * softmax_grad)
                   .reshape(batch_axis_remain)
                   .sum(along_class)
                   .eval()
                   .broadcast(one_axis);
    logits_grad.device(*context.eigen_device()) =
        (softmax_grad - dot) * softmax;
  }
};

template <typename DeviceContext>
class SoftmaxGradEigen<DeviceContext, platform::float16> {
 public:
333 334 335 336
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* y,
                  const framework::Tensor* y_grad,
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
                  framework::Tensor* x_grad) {
    auto softmax = EigenMatrix<platform::float16>::From(*y);
    auto softmax_grad = EigenMatrix<platform::float16>::From(*y_grad);
    auto logits_grad = EigenMatrix<platform::float16>::From(*x_grad);

    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;

    const int batch_size = softmax.dimension(kBatchDim);
    const int num_classes = softmax.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_class(kClassDim);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);

355 356 357 358 359 360 361 362 363 364 365 366
    auto dot = (softmax * softmax_grad)
                   .reshape(batch_axis_remain)
                   .sum(along_class)
                   .broadcast(one_axis);
    logits_grad.device(*context.eigen_device()) =
        (softmax_grad - dot) * softmax;
  }
};

template <typename DeviceContext>
class SoftmaxGradEigen<DeviceContext, platform::bfloat16> {
 public:
367 368 369 370
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* y,
                  const framework::Tensor* y_grad,
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
                  framework::Tensor* x_grad) {
    auto softmax = EigenMatrix<platform::bfloat16>::From(*y);
    auto softmax_grad = EigenMatrix<platform::bfloat16>::From(*y_grad);
    auto logits_grad = EigenMatrix<platform::bfloat16>::From(*x_grad);

    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;

    const int batch_size = softmax.dimension(kBatchDim);
    const int num_classes = softmax.dimension(kClassDim);
    const int num_remain = num_classes / axis_dim;

    Eigen::DSizes<int, 1> along_class(kClassDim);
    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
    Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
    Eigen::DSizes<int, 2> one_axis(1, axis_dim);

389 390 391 392 393 394 395 396
    auto dot = (softmax * softmax_grad)
                   .reshape(batch_axis_remain)
                   .sum(along_class)
                   .broadcast(one_axis);
    logits_grad.device(*context.eigen_device()) =
        (softmax_grad - dot) * softmax;
  }
};
397

398 399
template <typename DeviceContext, typename T, typename Enable>
void SoftmaxGradFunctor<DeviceContext, T, Enable>::operator()(
400 401 402 403
    const DeviceContext& context,
    const int axis_dim,
    const framework::Tensor* y,
    const framework::Tensor* y_grad,
404
    framework::Tensor* x_grad) {
405
  SoftmaxGradEigen<DeviceContext, T>()(context, axis_dim, y, y_grad, x_grad);
406 407 408 409 410
}

template <typename DeviceContext, typename T>
class SoftmaxGradFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
 public:
411 412 413 414
  void operator()(const DeviceContext& context,
                  const int axis_dim,
                  const framework::Tensor* y,
                  const framework::Tensor* y_grad,
415
                  framework::Tensor* x_grad) {
Z
Zhang Jun 已提交
416
    const auto& out_dims = y->dims();
417 418 419 420 421 422 423 424 425 426 427 428
    constexpr int kBatchDim = 0;
    constexpr int kClassDim = 1;
    const int num_classes = out_dims[kClassDim];
    const int batch_size = out_dims[kBatchDim];
    const int num_remain = num_classes / axis_dim;

    if (num_remain == 1 && platform::MayIUse(platform::avx)) {
      const T* out_data = y->data<T>();
      const T* out_grad = y_grad->data<T>();
      T* in_grad = x_grad->data<T>();
      for (int bs = 0; bs < batch_size; ++bs) {
        T scalar;
429 430
        vec_mul_reduce<T, platform::avx>(
            num_classes, out_grad, out_data, &scalar);
431 432 433 434 435 436 437 438
        scalar *= static_cast<T>(-1);
        vec_add_bias<T, platform::avx>(num_classes, scalar, out_grad, in_grad);
        vec_mul<T, platform::avx>(num_classes, out_data, in_grad, in_grad);
        out_data += num_classes;
        out_grad += num_classes;
        in_grad += num_classes;
      }
    } else {
439 440
      SoftmaxGradEigen<DeviceContext, T>()(
          context, axis_dim, y, y_grad, x_grad);
441 442 443 444
    }
  }
};

445 446 447
}  // namespace math
}  // namespace operators
}  // namespace paddle