未验证 提交 03ca04fe 编写于 作者: N niuliling123 提交者: GitHub

Modefied reduce op for store temp_data with MpType (#55709)

上级 d3c9c079
...@@ -233,7 +233,7 @@ struct OneDimIndexCal { ...@@ -233,7 +233,7 @@ struct OneDimIndexCal {
}; };
// reduce config // reduce config
template <typename Ty> template <typename Ty, typename MPType>
struct ReduceConfig { struct ReduceConfig {
ReduceConfig(const std::vector<int>& origin_reduce_dims, ReduceConfig(const std::vector<int>& origin_reduce_dims,
const std::vector<int>& origin_x_dim) const std::vector<int>& origin_x_dim)
...@@ -250,7 +250,7 @@ struct ReduceConfig { ...@@ -250,7 +250,7 @@ struct ReduceConfig {
bool should_reduce_again = false; bool should_reduce_again = false;
bool reduce_last_dim = false; bool reduce_last_dim = false;
bool vectorize_input = false; bool vectorize_input = false;
Ty* output_data; MPType* tmp_data;
dim3 block; dim3 block;
dim3 grid; dim3 grid;
...@@ -288,11 +288,9 @@ struct ReduceConfig { ...@@ -288,11 +288,9 @@ struct ReduceConfig {
const KPDevice& dev_ctx, const KPDevice& dev_ctx,
phi::DenseTensor* tmp) { phi::DenseTensor* tmp) {
if (should_reduce_again) { if (should_reduce_again) {
tmp->Resize(phi::make_ddim( tmp->Resize(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))})); phi::make_ddim({static_cast<int64_t>(left_num * grid.z * grid.y)}));
output_data = dev_ctx.Alloc<Ty>(tmp); tmp_data = dev_ctx.Alloc<MPType>(tmp);
} else {
output_data = y_data;
} }
} }
...@@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x, ...@@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
const Calculator reduce_index_calculator, const Calculator reduce_index_calculator,
const Calculator left_index_calculator, const Calculator left_index_calculator,
const kps::DimConfig dim, const kps::DimConfig dim,
bool is_mean) { bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
int input_idx, left_idx, stride; int input_idx, left_idx, stride;
int block_size = 0; int block_size = 0;
bool need_store = true; bool need_store = true;
...@@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x, ...@@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x,
if (is_mean) { if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num); reduce_var = reduce_var / static_cast<MPType>(reduce_num);
} }
Ty result = static_cast<Ty>(reduce_var); if (!need_store_tmp) {
kps::details::WriteData<Ty>( Ty result = static_cast<Ty>(reduce_var);
y + store_offset + i, &result, static_cast<int>(need_store)); kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
} else {
kps::details::WriteData<MPType>(tmp_data + store_offset + i,
&reduce_var,
static_cast<int>(need_store));
}
} }
} }
...@@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int blocking_size, int blocking_size,
const kps::DimConfig dim, const kps::DimConfig dim,
int mean_div, int mean_div,
bool is_mean) { bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used // function will be used
auto block = ReduceIndexMapping<false>(dim); auto block = ReduceIndexMapping<false>(dim);
...@@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) { if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div); reduce_var = reduce_var / static_cast<MPType>(mean_div);
} }
Ty result = static_cast<Ty>(reduce_var); if (!need_store_tmp) {
kps::WriteData<Ty, 1, 1, false>( Ty result = static_cast<Ty>(reduce_var);
y + store_offset + idx, &result, block.BlockDimX()); kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
} else {
kps::WriteData<MPType, 1, 1, false>(
tmp_data + store_offset + idx, &reduce_var, block.BlockDimX());
}
} }
if (idx < left_num) { if (idx < left_num) {
...@@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x, ...@@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) { if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div); reduce_var = reduce_var / static_cast<MPType>(mean_div);
} }
Ty result = static_cast<Ty>(reduce_var); if (!need_store_tmp) {
kps::WriteData<Ty, 1, 1, true>(y + store_offset + idx, &result, dim.rem_x); Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x);
} else {
kps::WriteData<MPType, 1, 1, true>(
tmp_data + store_offset + idx, &reduce_var, dim.rem_x);
}
} }
} }
...@@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform, const TransformOp& transform,
MPType init, MPType init,
KPStream stream, KPStream stream,
ReduceConfig<Ty> config, ReduceConfig<Ty, MPType> config,
bool is_mean = false) { bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) { if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1; int stride_reduce = 1;
...@@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal> ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal>
<<<grid_num, block_num, 0, stream>>>( <<<grid_num, block_num, 0, stream>>>(
x_data, x_data,
config.output_data, y_data,
reducer, reducer,
transform, transform,
init, init,
...@@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator, reduce_index_calculator,
left_index_calculator, left_index_calculator,
dim, dim,
is_mean && (!config.should_reduce_again)); is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
} else { } else {
int reduce_rank = config.reduce_strides.size(); int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size(); int left_rank = config.left_strides.size();
...@@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator> ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator>
<<<grid_num, block_num, 0, stream>>>( <<<grid_num, block_num, 0, stream>>>(
x_data, x_data,
config.output_data, y_data,
reducer, reducer,
transform, transform,
init, init,
...@@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator, reduce_index_calculator,
left_index_calculator, left_index_calculator,
dim, dim,
is_mean && (!config.should_reduce_again)); is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
} }
if (config.should_reduce_again) { if (config.should_reduce_again) {
...@@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data, ...@@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data,
auto grid_size = grid; auto grid_size = grid;
auto block_size = block; auto block_size = block;
#endif #endif
ReduceHigherDimKernel<Ty, ReduceHigherDimKernel<MPType,
Ty, Ty,
MPType, MPType,
ReduceOp, ReduceOp,
kps::IdentityFunctor<Ty, MPType>> kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>( <<<grid_size, block_size, 0, stream>>>(
config.output_data, config.tmp_data,
y_data, y_data,
reducer, reducer,
kps::IdentityFunctor<Ty, MPType>(), kps::IdentityFunctor<MPType, MPType>(),
init, init,
config.grid.y, config.grid.y,
config.left_num, config.left_num,
config.grid.y, config.grid.y,
dim, dim,
config.reduce_num, config.reduce_num,
is_mean); is_mean,
config.tmp_data,
false);
} }
} }
...@@ -1008,7 +1033,8 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1008,7 +1033,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
return; return;
} }
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim); using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto config = ReduceConfig<Ty, MPType>(origin_reduce_dims, x_dim);
config.Run(dev_ctx); config.Run(dev_ctx);
int numel = x.numel(); int numel = x.numel();
// after config.run() // after config.run()
...@@ -1051,7 +1077,6 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1051,7 +1077,6 @@ void ReduceKernel(const KPDevice& dev_ctx,
} }
#endif #endif
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>(); auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel // launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
...@@ -1081,7 +1106,7 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1081,7 +1106,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp> ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>
<<<grid_num, block_num, 0, stream>>>( <<<grid_num, block_num, 0, stream>>>(
x_data, x_data,
config.output_data, y_data,
reducer, reducer,
transform, transform,
reducer.initial(), reducer.initial(),
...@@ -1090,7 +1115,9 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1090,7 +1115,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size, config.blocking_size,
dim, dim,
config.reduce_num, config.reduce_num,
is_mean && (!config.should_reduce_again)); is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
if (config.should_reduce_again) { if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1); dim3 block = dim3(config.block.x, 1, 1);
...@@ -1106,23 +1133,25 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1106,23 +1133,25 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto grid_size = grid; auto grid_size = grid;
auto block_size = block; auto block_size = block;
#endif #endif
ReduceHigherDimKernel<Ty, ReduceHigherDimKernel<MPType,
Ty, Ty,
MPType, MPType,
ReduceOp<MPType>, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>> kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>( <<<grid_size, block_size, 0, stream>>>(
config.output_data, config.tmp_data,
y_data, y_data,
reducer, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), kps::IdentityFunctor<MPType, MPType>(config.grid.y),
reducer.initial(), reducer.initial(),
config.grid.y, config.grid.y,
config.left_num, config.left_num,
config.grid.y, config.grid.y,
dim2, dim2,
config.reduce_num, config.reduce_num,
is_mean); is_mean,
config.tmp_data,
false);
} }
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册