diff --git a/paddle/testing/CMakeLists.txt b/paddle/testing/CMakeLists.txt index 4208132b98051858d3f49ecd265b9c459a94d31e..fe288ec2bf1d14a5b628ab69b5bdacfc2c429c51 100644 --- a/paddle/testing/CMakeLists.txt +++ b/paddle/testing/CMakeLists.txt @@ -4,3 +4,4 @@ if(WITH_TESTING) cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init device_context memory gtest gflags) endif() cc_test(small_vector_test SRCS small_vector_test.cc DEPS gtest gflags) +cc_test(array_ref_test SRCS array_ref_test.cc DEPS gtest gflags) diff --git a/paddle/testing/array_ref_test.cc b/paddle/testing/array_ref_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33a09c499246d123d43fea36ef1e0c0faa841236 --- /dev/null +++ b/paddle/testing/array_ref_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/utils/array_ref.h" + +#include +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(array_ref, array_ref) { + paddle::ArrayRef a; + CHECK_EQ(a.size(), size_t(0)); + CHECK_EQ(a.data(), static_cast(nullptr)); + + paddle::ArrayRef b(paddle::none); + CHECK_EQ(b.size(), size_t(0)); + CHECK_EQ(b.data(), static_cast(nullptr)); + + int v = 1; + paddle::ArrayRef c(v); + CHECK_EQ(c.size(), size_t(1)); + CHECK_EQ(c.data(), &v); + CHECK_EQ(c.equals(paddle::makeArrayRef(v)), true); + + int v1[5] = {1, 2, 3, 4, 5}; + paddle::ArrayRef d(v1, 5); + CHECK_EQ(d.size(), size_t(5)); + CHECK_EQ(d.data(), v1); + CHECK_EQ(d.equals(paddle::makeArrayRef(v1, 5)), true); + + paddle::ArrayRef e(&v1[0], &v1[4]); + CHECK_EQ(e.size(), size_t(4)); + CHECK_EQ(e.data(), v1); + CHECK_EQ(e.equals(paddle::makeArrayRef(&v1[0], &v1[4])), true); + + paddle::SmallVector small_vector{1, 2, 3}; + paddle::ArrayRef f(small_vector); + CHECK_EQ(f.size(), size_t(3)); + CHECK_EQ(f.data(), small_vector.data()); + CHECK_EQ(f.equals(paddle::makeArrayRef(small_vector)), true); + + std::vector vector{1, 2, 3}; + paddle::ArrayRef g(vector); + CHECK_EQ(g.size(), size_t(3)); + CHECK_EQ(g.data(), vector.data()); + CHECK_EQ(g.equals(paddle::makeArrayRef(vector)), true); + + std::initializer_list list = {1, 2, 3}; + paddle::ArrayRef h(list); + CHECK_EQ(h.size(), size_t(3)); + CHECK_EQ(h.data(), list.begin()); + + paddle::ArrayRef i(h); + CHECK_EQ(i.size(), size_t(3)); + CHECK_EQ(i.data(), list.begin()); + CHECK_EQ(i.equals(h), true); + CHECK_EQ(i.equals(paddle::makeArrayRef(h)), true); + + auto slice = i.slice(1, 2); + CHECK_EQ(slice.size(), size_t(2)); + CHECK_EQ(slice[0], 2); + CHECK_EQ(slice[1], 3); + + auto drop = i.drop_front(2); + CHECK_EQ(drop.size(), size_t(1)); + CHECK_EQ(drop[0], 3); + + paddle::ArrayRef nums = {1, 2, 3, 4, 5, 6, 7, 8}; + auto front = nums.take_front(3); + CHECK_EQ(front.size(), size_t(3)); + for (size_t i = 0; i < 3; ++i) { + CHECK_EQ(front[i], nums[i]); + } + auto back = nums.take_back(3); + CHECK_EQ(back.size(), size_t(3)); + for (size_t i = 0; i < 3; ++i) { + CHECK_EQ(back[i], nums[i + 5]); + } +} diff --git a/paddle/utils/array_ref.h b/paddle/utils/array_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..9b39e9775f97a305c2a7479162cbd4b53835aec7 --- /dev/null +++ b/paddle/utils/array_ref.h @@ -0,0 +1,337 @@ +// This file copy from llvm/ADT/ArrayRef.h, version: 12.0.0 +// Modified the following points +// 1. remove hash_value functions +// 2. replace with the llvm::NoneType with paddle::none_t +// 3. remove drop_while, drop_until, take_while, take_until methods + +//===- ArrayRef.h - Array Reference Wrapper ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PADDLE_UTILS_ARRAY_REF_H_ +#define PADDLE_UTILS_ARRAY_REF_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/utils/none.h" +#include "paddle/utils/small_vector.h" + +namespace paddle { + +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +template +class ArrayRef { + public: + using iterator = const T *; + using const_iterator = const T *; + using size_type = size_t; + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + const T *Data = nullptr; + + /// The number of elements. + size_type Length = 0; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /*implicit*/ ArrayRef() = default; + + /// Construct an empty ArrayRef from None. + /*implicit*/ ArrayRef(none_t) {} + + /// Construct an ArrayRef from a single element. + /*implicit*/ ArrayRef(const T &OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an ArrayRef from a pointer and length. + /*implicit*/ ArrayRef(const T *data, size_t length) + : Data(data), Length(length) {} + + /// Construct an ArrayRef from a range. + ArrayRef(const T *begin, const T *end) : Data(begin), Length(end - begin) {} + + /// Construct an ArrayRef from a SmallVector. This is templated in order to + /// avoid instantiating SmallVectorTemplateCommon whenever we + /// copy-construct an ArrayRef. + template + /*implicit*/ ArrayRef(const SmallVectorTemplateCommon &Vec) + : Data(Vec.data()), Length(Vec.size()) {} + + /// Construct an ArrayRef from a std::vector. + template + /*implicit*/ ArrayRef(const std::vector &Vec) + : Data(Vec.data()), Length(Vec.size()) {} + + /// Construct an ArrayRef from a std::array + template + /*implicit*/ constexpr ArrayRef(const std::array &Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an ArrayRef from a C array. + template + /*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// Construct an ArrayRef from a std::initializer_list. + /*implicit*/ ArrayRef(const std::initializer_list &Vec) + : Data(Vec.begin() == Vec.end() ? (T *)nullptr : Vec.begin()), + Length(Vec.size()) {} + + /// Construct an ArrayRef from ArrayRef. This uses SFINAE to + /// ensure that only ArrayRefs of pointers can be converted. + template + ArrayRef(const ArrayRef &A, + std::enable_if_t::value> + * = nullptr) + : Data(A.data()), Length(A.size()) {} + + /// Construct an ArrayRef from a SmallVector. This is + /// templated in order to avoid instantiating SmallVectorTemplateCommon + /// whenever we copy-construct an ArrayRef. + template + /*implicit*/ ArrayRef( + const SmallVectorTemplateCommon &Vec, + std::enable_if_t::value> * = + nullptr) + : Data(Vec.data()), Length(Vec.size()) {} + + /// Construct an ArrayRef from std::vector. This uses SFINAE + /// to ensure that only vectors of pointers can be converted. + template + ArrayRef( + const std::vector &Vec, + std::enable_if_t::value> * = 0) + : Data(Vec.data()), Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + iterator begin() const { return Data; } + iterator end() const { return Data + Length; } + + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + /// empty - Check if the array is empty. + bool empty() const { return Length == 0; } + + const T *data() const { return Data; } + + /// size - Get the array size. + size_t size() const { return Length; } + + /// front - Get the first element. + const T &front() const { + assert(!empty()); + return Data[0]; + } + + /// back - Get the last element. + const T &back() const { + assert(!empty()); + return Data[Length - 1]; + } + + // copy - Allocate copy in Allocator and return ArrayRef to it. + template + ArrayRef copy(Allocator &A) { + T *Buff = A.template Allocate(Length); + std::uninitialized_copy(begin(), end(), Buff); + return ArrayRef(Buff, Length); + } + + /// equals - Check for element-wise equality. + bool equals(ArrayRef RHS) const { + if (Length != RHS.Length) return false; + return std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Chop off the first N elements of the array, and keep M + /// elements in the array. + ArrayRef slice(size_t N, size_t M) const { + assert(N + M <= size() && "Invalid specifier"); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + ArrayRef slice(size_t N) const { return slice(N, size() - N); } + + /// Drop the first \p N elements of the array. + ArrayRef drop_front(size_t N = 1) const { + assert(size() >= N && "Dropping more elements than exist"); + return slice(N, size() - N); + } + + /// Drop the last \p N elements of the array. + ArrayRef drop_back(size_t N = 1) const { + assert(size() >= N && "Dropping more elements than exist"); + return slice(0, size() - N); + } + + /// Return a copy of *this with only the first \p N elements. + ArrayRef take_front(size_t N = 1) const { + if (N >= size()) return *this; + return drop_back(size() - N); + } + + /// Return a copy of *this with only the last \p N elements. + ArrayRef take_back(size_t N = 1) const { + if (N >= size()) return *this; + return drop_front(size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + const T &operator[](size_t Index) const { + assert(Index < Length && "Invalid index!"); + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t::value, ArrayRef> &operator=( + U &&Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t::value, ArrayRef> &operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { return std::vector(Data, Data + Length); } + + /// @} + /// @name Conversion operators + /// @{ + operator std::vector() const { + return std::vector(Data, Data + Length); + } + + /// @} +}; + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template +ArrayRef makeArrayRef(const T &OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template +ArrayRef makeArrayRef(const T *data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template +ArrayRef makeArrayRef(const T *begin, const T *end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVectorImpl &Vec) { + return Vec; +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVector &Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::vector. +template +ArrayRef makeArrayRef(const std::vector &Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::array. +template +ArrayRef makeArrayRef(const std::array &Arr) { + return Arr; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template +ArrayRef makeArrayRef(const ArrayRef &Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template +ArrayRef &makeArrayRef(ArrayRef &Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template +ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +/// @} +/// @name ArrayRef Comparison Operators +/// @{ + +template +inline bool operator==(ArrayRef LHS, ArrayRef RHS) { + return LHS.equals(RHS); +} + +template +inline bool operator==(SmallVectorImpl &LHS, ArrayRef RHS) { + return ArrayRef(LHS).equals(RHS); +} + +template +inline bool operator!=(ArrayRef LHS, ArrayRef RHS) { + return !(LHS == RHS); +} + +template +inline bool operator!=(SmallVectorImpl &LHS, ArrayRef RHS) { + return !(LHS == RHS); +} + +} // end namespace paddle + +#endif // PADDLE_UTILS_ARRAY_REF_H_