diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 6179df5b0c11ad2a2e19384989029696e9d6c266..c1beb82da1072d199217d0722eaae6fcb0123490 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -58,7 +58,9 @@ void PoolCompute(const PoolParam ¶m) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); } - } else if (ksize[0] == 3 && ksize[0] == ksize[1]) { + } + + if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (pooling_type == "max") { if (strides[0] == strides[1] && strides[0] == 1 && paddings[0] == paddings[1] && paddings[1] == 1) { diff --git a/test/net/test_mobilenet.cpp b/test/net/test_mobilenet.cpp index 5a3cc43a552ccec34817af2409af98e8db0ec9e5..56234c3c72b58869775238d78875c8bd3b94cf7c 100644 --- a/test/net/test_mobilenet.cpp +++ b/test/net/test_mobilenet.cpp @@ -20,7 +20,11 @@ int main() { paddle_mobile::PaddleMobile paddle_mobile; paddle_mobile.SetThreadNum(4); auto time1 = time(); - if (paddle_mobile.Load(g_mobilenet, true)) { + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_mobilenet, true); + if (isok) { auto time2 = time(); std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; @@ -39,13 +43,14 @@ int main() { for (int i = 0; i < 10; ++i) { auto vec_result = paddle_mobile.Predict(input, dims); } + DLOG << vec_result; auto time4 = time(); std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" << std::endl; } - std::cout - << "如果结果Nan请查看: test/images/test_image_1x3x224x224_float 是否存在?" - << std::endl; + std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana " + "是否存在?" + << std::endl; return 0; } diff --git a/test/test_helper.h b/test/test_helper.h index f6ad597ab122f4abda2ed255f0ec957c56d3cb46..fef175951e834a176c7987a77d53f2b5b4eecc5b 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -28,6 +28,7 @@ static const char *g_ocr = "../models/ocr"; static const char *g_mobilenet_ssd = "../models/mobilenet+ssd"; static const char *g_mobilenet_ssd_gesture = "../models/mobilenet+ssd_gesture"; static const char *g_mobilenet_combined = "../models/mobilenet_combine"; +static const char *g_mobilenet_detect = "../models/mobilenet-detect"; static const char *g_squeezenet = "../models/squeezenet"; static const char *g_googlenet = "../models/googlenet"; static const char *g_mobilenet = "../models/mobilenet";