提交 ddb97486 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE] Fix bug of global max pooling

上级 dfbc4b50
......@@ -218,23 +218,8 @@ void pooling_global(const void* din, void* dout, int num, int chout, int hout,
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int cnt = size_channel_in / 8;
#if 0
LOG(INFO) << "size_channel_in:" << size_channel_in;
LOG(INFO) << "cnt:" << cnt;
LOG(INFO) << "num:" << num;
LOG(INFO) << "chout:" << chout;
LOG(INFO) << "hout:" << hout;
LOG(INFO) << "wout:" << wout;
LOG(INFO) << "chin:" << chin;
LOG(INFO) << "hin:" << hin;
LOG(INFO) << "win:" << win;
LOG(INFO) << "pooling_type " << pooling_type;
#endif
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout;
const float* data_in_batch = data_in + n * chin * size_channel_in;
......@@ -254,24 +239,12 @@ void pooling_global(const void* din, void* dout, int num, int chout, int hout,
data_in_channel += 8;
}
#else
int num = cnt;
if (num > 0) {
asm volatile(
"max_loop: @main loop\n"
"vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, "
"data_in_channel\n"
"vmax.f32 %q[vmax], %q[vmax], q0 @max vmax, "
"vmax, data_in_channel\n"
"vld1.f32 {d2-d3}, [%[data_in_channel]]! @ load 2nd 4 "
"data"
"vmax.f32 %q[vmax], %q[vmax], q1 @ compare 2nd "
"4 datas\n"
"subs %[num], #1 @subs num, 1\n"
"bne max_loop @bne num\n"
: [data_in_channel] "+r"(data_in_channel), [num] "+r"(num),
[vmax] "+w"(vmax)
:
: "cc", "memory", "q0", "q1");
for (; i < cnt; i++) {
float32x4_t vdin1 = vld1q_f32(data_in_channel);
vmax = vmaxq_f32(vdin1, vmax);
float32x4_t vdin2 = vld1q_f32(data_in_channel + 4);
vmax = vmaxq_f32(vmax, vdin2);
data_in_channel += 8;
}
#endif // __aarch64__
float32x2_t vmax_tmp =
......
......@@ -25,6 +25,43 @@ namespace lite {
namespace kernels {
namespace arm {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
const auto x_dims = param_->x->dims();
std::vector<int>& ksize = param_->ksize;
if (param_->global_pooling) {
ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
param_->paddings[i] = 0;
ksize[i] = static_cast<int>(x_dims[i + 2]);
}
}
std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
if (param_->adaptive) {
output_shape.insert(output_shape.end(), param_->ksize.begin(),
param_->ksize.end());
} else {
for (size_t i = 0; i < param_->ksize.size(); ++i) {
output_shape.push_back(
PoolOutputSize(x_dims[i + 2], param_->ksize[i], param_->paddings[i],
param_->strides[i], param_->ceil_mode));
}
}
return output_shape;
}
void pool_compute_ref(const operators::PoolParam& param) {
auto& in_dims = param.x->dims();
auto& out_dims = param.output->dims();
......@@ -66,33 +103,28 @@ void pool_compute_ref(const operators::PoolParam& param) {
if (global_pooling == true) {
ksize[0] = in_h;
ksize[1] = in_w;
}
#if 0
for (int i = 0; i < ksize.size(); ++i) {
LOG(INFO) << "ksize[" << i << "]:" << ksize[i];
}
for (int i = 0; i < strides.size(); ++i) {
LOG(INFO) << "strides[" << i << "]:" << strides[i];
}
for (int i = 0; i < paddings.size(); ++i) {
LOG(INFO) << "paddings[" << i << "]:" << paddings[i];
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * in_c * in_h * in_w + c * in_h * in_w;
float res = src[0];
if (pooling_type == "max") {
for (int i = 1; i < in_h * in_w; ++i) {
float cur_val = src[i];
res = cur_val > res ? cur_val : res;
}
} else if (pooling_type == "avg") {
for (int i = 1; i < in_h * in_w; ++i) {
float cur_val = src[i];
res += cur_val;
}
res /= (in_h * in_w);
}
dst_ptr[n * in_c * out_h * out_w + c] = res;
}
}
return;
}
LOG(INFO) << "in nchw:" << in_n << ", " << in_c << ", " << in_h << ", "
<< in_w;
LOG(INFO) << "size_in_n:" << size_in_n;
LOG(INFO) << "size_out_c:" << size_out_c;
LOG(INFO) << "out_h:" << out_h;
LOG(INFO) << "out_w:" << out_w;
LOG(INFO) << "size_out_n:" << size_out_n;
LOG(INFO) << "size_out_c:" << size_out_c;
LOG(INFO) << "window_h:" << window_h;
LOG(INFO) << "window_w:" << window_w;
LOG(INFO) << "stride_h:" << stride_h;
LOG(INFO) << "stride_w:" << stride_w;
LOG(INFO) << "pad_h:" << pad_h;
LOG(INFO) << "pad_w:" << pad_w;
#endif
for (int ind_n = 0; ind_n < in_n; ++ind_n) {
for (int ind_c = 0; ind_c < in_c; ++ind_c) {
......@@ -179,21 +211,21 @@ TEST(pool_arm, compute) {
for (auto pooling_type : {"avg", "max"}) {
for (auto global_pooling : {true}) {
for (auto stride : {2}) {
for (auto pad : {0}) {
// for (auto ksize: {3}) { // TODO(yuanshuai): ksize enable 2, 3
for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) {
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci
for (auto h : {3, 1, 11, 4, 1}) {
for (auto w : {1, 3, 4, 12, 1}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " stride:" << stride << " pad:" << pad
<< " pooling_type:" << pooling_type
<< " global_pooling:" << global_pooling;
for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) {
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w // << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " pooling_type:" << pooling_type
<< " global_pooling:" << global_pooling;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, 1, 1})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, 1, 1})));
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); ++i) {
x_data[i] = i;
......@@ -203,6 +235,8 @@ TEST(pool_arm, compute) {
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
// param.ksize = {ksize, ksize}; //TODO(yuanshuai): ksize
// enable
param.ksize = {h, w};
param.global_pooling = global_pooling;
param.strides = {stride, stride};
......@@ -212,41 +246,40 @@ TEST(pool_arm, compute) {
param.ceil_mode = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
// compute
pool.SetParam(param);
pool.Run();
#if 0
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " end";
std::cout << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " end" << std::endl;
for (int i = 0; i < param.ksize.size(); ++i) {
std::cout << " ksize[" << i << "]:" << param.ksize[i];
}
std::cout << "\n";
for (int i = 0; i < param.strides.size(); ++i) {
std::cout << " strides[" << i << "]:" << param.strides[i];
}
std::cout << "\n";
for (int i = 0; i < param.paddings.size(); ++i) {
std::cout << " paddings[" << i << "]:" << param.paddings[i];
}
std::cout << "\n";
#endif
// compute ref
// output_ref.Resize(output.dims());
param.output = &output_ref;
pool_compute_ref(param);
VLOG(3) << "pool_compute_ref(param) end";
// compare
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1); // 1e-5);
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
float tmp = output_data[i] - output_ref_data[i];
tmp = tmp < 0 ? -tmp : tmp;
if (tmp > 1e-5) {
std::cout << "output_data[0]:" << output_data[0]
<< std::endl;
std::cout << "output_ref_data[0]:" << output_ref_data[0]
<< std::endl;
std::cout
<< "x.dims().production():" << x.dims().production()
<< std::endl;
for (int ii = 0; ii < x.dims().production(); ++ii) {
std::cout << x_data[ii] << " ";
}
std::cout;
exit(0);
}
}
VLOG(3) << "compare pass";
......@@ -256,6 +289,7 @@ TEST(pool_arm, compute) {
}
} // pad
} // stride
//} // ksize TODO(yuanshuai): ksize enable
} // global_pooling
} // pooling_type
}
......
......@@ -214,7 +214,7 @@ function test_arm {
echo "android do not need armv7hf"
return 0
fi
# TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then
echo "skip android v7 test yet"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册