array.h 3.7 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
S
sneaxiy 已提交
18 19
#include "paddle/fluid/framework/unroll_array_ops.h"
#include "paddle/fluid/platform/enforce.h"
X
Xin Pan 已提交
20 21 22

namespace paddle {
namespace framework {
S
sneaxiy 已提交
23

X
Xin Pan 已提交
24 25 26
template <typename T, size_t N>
class Array {
 public:
S
sneaxiy 已提交
27
  static constexpr size_t kSize = N;
X
Xin Pan 已提交
28

S
sneaxiy 已提交
29
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
30 31 32

  template <typename... Args>
  HOSTDEVICE inline explicit Array(const T &val, Args... args) {
S
sneaxiy 已提交
33 34
    static_assert(N == sizeof...(Args) + 1, "Invalid argument");
    UnrollVarArgsAssign<T>::Run(data_, val, args...);
X
Xin Pan 已提交
35 36
  }

S
sneaxiy 已提交
37 38 39
  HOSTDEVICE inline void Fill(const T &val) {
    UnrollFillConstant<N>::Run(data_, val);
  }
X
Xin Pan 已提交
40

S
sneaxiy 已提交
41
  HOSTDEVICE inline const T *Get() const { return data_; }
X
Xin Pan 已提交
42

S
sneaxiy 已提交
43
  HOSTDEVICE inline T *GetMutable() { return data_; }
X
Xin Pan 已提交
44

S
sneaxiy 已提交
45
  HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
S
sneaxiy 已提交
46

S
sneaxiy 已提交
47 48 49 50 51 52 53 54 55 56 57
  // Writing "return data_[i]" would cause compilation warning/error:
  // "array subscript is above array bound" in Python 35 CI.
  // It seems that it is a false warning of GCC if we do not check the bounds
  // of array index. But for better performance, we do not check in operator[]
  // like what is in STL. If users want to check the bounds, use at() instead
  HOSTDEVICE inline const T &operator[](size_t i) const {
    return *advance(data_, i);
  }

  HOSTDEVICE inline T &at(size_t i) {
#ifndef __CUDA_ARCH__
58 59
    PADDLE_ENFORCE_LT(
        i, N, platform::errors::OutOfRange("Array index out of bounds."));
S
sneaxiy 已提交
60 61 62 63 64 65 66 67 68
#endif
    return (*this)[i];
  }

  HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__
    PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
#endif
    return (*this)[i];
S
sneaxiy 已提交
69
  }
X
Xin Pan 已提交
70 71 72

  HOSTDEVICE constexpr size_t size() const { return N; }

S
sneaxiy 已提交
73 74 75 76 77 78 79 80
  HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
    return UnrollCompare<N>::Run(data_, other.data_);
  }

  HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
    return !(*this == other);
  }

X
Xin Pan 已提交
81
 private:
S
sneaxiy 已提交
82 83 84 85 86
  template <typename U>
  HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
    return ptr + i;
  }

X
Xin Pan 已提交
87 88 89
  T data_[N];
};

S
sneaxiy 已提交
90 91 92 93 94
template <typename T>
class Array<T, 0> {
 public:
  static constexpr size_t kSize = 0;

S
sneaxiy 已提交
95
  HOSTDEVICE inline Array() {}
S
sneaxiy 已提交
96 97 98 99 100 101 102 103

  HOSTDEVICE inline void Fill(const T &val) {}

  HOSTDEVICE inline constexpr T *Get() const { return nullptr; }

  // Add constexpr to GetMutable() cause warning in MAC
  HOSTDEVICE inline T *GetMutable() { return nullptr; }

S
sneaxiy 已提交
104 105 106 107 108
  HOSTDEVICE inline T &operator[](size_t) {
#ifdef __CUDA_ARCH__
    static T obj();
    return obj;
#else
S
sneaxiy 已提交
109 110 111 112
    PADDLE_THROW("Array<T, 0> has no element");
#endif
  }

S
sneaxiy 已提交
113 114 115 116 117
  HOSTDEVICE inline const T &operator[](size_t) const {
#ifdef __CUDA_ARCH__
    static const T obj();
    return obj;
#else
S
sneaxiy 已提交
118 119 120 121
    PADDLE_THROW("Array<T, 0> has no element");
#endif
  }

S
sneaxiy 已提交
122 123 124 125
  HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }

  HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }

S
sneaxiy 已提交
126 127 128 129 130 131 132 133 134 135 136
  HOSTDEVICE constexpr size_t size() const { return 0; }

  HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
    return true;
  }

  HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
    return false;
  }
};

X
Xin Pan 已提交
137 138
}  // namespace framework
}  // namespace paddle