未验证 提交 aa0098f6 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of avgpool2d with NHWC layout (#49231)

上级 6cf9f9dd
......@@ -29,20 +29,22 @@ namespace phi {
inline int GetReduceNum(const DenseTensor& input,
const DenseTensor* output,
const std::string data_format,
const bool channel_last,
std::vector<int>* reduce_dim) {
// data_format only can be NCHW
bool channel_last = (data_format == "NHWC");
if (channel_last) {
return 0;
}
int reduce_num = 0;
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int output_height =
channel_last ? output->dims()[1] : output->dims()[2];
const int output_width = channel_last ? output->dims()[2] : output->dims()[3];
if ((output_height == 1) && (output_width == 1)) {
reduce_dim->push_back(2);
reduce_dim->push_back(3);
reduce_num = input.dims()[2] * input.dims()[3];
if (channel_last) {
reduce_dim->push_back(1);
reduce_dim->push_back(2);
reduce_num = input.dims()[1] * input.dims()[2];
} else {
reduce_dim->push_back(2);
reduce_dim->push_back(3);
reduce_num = input.dims()[2] * input.dims()[3];
}
}
return reduce_num;
}
......@@ -109,7 +111,7 @@ void PoolRawKernel(const Context& ctx,
} else if (pooling_type == "avg") {
std::vector<int> reduce_dim;
int reduce_num = GetReduceNum(x, out, data_format, &reduce_dim);
int reduce_num = GetReduceNum(x, out, channel_last, &reduce_dim);
if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册