未验证 提交 ad4824e5 编写于 作者: L limingshu 提交者: GitHub

Move GeneralDivMod from segmented_array.h to fast_divmod.h (#48934)

* first commit

* add some changes in stack kernel.

* move the location of GeneralDivMod

* fix code format error according to ci
上级 6ed8221a
......@@ -64,5 +64,32 @@ struct FastDivMod {
uint32_t multiplier;
};
template <typename IndexT>
struct GeneralDivMod {
public:
explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
phi::funcs::FastDivMod divmoder;
};
template <>
struct GeneralDivMod<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
explicit GeneralDivMod(int64_t d) { divisor = d; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / divisor;
data[1] = val - data[0] * divisor;
return data;
}
int64_t divisor;
};
} // namespace funcs
} // namespace phi
......@@ -19,33 +19,6 @@
namespace phi {
namespace funcs {
template <typename IndexT>
struct GeneralDivMod {
public:
explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
phi::funcs::FastDivMod divmoder;
};
template <>
struct GeneralDivMod<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
explicit GeneralDivMod(int64_t d) { divisor = d; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / divisor;
data[1] = val - data[0] * divisor;
return data;
}
int64_t divisor;
};
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
......@@ -143,9 +116,8 @@ struct ConstPointerArraySetter : public ArraySetterBase<Context> {
const T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t.size() * sizeof(T*);
dev_ptr =
reinterpret_cast<const T**>(ArraySetterBase<Context>::AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
dev_ptr = reinterpret_cast<const T**>(this->AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
}
array.Set(ptrs, dev_ptr);
......@@ -173,7 +145,7 @@ struct PointerArraySetter : public ArraySetterBase<Context> {
T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t->size() * sizeof(T*);
dev_ptr = reinterpret_cast<T**>(ArraySetterBase<Context>::AllocAndCopy(
dev_ptr = reinterpret_cast<T**>(this->AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
}
......@@ -199,6 +171,7 @@ inline SegmentedArraySize CalcArraySize(int n) {
return SegmentedArraySize::kVariableLength;
}
}
} // namespace funcs
#define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册