segmented_array.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/kernels/funcs/fast_divmod.h"

namespace phi {
namespace funcs {

#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x)
#endif

enum class SegmentedArraySize {
  kVariableLength = 0,
  kFixed4 = 4,
  kFixed8 = 8,
  kFixed16 = 16,
  kFixed32 = 32,
  kFixed64 = 64,
};

template <typename T, SegmentedArraySize Size>
struct PADDLE_ALIGN(256) ConstPointerArray {
 public:
  const T* data[static_cast<int>(Size)];

  void Set(const std::vector<const T*>& ptrs, const T** dev_ptr = nullptr) {
    for (auto i = 0; i < ptrs.size(); ++i) {
      data[i] = ptrs[i];
    }
  }
};

template <typename T>
struct PADDLE_ALIGN(256)
    ConstPointerArray<T, SegmentedArraySize::kVariableLength> {
 public:
  const T** data{nullptr};

  void Set(const std::vector<const T*>& ptrs, const T** dev_ptr = nullptr) {
    data = dev_ptr;
  }
};

template <typename T, SegmentedArraySize Size>
struct PADDLE_ALIGN(256) PointerArray {
 public:
  T* data[static_cast<int>(Size)];

  void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) {
    for (auto i = 0; i < ptrs.size(); ++i) {
      data[i] = ptrs[i];
    }
  }
};

template <typename T>
struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> {
 public:
  T** data{nullptr};

  void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) {
    data = dev_ptr;
  }
};

#undef PADDLE_ALIGN

template <typename Context>
struct ArraySetterBase {
 protected:
  void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) {
    allocation = paddle::memory::Alloc(
        ctx.GetPlace(),
        num_bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
    paddle::memory::Copy(ctx.GetPlace(),
                         allocation->ptr(),
                         phi::CPUPlace(),
                         src,
                         num_bytes,
                         ctx.stream());
    return allocation->ptr();
  }

  phi::Allocator::AllocationPtr allocation{nullptr};
};

template <typename Context, typename T, SegmentedArraySize Size>
struct ConstPointerArraySetter : public ArraySetterBase<Context> {
 public:
  ConstPointerArray<T, Size> array;

  ConstPointerArraySetter(const Context& ctx,
                          const std::vector<const DenseTensor*>& t) {
    ptrs.resize(t.size());
    for (int i = 0; i < t.size(); ++i) {
      ptrs[i] = t[i]->data<T>();
    }

    const T** dev_ptr = nullptr;
    if (Size == SegmentedArraySize::kVariableLength) {
      size_t num_bytes = t.size() * sizeof(T*);
119 120
      dev_ptr = reinterpret_cast<const T**>(this->AllocAndCopy(
          ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    }

    array.Set(ptrs, dev_ptr);
  }

 private:
  std::vector<const T*> ptrs;
};

template <typename Context, typename T, SegmentedArraySize Size>
struct PointerArraySetter : public ArraySetterBase<Context> {
 public:
  PointerArray<T, Size> array;

  PointerArraySetter(const Context& ctx, std::vector<DenseTensor*>* t) {
    ptrs.resize(t->size());
    for (int i = 0; i < t->size(); ++i) {
      if (t->at(i) && (t->at(i)->numel() > 0)) {
        ptrs[i] = ctx.template Alloc<T>(t->at(i));
      } else {
        ptrs[i] = nullptr;
      }
    }

    T** dev_ptr = nullptr;
    if (Size == SegmentedArraySize::kVariableLength) {
      size_t num_bytes = t->size() * sizeof(T*);
148
      dev_ptr = reinterpret_cast<T**>(this->AllocAndCopy(
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
          ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
    }

    array.Set(ptrs, dev_ptr);
  }

 private:
  std::vector<T*> ptrs;
};

inline SegmentedArraySize CalcArraySize(int n) {
  if (n <= 4) {
    return SegmentedArraySize::kFixed4;
  } else if (n <= 8) {
    return SegmentedArraySize::kFixed8;
  } else if (n <= 16) {
    return SegmentedArraySize::kFixed16;
  } else if (n <= 32) {
    return SegmentedArraySize::kFixed32;
  } else if (n <= 64) {
    return SegmentedArraySize::kFixed64;
  } else {
    return SegmentedArraySize::kVariableLength;
  }
}
174

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
}  // namespace funcs

#define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \
  case (size): {                                \
    constexpr auto kArraySize = (size);         \
    __VA_ARGS__;                                \
  } break

#define _SEGMENTED_ARRAY_KERNEL_DEFAULT(size, ...) \
  default: {                                       \
    constexpr auto kArraySize = (size);            \
    __VA_ARGS__;                                   \
  } break

#define SEGMENTED_ARRAY_KERNEL_HELPER(...)                                    \
  _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed4,            \
                               ##__VA_ARGS__);                                \
  _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed8,            \
                               ##__VA_ARGS__);                                \
  _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed16,           \
                               ##__VA_ARGS__);                                \
  _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed32,           \
                               ##__VA_ARGS__);                                \
  _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed64,           \
                               ##__VA_ARGS__);                                \
  _SEGMENTED_ARRAY_KERNEL_DEFAULT(funcs::SegmentedArraySize::kVariableLength, \
                                  ##__VA_ARGS__);

}  // namespace phi