未验证 提交 3b2cd23a 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Support float & bfloat16 when using thrust (#51627)

* [AMP OP&Test] Support float & bfloat16 when using cub

* fix compile error

* fix

* fix rocm compile error
上级 7e6f89c4
...@@ -33,6 +33,26 @@ namespace cub = hipcub; ...@@ -33,6 +33,26 @@ namespace cub = hipcub;
namespace phi { namespace phi {
namespace funcs { namespace funcs {
template <typename T>
class CumTypeTrait {
public:
using Type = T;
};
template <>
class CumTypeTrait<phi::dtype::float16> {
public:
using Type = __half;
};
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
template <>
class CumTypeTrait<phi::dtype::bfloat16> {
public:
using Type = __nv_bfloat16;
};
#endif
template <typename T> template <typename T>
struct IsComplex : public std::false_type {}; struct IsComplex : public std::false_type {};
......
...@@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero { ...@@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero {
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0; first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
} }
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx]; x_filled_one_[idx] = should_fill_one ? static_cast<T>(1) : x_[idx];
} }
private: private:
...@@ -131,6 +131,7 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -131,6 +131,7 @@ void CumprodGradKernel(const Context &dev_ctx,
const DenseTensor &dout, const DenseTensor &dout,
int dim, int dim,
DenseTensor *dx) { DenseTensor *dx) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *y = &out; const auto *y = &out;
const auto *dy = &dout; const auto *dy = &dout;
...@@ -225,12 +226,16 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -225,12 +226,16 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T)); .Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data = auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr()); reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data, CumType *dy_mul_y_data_cum = reinterpret_cast<CumType *>(dy_mul_y_data);
dy_mul_y_reversed_cumsum_data, CumType *dy_mul_y_reversed_cumsum_data_cum =
reinterpret_cast<CumType *>(dy_mul_y_reversed_cumsum_data);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_y_data_cum,
dy_mul_y_reversed_cumsum_data_cum,
outer_dim, outer_dim,
mid_dim, mid_dim,
inner_dim, inner_dim,
static_cast<T>(0), static_cast<CumType>(0.0f),
cub::Sum(), cub::Sum(),
/*reverse=*/true, /*reverse=*/true,
dev_ctx); dev_ctx);
...@@ -264,14 +269,18 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -264,14 +269,18 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one // Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data = auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>( CumType *x_filled_one_data_cum =
x_filled_one_data, reinterpret_cast<CumType *>(x_filled_one_data);
x_filled_one_cumprod_data, CumType *x_filled_one_cumprod_data_cum =
reinterpret_cast<CumType *>(x_filled_one_cumprod_data);
phi::funcs::InclusiveScan<CumType, funcs::MultiplyFunctor<CumType>>(
x_filled_one_data_cum,
x_filled_one_cumprod_data_cum,
outer_dim, outer_dim,
mid_dim, mid_dim,
inner_dim, inner_dim,
static_cast<T>(1), static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<T>(), funcs::MultiplyFunctor<CumType>(),
/*reverse=*/false, /*reverse=*/false,
dev_ctx); dev_ctx);
...@@ -286,13 +295,17 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -286,13 +295,17 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>()); funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum = auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
phi::funcs::InclusiveScan<T, cub::Sum>( CumType *dy_mul_x_filled_one_cumprod_cum =
dy_mul_x_filled_one_cumprod, reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod);
dy_mul_x_filled_one_cumprod_reversed_cumsum, CumType *dy_mul_x_filled_one_cumprod_reversed_cumsum_cum =
reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod_reversed_cumsum);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_x_filled_one_cumprod_cum,
dy_mul_x_filled_one_cumprod_reversed_cumsum_cum,
outer_dim, outer_dim,
mid_dim, mid_dim,
inner_dim, inner_dim,
static_cast<T>(0), static_cast<CumType>(0.0f),
cub::Sum(), cub::Sum(),
/*reverse=*/true, /*reverse=*/true,
dev_ctx); dev_ctx);
...@@ -311,6 +324,18 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -311,6 +324,18 @@ void CumprodGradKernel(const Context &dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
phi::CumprodGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod_grad, PD_REGISTER_KERNEL(cumprod_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -319,5 +344,8 @@ PD_REGISTER_KERNEL(cumprod_grad, ...@@ -319,5 +344,8 @@ PD_REGISTER_KERNEL(cumprod_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#endif
...@@ -28,6 +28,7 @@ void CumprodKernel(const Context &dev_ctx, ...@@ -28,6 +28,7 @@ void CumprodKernel(const Context &dev_ctx,
const DenseTensor &input, const DenseTensor &input,
int dim, int dim,
DenseTensor *out) { DenseTensor *out) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *x = &input; const auto *x = &input;
auto *y = out; auto *y = out;
size_t outer_dim, mid_dim, inner_dim; size_t outer_dim, mid_dim, inner_dim;
...@@ -39,19 +40,22 @@ void CumprodKernel(const Context &dev_ctx, ...@@ -39,19 +40,22 @@ void CumprodKernel(const Context &dev_ctx,
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y); auto *y_data = dev_ctx.template Alloc<T>(y);
phi::funcs::InclusiveScan(x_data, const CumType *x_ptr = reinterpret_cast<const CumType *>(x_data);
y_data, CumType *y_ptr = reinterpret_cast<CumType *>(y_data);
phi::funcs::InclusiveScan(x_ptr,
y_ptr,
outer_dim, outer_dim,
mid_dim, mid_dim,
inner_dim, inner_dim,
static_cast<T>(1), static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<T>(), funcs::MultiplyFunctor<CumType>(),
/*reverse=*/false, /*reverse=*/false,
dev_ctx); dev_ctx);
} }
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod, PD_REGISTER_KERNEL(cumprod,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -62,3 +66,17 @@ PD_REGISTER_KERNEL(cumprod, ...@@ -62,3 +66,17 @@ PD_REGISTER_KERNEL(cumprod,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
phi::CumprodKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册