未验证 提交 7e1155ed 编写于 作者: N niuliling123 提交者: GitHub

Add is_mean param for mean op (#40757)

上级 521cded2
...@@ -65,9 +65,10 @@ class MeanCUDAKernel : public framework::OpKernel<T> { ...@@ -65,9 +65,10 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
for (decltype(rank) i = 0; i < rank; ++i) { for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i); reduce_dims.push_back(i);
} }
TensorReduceImpl<T, T, kernel_primitives::AddFunctor, Div>( TensorReduceImpl<T, T, kernel_primitives::AddFunctor,
context.cuda_device_context(), *input, output, Div(numel), reduce_dims, kps::IdentityFunctor<T>>(
stream); context.cuda_device_context(), *input, output,
kps::IdentityFunctor<T>(), reduce_dims, stream, true);
} }
}; };
......
...@@ -33,12 +33,12 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx, ...@@ -33,12 +33,12 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& x, framework::Tensor* y, const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform, const TransformOp& transform,
const std::vector<int>& origin_reduce_dims, const std::vector<int>& origin_reduce_dims,
gpuStream_t stream) { gpuStream_t stream, bool is_mean = false) {
y->mutable_data<Ty>(x.place()); y->mutable_data<Ty>(x.place());
phi::funcs::ReduceKernel<Tx, Ty, ReduceOp, TransformOp>( phi::funcs::ReduceKernel<Tx, Ty, ReduceOp, TransformOp>(
static_cast<const phi::GPUContext&>(dev_ctx), x, y, transform, static_cast<const phi::GPUContext&>(dev_ctx), x, y, transform,
origin_reduce_dims); origin_reduce_dims, is_mean);
} }
} // namespace operators } // namespace operators
......
...@@ -453,25 +453,20 @@ struct ReduceConfig { ...@@ -453,25 +453,20 @@ struct ReduceConfig {
void SetReduceType() { void SetReduceType() {
int rank = x_dim.size(); int rank = x_dim.size();
int reduce_rank = reduce_dim.size(); int reduce_rank = reduce_dim.size();
bool is_last_dim =
(rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
if (rank == reduce_rank || is_last_dim) {
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
reduce_type = static_cast<int>(ReduceType::kReduceAny); bool not_higher = x_dim[0] > 1;
#else #else
reduce_type = static_cast<int>(ReduceType::kReduceLastDim); int device_id = paddle::platform::GetCurrentDeviceId();
int max_grid_z = phi::backends::gpu::GetGpuMaxGridDimSize(device_id)[2];
bool not_higher = x_dim[0] >= max_grid_z;
#endif #endif
if (reduce_last_dim && (reduce_rank == 1)) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1) { } else if (reduce_rank == 1) {
// ReduceFirstDim and reduceSecondDim reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
#ifdef PADDLE_WITH_XPU_KP if (rank == 3 && not_higher) {
if (reduce_dim[0] == 0) {
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
} else {
reduce_type = static_cast<int>(ReduceType::kReduceAny); reduce_type = static_cast<int>(ReduceType::kReduceAny);
} }
#else
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
#endif
} else { } else {
reduce_type = static_cast<int>(ReduceType::kReduceAny); reduce_type = static_cast<int>(ReduceType::kReduceAny);
} }
...@@ -648,7 +643,8 @@ __global__ void ReduceAnyKernel(const Tx* x, ...@@ -648,7 +643,8 @@ __global__ void ReduceAnyKernel(const Tx* x,
bool reduce_last_dim, bool reduce_last_dim,
const Calculator reduce_index_calculator, const Calculator reduce_index_calculator,
const Calculator left_index_calculator, const Calculator left_index_calculator,
const kps::DimConfig dim) { const kps::DimConfig dim,
bool is_mean) {
int input_idx, left_idx, stride; int input_idx, left_idx, stride;
int block_size = 0; int block_size = 0;
bool need_store = true; bool need_store = true;
...@@ -752,7 +748,9 @@ __global__ void ReduceAnyKernel(const Tx* x, ...@@ -752,7 +748,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
kps::Reduce<MPType, 1, 1, 1, ReduceOp, kps::details::kGlobalMode>( kps::Reduce<MPType, 1, 1, 1, ReduceOp, kps::details::kGlobalMode>(
&reduce_var, &reduce_var, reducer, reduce_last_dim); &reduce_var, &reduce_var, reducer, reduce_last_dim);
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num);
}
Ty result = static_cast<Ty>(reduce_var); Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>( kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store)); y + store_offset + i, &result, static_cast<int>(need_store));
...@@ -772,7 +770,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -772,7 +770,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int reduce_num, int reduce_num,
int left_num, int left_num,
int blocking_size, int blocking_size,
const kps::DimConfig dim) { const kps::DimConfig dim,
int mean_div,
bool is_mean) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used // function will be used
auto block = ReduceIndexMapping<false>(dim); auto block = ReduceIndexMapping<false>(dim);
...@@ -806,6 +806,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -806,6 +806,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
kps::details::ReduceMode::kLocalMode>( kps::details::ReduceMode::kLocalMode>(
&reduce_var, &reduce_compute, reducer, false); &reduce_var, &reduce_compute, reducer, false);
} }
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var); Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, 1, false>( kps::WriteData<Ty, 1, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX()); y + store_offset + idx, &result, block.BlockDimX());
...@@ -831,6 +834,10 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -831,6 +834,10 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
kps::details::ReduceMode::kLocalMode>( kps::details::ReduceMode::kLocalMode>(
&reduce_var, &reduce_compute, reducer, false); &reduce_var, &reduce_compute, reducer, false);
} }
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var); Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, 1, true>( kps::WriteData<Ty, 1, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x); y + store_offset + idx, &result, dim.rem_x);
...@@ -848,7 +855,8 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -848,7 +855,8 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform, const TransformOp& transform,
MPType init, MPType init,
KPStream stream, KPStream stream,
ReduceConfig<Ty> config) { ReduceConfig<Ty> config,
bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) { if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1; int stride_reduce = 1;
int stride_left = config.reduce_num; int stride_left = config.reduce_num;
...@@ -887,7 +895,8 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -887,7 +895,8 @@ static void LaunchReduceKernel(const Tx* x_data,
config.reduce_last_dim, config.reduce_last_dim,
reduce_index_calculator, reduce_index_calculator,
left_index_calculator, left_index_calculator,
dim); dim,
is_mean && (!config.should_reduce_again));
} else { } else {
int reduce_rank = config.reduce_strides.size(); int reduce_rank = config.reduce_strides.size();
...@@ -930,7 +939,8 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -930,7 +939,8 @@ static void LaunchReduceKernel(const Tx* x_data,
config.reduce_last_dim, config.reduce_last_dim,
reduce_index_calculator, reduce_index_calculator,
left_index_calculator, left_index_calculator,
dim); dim,
is_mean && (!config.should_reduce_again));
} }
if (config.should_reduce_again) { if (config.should_reduce_again) {
...@@ -950,15 +960,18 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -950,15 +960,18 @@ static void LaunchReduceKernel(const Tx* x_data,
kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0); kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0);
dim.SetRem(config.left_num % block.x, 0, 0); dim.SetRem(config.left_num % block.x, 0, 0);
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
grid = 8; int grid_size = 8;
block = 64; int block_size = 64;
#else
auto grid_size = grid;
auto block_size = block;
#endif #endif
ReduceHigherDimKernel< ReduceHigherDimKernel<
Ty, Ty,
Ty, Ty,
MPType, MPType,
ReduceOp, ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>( kps::IdentityFunctor<Ty, MPType>><<<grid_size, block_size, 0, stream>>>(
config.output_data, config.output_data,
y_data, y_data,
reducer, reducer,
...@@ -967,7 +980,9 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -967,7 +980,9 @@ static void LaunchReduceKernel(const Tx* x_data,
config.grid.y, config.grid.y,
config.left_num, config.left_num,
config.grid.y, config.grid.y,
dim); dim,
config.reduce_num,
is_mean);
} }
} }
...@@ -1034,7 +1049,8 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1034,7 +1049,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
const phi::DenseTensor& x, const phi::DenseTensor& x,
phi::DenseTensor* y, phi::DenseTensor* y,
const TransformOp& transform, const TransformOp& transform,
const std::vector<int>& origin_reduce_dims) { const std::vector<int>& origin_reduce_dims,
bool is_mean = false) {
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
auto stream = dev_ctx.x_context()->xpu_stream; auto stream = dev_ctx.x_context()->xpu_stream;
#else #else
...@@ -1069,8 +1085,18 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1069,8 +1085,18 @@ void ReduceKernel(const KPDevice& dev_ctx,
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
#ifndef PADDLE_WITH_XPU_KP #ifndef PADDLE_WITH_XPU_KP
if (use_cub_reduce) { if (use_cub_reduce) {
CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>( if (is_mean) {
x_data, y_data, transform, config.reduce_num, dev_ctx, stream); using Div = kps::DivideFunctor<Tx>;
CubTensorReduceImpl<Tx, Ty, ReduceOp, Div>(x_data,
y_data,
Div(config.reduce_num),
config.reduce_num,
dev_ctx,
stream);
} else {
CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
x_data, y_data, transform, config.reduce_num, dev_ctx, stream);
}
return; return;
} }
#endif #endif
...@@ -1115,7 +1141,9 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1115,7 +1141,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.reduce_num, config.reduce_num,
config.left_num, config.left_num,
config.blocking_size, config.blocking_size,
dim); dim,
config.reduce_num,
is_mean && (!config.should_reduce_again));
if (config.should_reduce_again) { if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1); dim3 block = dim3(config.block.x, 1, 1);
...@@ -1125,15 +1153,19 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1125,15 +1153,19 @@ void ReduceKernel(const KPDevice& dev_ctx,
dim2.SetRem(config.left_num % config.block.x, 0, 0); dim2.SetRem(config.left_num % config.block.x, 0, 0);
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
grid = 8; int grid_size = 8;
block = 64; int block_size = 64;
#else
auto grid_size = grid;
auto block_size = block;
#endif #endif
ReduceHigherDimKernel< ReduceHigherDimKernel<
Ty, Ty,
Ty, Ty,
MPType, MPType,
ReduceOp<MPType>, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>( kps::IdentityFunctor<Ty,
MPType>><<<grid_size, block_size, 0, stream>>>(
config.output_data, config.output_data,
y_data, y_data,
reducer, reducer,
...@@ -1142,7 +1174,9 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1142,7 +1174,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.grid.y, config.grid.y,
config.left_num, config.left_num,
config.grid.y, config.grid.y,
dim2); dim2,
config.reduce_num,
is_mean);
} }
return; return;
} }
...@@ -1151,7 +1185,14 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1151,7 +1185,14 @@ void ReduceKernel(const KPDevice& dev_ctx,
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>( LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>(
x_data, y_data, reducer, transform, reducer.initial(), stream, config); x_data,
y_data,
reducer,
transform,
reducer.initial(),
stream,
config,
is_mean);
} }
} // namespace funcs } // namespace funcs
......
...@@ -30,7 +30,8 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -30,7 +30,8 @@ void Reduce(const KPDevice& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out,
bool is_mean = false) {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all);
...@@ -57,12 +58,18 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -57,12 +58,18 @@ void Reduce(const KPDevice& dev_ctx,
tmp_tensor, tmp_tensor,
out, out,
TransformOp<data_t, MPType>(reduce_num), TransformOp<data_t, MPType>(reduce_num),
reduce_dims); reduce_dims,
is_mean);
})); }));
} else { } else {
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>( phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims); dev_ctx,
x,
out,
TransformOp<T, MPType>(reduce_num),
reduce_dims,
is_mean);
} }
} }
} // namespace phi } // namespace phi
......
...@@ -27,8 +27,8 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -27,8 +27,8 @@ void MeanRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::DivideFunctor>( phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out, true);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册