// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/kernels/funcs/fast_divmod.h" namespace phi { namespace funcs { #if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) #else #define PADDLE_ALIGN(x) #endif enum class SegmentedArraySize { kVariableLength = 0, kFixed4 = 4, kFixed8 = 8, kFixed16 = 16, kFixed32 = 32, kFixed64 = 64, }; template struct PADDLE_ALIGN(256) ConstPointerArray { public: const T* data[static_cast(Size)]; void Set(const std::vector& ptrs, const T** dev_ptr = nullptr) { for (auto i = 0; i < ptrs.size(); ++i) { data[i] = ptrs[i]; } } }; template struct PADDLE_ALIGN(256) ConstPointerArray { public: const T** data{nullptr}; void Set(const std::vector& ptrs, const T** dev_ptr = nullptr) { data = dev_ptr; } }; template struct PADDLE_ALIGN(256) PointerArray { public: T* data[static_cast(Size)]; void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { for (auto i = 0; i < ptrs.size(); ++i) { data[i] = ptrs[i]; } } }; template struct PADDLE_ALIGN(256) PointerArray { public: T** data{nullptr}; void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { data = dev_ptr; } }; #undef PADDLE_ALIGN template struct ArraySetterBase { protected: void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) { allocation = paddle::memory::Alloc( ctx.GetPlace(), num_bytes, phi::Stream(reinterpret_cast(ctx.stream()))); paddle::memory::Copy(ctx.GetPlace(), allocation->ptr(), phi::CPUPlace(), src, num_bytes, ctx.stream()); return allocation->ptr(); } phi::Allocator::AllocationPtr allocation{nullptr}; }; template struct ConstPointerArraySetter : public ArraySetterBase { public: ConstPointerArray array; ConstPointerArraySetter(const Context& ctx, const std::vector& t) { ptrs.resize(t.size()); for (int i = 0; i < t.size(); ++i) { ptrs[i] = t[i]->data(); } const T** dev_ptr = nullptr; if (Size == SegmentedArraySize::kVariableLength) { size_t num_bytes = t.size() * sizeof(T*); dev_ptr = reinterpret_cast(this->AllocAndCopy( ctx, reinterpret_cast(ptrs.data()), num_bytes)); } array.Set(ptrs, dev_ptr); } private: std::vector ptrs; }; template struct PointerArraySetter : public ArraySetterBase { public: PointerArray array; PointerArraySetter(const Context& ctx, std::vector* t) { ptrs.resize(t->size()); for (int i = 0; i < t->size(); ++i) { if (t->at(i) && (t->at(i)->numel() > 0)) { ptrs[i] = ctx.template Alloc(t->at(i)); } else { ptrs[i] = nullptr; } } T** dev_ptr = nullptr; if (Size == SegmentedArraySize::kVariableLength) { size_t num_bytes = t->size() * sizeof(T*); dev_ptr = reinterpret_cast(this->AllocAndCopy( ctx, reinterpret_cast(ptrs.data()), num_bytes)); } array.Set(ptrs, dev_ptr); } private: std::vector ptrs; }; inline SegmentedArraySize CalcArraySize(int n) { if (n <= 4) { return SegmentedArraySize::kFixed4; } else if (n <= 8) { return SegmentedArraySize::kFixed8; } else if (n <= 16) { return SegmentedArraySize::kFixed16; } else if (n <= 32) { return SegmentedArraySize::kFixed32; } else if (n <= 64) { return SegmentedArraySize::kFixed64; } else { return SegmentedArraySize::kVariableLength; } } } // namespace funcs #define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \ case (size): { \ constexpr auto kArraySize = (size); \ __VA_ARGS__; \ } break #define _SEGMENTED_ARRAY_KERNEL_DEFAULT(size, ...) \ default: { \ constexpr auto kArraySize = (size); \ __VA_ARGS__; \ } break #define SEGMENTED_ARRAY_KERNEL_HELPER(...) \ _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed4, \ ##__VA_ARGS__); \ _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed8, \ ##__VA_ARGS__); \ _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed16, \ ##__VA_ARGS__); \ _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed32, \ ##__VA_ARGS__); \ _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed64, \ ##__VA_ARGS__); \ _SEGMENTED_ARRAY_KERNEL_DEFAULT(funcs::SegmentedArraySize::kVariableLength, \ ##__VA_ARGS__); } // namespace phi