未验证 提交 03ca04fe 编写于 作者: N niuliling123 提交者: GitHub

Modefied reduce op for store temp_data with MpType (#55709)

上级 d3c9c079
......@@ -233,7 +233,7 @@ struct OneDimIndexCal {
};
// reduce config
template <typename Ty>
template <typename Ty, typename MPType>
struct ReduceConfig {
ReduceConfig(const std::vector<int>& origin_reduce_dims,
const std::vector<int>& origin_x_dim)
......@@ -250,7 +250,7 @@ struct ReduceConfig {
bool should_reduce_again = false;
bool reduce_last_dim = false;
bool vectorize_input = false;
Ty* output_data;
MPType* tmp_data;
dim3 block;
dim3 grid;
......@@ -288,11 +288,9 @@ struct ReduceConfig {
const KPDevice& dev_ctx,
phi::DenseTensor* tmp) {
if (should_reduce_again) {
tmp->Resize(phi::make_ddim(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}));
output_data = dev_ctx.Alloc<Ty>(tmp);
} else {
output_data = y_data;
tmp->Resize(
phi::make_ddim({static_cast<int64_t>(left_num * grid.z * grid.y)}));
tmp_data = dev_ctx.Alloc<MPType>(tmp);
}
}
......@@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
const Calculator reduce_index_calculator,
const Calculator left_index_calculator,
const kps::DimConfig dim,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
int input_idx, left_idx, stride;
int block_size = 0;
bool need_store = true;
......@@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num);
}
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
} else {
kps::details::WriteData<MPType>(tmp_data + store_offset + i,
&reduce_var,
static_cast<int>(need_store));
}
}
}
......@@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int blocking_size,
const kps::DimConfig dim,
int mean_div,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
auto block = ReduceIndexMapping<false>(dim);
......@@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
} else {
kps::WriteData<MPType, 1, 1, false>(
tmp_data + store_offset + idx, &reduce_var, block.BlockDimX());
}
}
if (idx < left_num) {
......@@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(y + store_offset + idx, &result, dim.rem_x);
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x);
} else {
kps::WriteData<MPType, 1, 1, true>(
tmp_data + store_offset + idx, &reduce_var, dim.rem_x);
}
}
}
......@@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform,
MPType init,
KPStream stream,
ReduceConfig<Ty> config,
ReduceConfig<Ty, MPType> config,
bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
......@@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
......@@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
} else {
int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
......@@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
......@@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
}
if (config.should_reduce_again) {
......@@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(),
kps::IdentityFunctor<MPType, MPType>(),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
}
......@@ -1008,7 +1033,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
return;
}
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto config = ReduceConfig<Ty, MPType>(origin_reduce_dims, x_dim);
config.Run(dev_ctx);
int numel = x.numel();
// after config.run()
......@@ -1051,7 +1077,6 @@ void ReduceKernel(const KPDevice& dev_ctx,
}
#endif
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
......@@ -1081,7 +1106,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
reducer.initial(),
......@@ -1090,7 +1115,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size,
dim,
config.reduce_num,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
......@@ -1106,23 +1133,25 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
kps::IdentityFunctor<MPType, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册