array.h 4.2 KB
Newer Older
1
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
X
Xin Pan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
18 19 20

#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/unroll_array_ops.h"
X
Xin Pan 已提交
21

22
namespace pten {
X
Xin Pan 已提交
23
namespace framework {
S
sneaxiy 已提交
24

X
Xin Pan 已提交
25 26 27
template <typename T, size_t N>
class Array {
 public:
S
sneaxiy 已提交
28
  static constexpr size_t kSize = N;
X
Xin Pan 已提交
29

S
sneaxiy 已提交
30
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
31 32 33

  template <typename... Args>
  HOSTDEVICE inline explicit Array(const T &val, Args... args) {
S
sneaxiy 已提交
34 35
    static_assert(N == sizeof...(Args) + 1, "Invalid argument");
    UnrollVarArgsAssign<T>::Run(data_, val, args...);
X
Xin Pan 已提交
36 37
  }

S
sneaxiy 已提交
38 39 40
  HOSTDEVICE inline void Fill(const T &val) {
    UnrollFillConstant<N>::Run(data_, val);
  }
X
Xin Pan 已提交
41

S
sneaxiy 已提交
42
  HOSTDEVICE inline const T *Get() const { return data_; }
X
Xin Pan 已提交
43

S
sneaxiy 已提交
44
  HOSTDEVICE inline T *GetMutable() { return data_; }
X
Xin Pan 已提交
45

S
sneaxiy 已提交
46
  HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
S
sneaxiy 已提交
47

S
sneaxiy 已提交
48 49 50 51 52 53 54 55 56 57
  // Writing "return data_[i]" would cause compilation warning/error:
  // "array subscript is above array bound" in Python 35 CI.
  // It seems that it is a false warning of GCC if we do not check the bounds
  // of array index. But for better performance, we do not check in operator[]
  // like what is in STL. If users want to check the bounds, use at() instead
  HOSTDEVICE inline const T &operator[](size_t i) const {
    return *advance(data_, i);
  }

  HOSTDEVICE inline T &at(size_t i) {
58
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
59
    PADDLE_ENFORCE_LT(
60
        i, N, pten::errors::OutOfRange("Array index out of bounds."));
S
sneaxiy 已提交
61 62 63 64 65
#endif
    return (*this)[i];
  }

  HOSTDEVICE inline const T &at(size_t i) const {
66
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
67
    PADDLE_ENFORCE_LT(
68
        i, N, pten::errors::OutOfRange("Array index out of bounds."));
S
sneaxiy 已提交
69 70
#endif
    return (*this)[i];
S
sneaxiy 已提交
71
  }
X
Xin Pan 已提交
72 73 74

  HOSTDEVICE constexpr size_t size() const { return N; }

S
sneaxiy 已提交
75 76 77 78 79 80 81 82
  HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
    return UnrollCompare<N>::Run(data_, other.data_);
  }

  HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
    return !(*this == other);
  }

X
Xin Pan 已提交
83
 private:
S
sneaxiy 已提交
84 85 86 87 88
  template <typename U>
  HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
    return ptr + i;
  }

X
Xin Pan 已提交
89 90 91
  T data_[N];
};

S
sneaxiy 已提交
92 93 94 95 96
template <typename T>
class Array<T, 0> {
 public:
  static constexpr size_t kSize = 0;

S
sneaxiy 已提交
97
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
98 99 100 101 102 103 104 105

  HOSTDEVICE inline void Fill(const T &val) {}

  HOSTDEVICE inline constexpr T *Get() const { return nullptr; }

  // Add constexpr to GetMutable() cause warning in MAC
  HOSTDEVICE inline T *GetMutable() { return nullptr; }

S
sneaxiy 已提交
106
  HOSTDEVICE inline T &operator[](size_t) {
107 108 109 110 111 112
#if defined(__HIPCC__)
    // HIP will have compile error, if use "obj()"
    // function declared in block scope cannot have 'static' storage class
    static T obj{};
    return obj;
#elif defined(__CUDA_ARCH__)
S
sneaxiy 已提交
113 114 115
    static T obj();
    return obj;
#else
116
    PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
S
sneaxiy 已提交
117 118 119
#endif
  }

S
sneaxiy 已提交
120
  HOSTDEVICE inline const T &operator[](size_t) const {
121 122 123 124 125 126
#if defined(__HIPCC__)
    // HIP will have compile error, if use "obj()"
    // function declared in block scope cannot have 'static' storage class
    static const T obj{};
    return obj;
#elif defined(__CUDA_ARCH__)
S
sneaxiy 已提交
127 128 129
    static const T obj();
    return obj;
#else
130
    PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
S
sneaxiy 已提交
131 132 133
#endif
  }

S
sneaxiy 已提交
134 135 136 137
  HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }

  HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }

S
sneaxiy 已提交
138 139 140 141 142 143 144 145 146 147 148
  HOSTDEVICE constexpr size_t size() const { return 0; }

  HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
    return true;
  }

  HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
    return false;
  }
};

X
Xin Pan 已提交
149
}  // namespace framework
150
}  // namespace pten