未验证 提交 84273aaa 编写于 作者: Z Zhang Ting 提交者: GitHub

fix cumsum compilation error for GPU architecture that does not support fast FP16 (#47277)

上级 28ed27a6
......@@ -34,18 +34,6 @@ namespace cub = hipcub;
namespace phi {
template <typename T>
class CumTypeTrait {
public:
using Type = T;
};
template <>
class CumTypeTrait<phi::dtype::float16> {
public:
using Type = __half;
};
template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse(
const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
......@@ -228,6 +216,51 @@ __global__ void BlockScanKernel(T* d_out,
}
}
template <typename Context, typename T>
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const T* in_data,
T* out_data,
int64_t size,
bool reverse,
bool exclusive) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}
return;
}
template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::float16* in_data,
phi::dtype::float16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}
template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -260,37 +293,8 @@ void ScanKernel(const Context& dev_ctx,
// length of the ‘axis’ dimension.
if (!std::is_same<T, phi::dtype::float16>::value &&
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
using CumType = typename CumTypeTrait<T>::Type;
CumType* out_data_ptr = reinterpret_cast<CumType*>(out_data);
const CumType* in_data_ptr = reinterpret_cast<const CumType*>(in_data);
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const CumType>> reversed_in(
thrust::device_pointer_cast(in_data_ptr) + size);
thrust::reverse_iterator<thrust::device_ptr<CumType>> reversed_out(
thrust::device_pointer_cast(out_data_ptr) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(
policy, in_data_ptr, in_data_ptr + size, out_data_ptr);
} else {
thrust::inclusive_scan(
policy, in_data_ptr, in_data_ptr + size, out_data_ptr);
}
}
ThrustCumsumKernel<Context, T>(
dev_ctx, in_data, out_data, size, reverse, exclusive);
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册