未验证 提交 aa0098f6 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of avgpool2d with NHWC layout (#49231)

上级 6cf9f9dd
...@@ -29,21 +29,23 @@ namespace phi { ...@@ -29,21 +29,23 @@ namespace phi {
inline int GetReduceNum(const DenseTensor& input, inline int GetReduceNum(const DenseTensor& input,
const DenseTensor* output, const DenseTensor* output,
const std::string data_format, const bool channel_last,
std::vector<int>* reduce_dim) { std::vector<int>* reduce_dim) {
// data_format only can be NCHW
bool channel_last = (data_format == "NHWC");
if (channel_last) {
return 0;
}
int reduce_num = 0; int reduce_num = 0;
const int output_height = output->dims()[2]; const int output_height =
const int output_width = output->dims()[3]; channel_last ? output->dims()[1] : output->dims()[2];
const int output_width = channel_last ? output->dims()[2] : output->dims()[3];
if ((output_height == 1) && (output_width == 1)) { if ((output_height == 1) && (output_width == 1)) {
if (channel_last) {
reduce_dim->push_back(1);
reduce_dim->push_back(2);
reduce_num = input.dims()[1] * input.dims()[2];
} else {
reduce_dim->push_back(2); reduce_dim->push_back(2);
reduce_dim->push_back(3); reduce_dim->push_back(3);
reduce_num = input.dims()[2] * input.dims()[3]; reduce_num = input.dims()[2] * input.dims()[3];
} }
}
return reduce_num; return reduce_num;
} }
...@@ -109,7 +111,7 @@ void PoolRawKernel(const Context& ctx, ...@@ -109,7 +111,7 @@ void PoolRawKernel(const Context& ctx,
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
std::vector<int> reduce_dim; std::vector<int> reduce_dim;
int reduce_num = GetReduceNum(x, out, data_format, &reduce_dim); int reduce_num = GetReduceNum(x, out, channel_last, &reduce_dim);
if (reduce_num > 0 && if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1 adaptive) { // for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册