array.h 4.3 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
S
sneaxiy 已提交
18 19
#include "paddle/fluid/framework/unroll_array_ops.h"
#include "paddle/fluid/platform/enforce.h"
X
Xin Pan 已提交
20 21 22

namespace paddle {
namespace framework {
S
sneaxiy 已提交
23

X
Xin Pan 已提交
24 25 26
template <typename T, size_t N>
class Array {
 public:
S
sneaxiy 已提交
27
  static constexpr size_t kSize = N;
X
Xin Pan 已提交
28

S
sneaxiy 已提交
29
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
30 31 32

  template <typename... Args>
  HOSTDEVICE inline explicit Array(const T &val, Args... args) {
S
sneaxiy 已提交
33 34
    static_assert(N == sizeof...(Args) + 1, "Invalid argument");
    UnrollVarArgsAssign<T>::Run(data_, val, args...);
X
Xin Pan 已提交
35 36
  }

S
sneaxiy 已提交
37 38 39
  HOSTDEVICE inline void Fill(const T &val) {
    UnrollFillConstant<N>::Run(data_, val);
  }
X
Xin Pan 已提交
40

S
sneaxiy 已提交
41
  HOSTDEVICE inline const T *Get() const { return data_; }
X
Xin Pan 已提交
42

S
sneaxiy 已提交
43
  HOSTDEVICE inline T *GetMutable() { return data_; }
X
Xin Pan 已提交
44

S
sneaxiy 已提交
45
  HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
S
sneaxiy 已提交
46

S
sneaxiy 已提交
47 48 49 50 51 52 53 54 55 56
  // Writing "return data_[i]" would cause compilation warning/error:
  // "array subscript is above array bound" in Python 35 CI.
  // It seems that it is a false warning of GCC if we do not check the bounds
  // of array index. But for better performance, we do not check in operator[]
  // like what is in STL. If users want to check the bounds, use at() instead
  HOSTDEVICE inline const T &operator[](size_t i) const {
    return *advance(data_, i);
  }

  HOSTDEVICE inline T &at(size_t i) {
57
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
58 59
    PADDLE_ENFORCE_LT(
        i, N, platform::errors::OutOfRange("Array index out of bounds."));
S
sneaxiy 已提交
60 61 62 63 64
#endif
    return (*this)[i];
  }

  HOSTDEVICE inline const T &at(size_t i) const {
65
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
66 67
    PADDLE_ENFORCE_LT(
        i, N, platform::errors::OutOfRange("Array index out of bounds."));
S
sneaxiy 已提交
68 69
#endif
    return (*this)[i];
S
sneaxiy 已提交
70
  }
X
Xin Pan 已提交
71 72 73

  HOSTDEVICE constexpr size_t size() const { return N; }

S
sneaxiy 已提交
74 75 76 77 78 79 80 81
  HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
    return UnrollCompare<N>::Run(data_, other.data_);
  }

  HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
    return !(*this == other);
  }

X
Xin Pan 已提交
82
 private:
S
sneaxiy 已提交
83 84 85 86 87
  template <typename U>
  HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
    return ptr + i;
  }

X
Xin Pan 已提交
88 89 90
  T data_[N];
};

S
sneaxiy 已提交
91 92 93 94 95
template <typename T>
class Array<T, 0> {
 public:
  static constexpr size_t kSize = 0;

S
sneaxiy 已提交
96
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
97 98 99 100 101 102 103 104

  HOSTDEVICE inline void Fill(const T &val) {}

  HOSTDEVICE inline constexpr T *Get() const { return nullptr; }

  // Add constexpr to GetMutable() cause warning in MAC
  HOSTDEVICE inline T *GetMutable() { return nullptr; }

S
sneaxiy 已提交
105
  HOSTDEVICE inline T &operator[](size_t) {
106 107 108 109 110 111
#if defined(__HIPCC__)
    // HIP will have compile error, if use "obj()"
    // function declared in block scope cannot have 'static' storage class
    static T obj{};
    return obj;
#elif defined(__CUDA_ARCH__)
S
sneaxiy 已提交
112 113 114
    static T obj();
    return obj;
#else
115
    PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
S
sneaxiy 已提交
116 117 118
#endif
  }

S
sneaxiy 已提交
119
  HOSTDEVICE inline const T &operator[](size_t) const {
120 121 122 123 124 125
#if defined(__HIPCC__)
    // HIP will have compile error, if use "obj()"
    // function declared in block scope cannot have 'static' storage class
    static const T obj{};
    return obj;
#elif defined(__CUDA_ARCH__)
S
sneaxiy 已提交
126 127 128
    static const T obj();
    return obj;
#else
129
    PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
S
sneaxiy 已提交
130 131 132
#endif
  }

S
sneaxiy 已提交
133 134 135 136
  HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }

  HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }

S
sneaxiy 已提交
137 138 139 140 141 142 143 144 145 146 147
  HOSTDEVICE constexpr size_t size() const { return 0; }

  HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
    return true;
  }

  HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
    return false;
  }
};

X
Xin Pan 已提交
148 149
}  // namespace framework
}  // namespace paddle