未验证 提交 57e368b8 编写于 作者: Z Zhang Zheng 提交者: GitHub

Revert "[AMP OP&Test] Support float & bfloat16 when using thrust (#51627)" (#51897)

This reverts commit 3b2cd23a.
上级 202c06a2
......@@ -33,26 +33,6 @@ namespace cub = hipcub;
namespace phi {
namespace funcs {
template <typename T>
class CumTypeTrait {
public:
using Type = T;
};
template <>
class CumTypeTrait<phi::dtype::float16> {
public:
using Type = __half;
};
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
template <>
class CumTypeTrait<phi::dtype::bfloat16> {
public:
using Type = __nv_bfloat16;
};
#endif
template <typename T>
struct IsComplex : public std::false_type {};
......
......@@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero {
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? static_cast<T>(1) : x_[idx];
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
}
private:
......@@ -131,7 +131,6 @@ void CumprodGradKernel(const Context &dev_ctx,
const DenseTensor &dout,
int dim,
DenseTensor *dx) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *y = &out;
const auto *dy = &dout;
......@@ -226,19 +225,15 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
CumType *dy_mul_y_data_cum = reinterpret_cast<CumType *>(dy_mul_y_data);
CumType *dy_mul_y_reversed_cumsum_data_cum =
reinterpret_cast<CumType *>(dy_mul_y_reversed_cumsum_data);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_y_data_cum,
dy_mul_y_reversed_cumsum_data_cum,
outer_dim,
mid_dim,
inner_dim,
static_cast<CumType>(0.0f),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
......@@ -269,18 +264,14 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
CumType *x_filled_one_data_cum =
reinterpret_cast<CumType *>(x_filled_one_data);
CumType *x_filled_one_cumprod_data_cum =
reinterpret_cast<CumType *>(x_filled_one_cumprod_data);
phi::funcs::InclusiveScan<CumType, funcs::MultiplyFunctor<CumType>>(
x_filled_one_data_cum,
x_filled_one_cumprod_data_cum,
phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data,
x_filled_one_cumprod_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<CumType>(),
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
......@@ -295,17 +286,13 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
CumType *dy_mul_x_filled_one_cumprod_cum =
reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod);
CumType *dy_mul_x_filled_one_cumprod_reversed_cumsum_cum =
reinterpret_cast<CumType *>(dy_mul_x_filled_one_cumprod_reversed_cumsum);
phi::funcs::InclusiveScan<CumType, cub::Sum>(
dy_mul_x_filled_one_cumprod_cum,
dy_mul_x_filled_one_cumprod_reversed_cumsum_cum,
phi::funcs::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
outer_dim,
mid_dim,
inner_dim,
static_cast<CumType>(0.0f),
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
......@@ -324,7 +311,6 @@ void CumprodGradKernel(const Context &dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
......@@ -335,17 +321,3 @@ PD_REGISTER_KERNEL(cumprod_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
phi::CumprodGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
......@@ -28,7 +28,6 @@ void CumprodKernel(const Context &dev_ctx,
const DenseTensor &input,
int dim,
DenseTensor *out) {
using CumType = typename funcs::CumTypeTrait<T>::Type;
const auto *x = &input;
auto *y = out;
size_t outer_dim, mid_dim, inner_dim;
......@@ -40,22 +39,19 @@ void CumprodKernel(const Context &dev_ctx,
const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
const CumType *x_ptr = reinterpret_cast<const CumType *>(x_data);
CumType *y_ptr = reinterpret_cast<CumType *>(y_data);
phi::funcs::InclusiveScan(x_ptr,
y_ptr,
phi::funcs::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<CumType>(1.0f),
funcs::MultiplyFunctor<CumType>(),
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
......@@ -66,17 +62,3 @@ PD_REGISTER_KERNEL(cumprod,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
phi::CumprodKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册