From 8505700f1bd7bf758bd54d9f5ba59c99e38233dd Mon Sep 17 00:00:00 2001 From: Houjiang Chen Date: Mon, 15 Apr 2019 17:34:46 +0800 Subject: [PATCH] Add fetch API for java, refine android log (#1558) * Clear no persistable tensor array before predicting, fix crash when predicting with gpu debugging mode * Fix code style * Add fetch API for java, refine android log --- src/common/log.h | 12 ++-- src/io/jni/PML.java | 2 + src/io/jni/paddle_mobile_jni.cpp | 60 ++++++++++++------- src/io/jni/paddle_mobile_jni.h | 4 ++ .../kernel/arm/convolution/conv_common.cpp | 1 + tools/android-cmake/android.toolchain.cmake | 2 +- 6 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/common/log.h b/src/common/log.h index 30d1d495c7..dde50b6170 100644 --- a/src/common/log.h +++ b/src/common/log.h @@ -36,16 +36,20 @@ static const char *ANDROID_LOG_TAG = #define ANDROIDLOGI(...) \ __android_log_print(ANDROID_LOG_INFO, ANDROID_LOG_TAG, __VA_ARGS__); \ - printf("%s\n", __VA_ARGS__); + fprintf(stderr, "%s\n", __VA_ARGS__); \ + fflush(stderr) #define ANDROIDLOGW(...) \ __android_log_print(ANDROID_LOG_WARNING, ANDROID_LOG_TAG, __VA_ARGS__); \ - printf("%s\n", __VA_ARGS__); + fprintf(stderr, "%s\n", __VA_ARGS__); \ + fflush(stderr) #define ANDROIDLOGD(...) \ __android_log_print(ANDROID_LOG_DEBUG, ANDROID_LOG_TAG, __VA_ARGS__); \ - printf("%s\n", __VA_ARGS__) + fprintf(stderr, "%s\n", __VA_ARGS__); \ + fflush(stderr) #define ANDROIDLOGE(...) \ __android_log_print(ANDROID_LOG_ERROR, ANDROID_LOG_TAG, __VA_ARGS__); \ - printf("%s\n", __VA_ARGS__) + fprintf(stderr, "%s\n", __VA_ARGS__); \ + fflush(stderr) #else #define ANDROIDLOGI(...) #define ANDROIDLOGW(...) diff --git a/src/io/jni/PML.java b/src/io/jni/PML.java index cfacf46135..3f162dcf9e 100644 --- a/src/io/jni/PML.java +++ b/src/io/jni/PML.java @@ -44,6 +44,8 @@ public class PML { */ public static native float[] predictImage(float[] buf, int[] ddims); + public static native float[] fetch(String varName); + public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues); // predict with variable length input diff --git a/src/io/jni/paddle_mobile_jni.cpp b/src/io/jni/paddle_mobile_jni.cpp index 63511a2226..ee336889a2 100644 --- a/src/io/jni/paddle_mobile_jni.cpp +++ b/src/io/jni/paddle_mobile_jni.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef ANDROID -#include "paddle_mobile_jni.h" +#include "io/jni/paddle_mobile_jni.h" #include #include #include @@ -193,11 +193,9 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( env->DeleteLocalRef(ddims); env->ReleaseFloatArrayElements(buf, dataPointer, 0); env->DeleteLocalRef(buf); - } catch (paddle_mobile::PaddleMobileException &e) { ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); } - #else jsize ddim_size = env->GetArrayLength(ddims); if (ddim_size != 4) { @@ -231,18 +229,43 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( #endif ANDROIDLOGI("predictImage finished"); + return result; +} + +JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env, + jclass thiz, + jstring varName) { + jfloatArray result = NULL; + +#ifdef ENABLE_EXCEPTION + try { + auto output = + getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName)); + int count = output->numel(); + result = env->NewFloatArray(count); + env->SetFloatArrayRegion(result, 0, count, output->data()); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + } +#else + auto output = + getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName)); + int count = output->numel(); + result = env->NewFloatArray(count); + env->SetFloatArrayRegion(result, 0, count, output->data()); +#endif return result; } inline int yuv_to_rgb(int y, int u, int v, float *r, float *g, float *b) { - int r1 = (int)(y + 1.370705 * (v - 128)); - int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128)); - int b1 = (int)(y + 1.732446 * (u - 128)); + int r1 = (int)(y + 1.370705 * (v - 128)); // NOLINT + int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128)); // NOLINT + int b1 = (int)(y + 1.732446 * (u - 128)); // NOLINT - r1 = (int)fminf(255, fmaxf(0, r1)); - g1 = (int)fminf(255, fmaxf(0, g1)); - b1 = (int)fminf(255, fmaxf(0, b1)); + r1 = (int)fminf(255, fmaxf(0, r1)); // NOLINT + g1 = (int)fminf(255, fmaxf(0, g1)); // NOLINT + b1 = (int)fminf(255, fmaxf(0, b1)); // NOLINT *r = r1; *g = g1; *b = b1; @@ -299,14 +322,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( framework::DDim ddim = framework::make_ddim( {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); int length = framework::product(ddim); - float matrix[length]; + float matrix[length]; // NOLINT jbyte *yuv = env->GetByteArrayElements(yuv_, NULL); float *meansPointer = nullptr; if (nullptr != meanValues) { meansPointer = env->GetFloatArrayElements(meanValues, NULL); } - convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], - ddim[2], meansPointer); + convert_nv21_to_matrix(reinterpret_cast(yuv), matrix, imgwidth, + imgHeight, ddim[3], ddim[2], meansPointer); int count = 0; framework::Tensor input; input.Resize(ddim); @@ -335,14 +358,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( framework::DDim ddim = framework::make_ddim( {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); int length = framework::product(ddim); - float matrix[length]; + float matrix[length]; // NOLINT jbyte *yuv = env->GetByteArrayElements(yuv_, NULL); float *meansPointer = nullptr; if (nullptr != meanValues) { meansPointer = env->GetFloatArrayElements(meanValues, NULL); } - convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], - ddim[2], meansPointer); + convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, // NOLINT + imgHeight, ddim[3], ddim[2], meansPointer); int count = 0; framework::Tensor input; input.Resize(ddim); @@ -408,13 +431,12 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, ANDROIDLOGI("setThreadCount %d", threadCount); #ifdef ENABLE_EXCEPTION try { - getPaddleMobileInstance()->SetThreadNum((int)threadCount); + getPaddleMobileInstance()->SetThreadNum(static_cast(threadCount)); } catch (paddle_mobile::PaddleMobileException &e) { ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); } #else - getPaddleMobileInstance()->SetThreadNum((int)threadCount); - + getPaddleMobileInstance()->SetThreadNum(static_cast(threadCount)); #endif } @@ -425,13 +447,11 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env, #ifdef ENABLE_EXCEPTION try { getPaddleMobileInstance()->Clear(); - } catch (paddle_mobile::PaddleMobileException &e) { ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); } #else getPaddleMobileInstance()->Clear(); - #endif } diff --git a/src/io/jni/paddle_mobile_jni.h b/src/io/jni/paddle_mobile_jni.h index a2825f96fe..16d6768723 100644 --- a/src/io/jni/paddle_mobile_jni.h +++ b/src/io/jni/paddle_mobile_jni.h @@ -54,6 +54,10 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified( JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims); +JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env, + jclass thiz, + jstring varName); + /** * object detection for anroid */ diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 5a5c04c656..3989dfe74f 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -28,6 +28,7 @@ void InitBaseConvKernel(ConvParam *param) { bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1]; + if (param->Filter()->type() == type_id().hash_code()) { #ifndef __aarch64__ if (depth3x3 && param->Strides()[0] < 3 && diff --git a/tools/android-cmake/android.toolchain.cmake b/tools/android-cmake/android.toolchain.cmake index 55b90ba652..b897a473d9 100644 --- a/tools/android-cmake/android.toolchain.cmake +++ b/tools/android-cmake/android.toolchain.cmake @@ -444,7 +444,7 @@ endif() # Generic flags. list(APPEND ANDROID_COMPILER_FLAGS - -g +# -g -DANDROID -ffunction-sections -funwind-tables -- GitLab