gpu_utils.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#define EIGEN_USE_GPU

#include <array>
20

21
#include "paddle/phi/core/enforce.h"
22 23
#include "unsupported/Eigen/CXX11/Tensor"

24 25
namespace phi {
namespace funcs {
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

template <typename T, int Size, T DefaultValue>
struct DeviceArray {
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
    return data[index];
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
    return data[index];
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray() {
    for (int i = 0; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0) {
    data[0] = a0;
    for (int i = 1; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1) {
    data[0] = a0;
    data[1] = a1;
    for (int i = 2; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1, T a2) {
    data[0] = a0;
    data[1] = a1;
    data[2] = a2;
    for (int i = 3; i < Size; i++) {
      data[i] = DefaultValue;
    }
  }
  EIGEN_STRONG_INLINE DeviceArray(const std::array<T, Size>& sa) {
    for (int i = 0; i < Size; i++) {
      data[i] = sa[i];
    }
  }
  T data[Size];
};

struct Dim3 : DeviceArray<int, 3, 1> {
  typedef DeviceArray<int, 3, 1> Base;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3() : Base() {}
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3(int a0, int a1, int a2)
      : Base(a0, a1, a2) {}
  EIGEN_STRONG_INLINE Dim3(const std::array<int, 3>& array) : Base(array) {}
};

struct Index3 : DeviceArray<int, 3, 0> {
  typedef DeviceArray<int, 3, 0> Base;
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3() : Base() {}
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3(int a0, int a1, int a2)
      : Base(a0, a1, a2) {}
};

// Flat index with real dimension
85 86 87 88
template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType
FlatTensorIndex(const Index3& index, const Dim3& dims) {
  IndexType flat_index = index[0];
L
limingshu 已提交
89 90
#pragma unroll
  for (int i = 1; i < 3; ++i) {
91 92 93 94 95 96
    flat_index = flat_index * dims[i] + index[i];
  }
  return flat_index;
}

// Convert index to tensor index with dimension.
97
template <typename IndexType = int>
98
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
99
ConvertTensorIndex(IndexType index, const Dim3& dims) {
100
  Index3 tensor_index;
L
limingshu 已提交
101 102
#pragma unroll
  for (int i = 2; i >= 0; --i) {
103
    IndexType new_index = index / dims[i];
104
    tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
105 106 107 108 109 110 111
    index = new_index;
  }
  return tensor_index;
}

template <typename IntType, bool ceil>
IntType CeilOrFloor(IntType x, IntType deviser) {
112
  PADDLE_ENFORCE_GT(
113 114
      deviser,
      0,
115 116 117
      phi::errors::InvalidArgument("deviser should be greater than 0, "
                                   "but received is:%d",
                                   deviser));
118 119

  PADDLE_ENFORCE_GT(
120 121
      x,
      0,
122 123 124
      phi::errors::InvalidArgument("input should be greater than 0, "
                                   "but received is:%d",
                                   x));
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

  const IntType round_to_zero = x / deviser;
  const IntType inte_result = round_to_zero * deviser;

  if (ceil) {
    const bool do_adjustment =
        (round_to_zero >= 0) && (deviser > 0 && x > inte_result);
    const IntType adjustment = static_cast<IntType>(do_adjustment);
    const IntType ceil_val = round_to_zero + adjustment;
    return ceil_val;
  } else {
    const bool do_adjustment =
        (round_to_zero <= 0) && (deviser > 0 && x < inte_result);

    const IntType adjustment = static_cast<IntType>(do_adjustment);
    const IntType floor_val = round_to_zero - adjustment;
    return floor_val;
  }
}

145 146
}  // namespace funcs
}  // namespace phi