未验证 提交 3b2cd23a 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Support float & bfloat16 when using thrust (#51627)

* [AMP OP&Test] Support float & bfloat16 when using cub

* fix compile error

* fix

* fix rocm compile error
上级 7e6f89c4
......@@ -33,6 +33,26 @@ namespace cub = hipcub;
namespace phi {
namespace funcs {
template <typename T>
class CumTypeTrait {
public:
using Type = T;
};
template <>
class CumTypeTrait<phi::dtype::float16> {
public:
using Type = __half;
};
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
template <>
class CumTypeTrait<phi::dtype::bfloat16> {
public:
using Type = __nv_bfloat16;
};
#endif
template <typename T>
struct IsComplex : public std::false_type {};
......
......@@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero {
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
x_filled_one_[idx] = should_fill_one ? static_cast<T>(1) : x_[idx];
}
private:
......@@ -131,6 +131,7 @@ void CumprodGradKernel(const Context &dev_ctx,
const DenseTensor &dout,
int dim,
DenseTensor *dx) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *y = &out;
const auto *dy = &dout;
......@@ -225,15 +226,19 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
CumType *dy_mul_y_data_cum = reinterpret_cast<CumType *>(dy_mul_y_data);
CumType *dy_mul_y_reversed_cumsum_data_cum =
reinterpret_cast<CumType *>(dy_mul_y_reversed_cumsum_data);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_y_data_cum,
dy_mul_y_reversed_cumsum_data_cum,
outer_dim,
mid_dim,
inner_dim,
static_cast<CumType>(0.0f),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
......@@ -264,14 +269,18 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data,
x_filled_one_cumprod_data,
CumType *x_filled_one_data_cum =
reinterpret_cast<CumType *>(x_filled_one_data);
CumType *x_filled_one_cumprod_data_cum =
reinterpret_cast<CumType *>(x_filled_one_cumprod_data);
phi::funcs::InclusiveScan<CumType, funcs::MultiplyFunctor<CumType>>(
x_filled_one_data_cum,
x_filled_one_cumprod_data_cum,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<CumType>(),
/*reverse=*/false,
dev_ctx);
......@@ -286,13 +295,17 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
phi::funcs::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
CumType *dy_mul_x_filled_one_cumprod_cum =
reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod);
CumType *dy_mul_x_filled_one_cumprod_reversed_cumsum_cum =
reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod_reversed_cumsum);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_x_filled_one_cumprod_cum,
dy_mul_x_filled_one_cumprod_reversed_cumsum_cum,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
static_cast<CumType>(0.0f),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
......@@ -311,6 +324,7 @@ void CumprodGradKernel(const Context &dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
......@@ -321,3 +335,17 @@ PD_REGISTER_KERNEL(cumprod_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
phi::CumprodGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
......@@ -28,6 +28,7 @@ void CumprodKernel(const Context &dev_ctx,
const DenseTensor &input,
int dim,
DenseTensor *out) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *x = &input;
auto *y = out;
size_t outer_dim, mid_dim, inner_dim;
......@@ -39,19 +40,22 @@ void CumprodKernel(const Context &dev_ctx,
const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
phi::funcs::InclusiveScan(x_data,
y_data,
const CumType *x_ptr = reinterpret_cast<const CumType *>(x_data);
CumType *y_ptr = reinterpret_cast<CumType *>(y_data);
phi::funcs::InclusiveScan(x_ptr,
y_ptr,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<CumType>(),
/*reverse=*/false,
dev_ctx);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
......@@ -62,3 +66,17 @@ PD_REGISTER_KERNEL(cumprod,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
phi::CumprodKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册