gpu_utils.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#define EIGEN_USE_GPU

#include <array>
20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
#include "paddle/fluid/platform/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {

template <typename T, int Size, T DefaultValue>
struct DeviceArray {
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
    return data[index];
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
    return data[index];
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray() {
    for (int i = 0; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0) {
    data[0] = a0;
    for (int i = 1; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1) {
    data[0] = a0;
    data[1] = a1;
    for (int i = 2; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1, T a2) {
    data[0] = a0;
    data[1] = a1;
    data[2] = a2;
    for (int i = 3; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_STRONG_INLINE DeviceArray(const std::array<T, Size>& sa) {
    for (int i = 0; i < Size; i++) {
      data[i] = sa[i];
    }
  }
  T data[Size];
};

struct Dim3 : DeviceArray<int, 3, 1> {
  typedef DeviceArray<int, 3, 1> Base;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3() : Base() {}
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3(int a0, int a1, int a2)
      : Base(a0, a1, a2) {}
  EIGEN_STRONG_INLINE Dim3(const std::array<int, 3>& array) : Base(array) {}
};

struct Index3 : DeviceArray<int, 3, 0> {
  typedef DeviceArray<int, 3, 0> Base;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3() : Base() {}
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3(int a0, int a1, int a2)
      : Base(a0, a1, a2) {}
};

// Flat index with real dimension
85 86 87 88
template <typename IDX_T = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index,
                                                            const Dim3& dims) {
  IDX_T flat_index = index[0];
89 90 91 92 93 94 95
  for (int i = 1; i < 3; i++) {
    flat_index = flat_index * dims[i] + index[i];
  }
  return flat_index;
}

// Convert index to tensor index with dimension.
96
template <typename IDX_T = int>
97
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
98
ConvertTensorIndex(IDX_T index, const Dim3& dims) {
99 100
  Index3 tensor_index;
  for (int i = 2; i >= 0; i--) {
101 102
    IDX_T new_index = index / dims[i];
    tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
103 104 105 106 107 108 109
    index = new_index;
  }
  return tensor_index;
}

template <typename IntType, bool ceil>
IntType CeilOrFloor(IntType x, IntType deviser) {
110
  PADDLE_ENFORCE_GT(
111 112
      deviser,
      0,
113 114 115
      platform::errors::InvalidArgument("deviser should be greater than 0, "
                                        "but received is:%d",
                                        deviser));
116 117

  PADDLE_ENFORCE_GT(
118 119
      x,
      0,
120 121 122
      platform::errors::InvalidArgument("input should be greater than 0, "
                                        "but received is:%d",
                                        x));
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

  const IntType round_to_zero = x / deviser;
  const IntType inte_result = round_to_zero * deviser;

  if (ceil) {
    const bool do_adjustment =
        (round_to_zero >= 0) && (deviser > 0 && x > inte_result);
    const IntType adjustment = static_cast<IntType>(do_adjustment);
    const IntType ceil_val = round_to_zero + adjustment;
    return ceil_val;
  } else {
    const bool do_adjustment =
        (round_to_zero <= 0) && (deviser > 0 && x < inte_result);

    const IntType adjustment = static_cast<IntType>(do_adjustment);
    const IntType floor_val = round_to_zero - adjustment;
    return floor_val;
  }
}

}  // namespace framework
}  // namespace paddle