未验证 提交 ad4824e5 编写于 作者: L limingshu 提交者: GitHub

Move GeneralDivMod from segmented_array.h to fast_divmod.h (#48934)

* first commit

* add some changes in stack kernel.

* move the location of GeneralDivMod

* fix code format error according to ci
上级 6ed8221a
...@@ -64,5 +64,32 @@ struct FastDivMod { ...@@ -64,5 +64,32 @@ struct FastDivMod {
uint32_t multiplier; uint32_t multiplier;
}; };
template <typename IndexT>
struct GeneralDivMod {
public:
explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
phi::funcs::FastDivMod divmoder;
};
template <>
struct GeneralDivMod<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
explicit GeneralDivMod(int64_t d) { divisor = d; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / divisor;
data[1] = val - data[0] * divisor;
return data;
}
int64_t divisor;
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -19,33 +19,6 @@ ...@@ -19,33 +19,6 @@
namespace phi { namespace phi {
namespace funcs { namespace funcs {
template <typename IndexT>
struct GeneralDivMod {
public:
explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
phi::funcs::FastDivMod divmoder;
};
template <>
struct GeneralDivMod<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
explicit GeneralDivMod(int64_t d) { divisor = d; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / divisor;
data[1] = val - data[0] * divisor;
return data;
}
int64_t divisor;
};
#if !defined(_WIN32) #if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x))) #define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else #else
...@@ -143,9 +116,8 @@ struct ConstPointerArraySetter : public ArraySetterBase<Context> { ...@@ -143,9 +116,8 @@ struct ConstPointerArraySetter : public ArraySetterBase<Context> {
const T** dev_ptr = nullptr; const T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) { if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t.size() * sizeof(T*); size_t num_bytes = t.size() * sizeof(T*);
dev_ptr = dev_ptr = reinterpret_cast<const T**>(this->AllocAndCopy(
reinterpret_cast<const T**>(ArraySetterBase<Context>::AllocAndCopy( ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
} }
array.Set(ptrs, dev_ptr); array.Set(ptrs, dev_ptr);
...@@ -173,7 +145,7 @@ struct PointerArraySetter : public ArraySetterBase<Context> { ...@@ -173,7 +145,7 @@ struct PointerArraySetter : public ArraySetterBase<Context> {
T** dev_ptr = nullptr; T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) { if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t->size() * sizeof(T*); size_t num_bytes = t->size() * sizeof(T*);
dev_ptr = reinterpret_cast<T**>(ArraySetterBase<Context>::AllocAndCopy( dev_ptr = reinterpret_cast<T**>(this->AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes)); ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
} }
...@@ -199,6 +171,7 @@ inline SegmentedArraySize CalcArraySize(int n) { ...@@ -199,6 +171,7 @@ inline SegmentedArraySize CalcArraySize(int n) {
return SegmentedArraySize::kVariableLength; return SegmentedArraySize::kVariableLength;
} }
} }
} // namespace funcs } // namespace funcs
#define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \ #define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册