From eaa57bcd7cdf4a4d90bca72b1b116f266fbebc69 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Mon, 20 Aug 2018 23:20:53 +0800 Subject: [PATCH] update jni with lock and exception logics && add java interface close #807 --- src/jni/PML.java | 63 ++++++++++ src/jni/paddle_mobile_jni.cpp | 219 ++++++++++++++++++++++++++++++---- src/jni/paddle_mobile_jni.h | 4 +- 3 files changed, 260 insertions(+), 26 deletions(-) create mode 100644 src/jni/PML.java diff --git a/src/jni/PML.java b/src/jni/PML.java new file mode 100644 index 0000000000..717d9ebb97 --- /dev/null +++ b/src/jni/PML.java @@ -0,0 +1,63 @@ +package com.baidu.paddle; + +public class PML { + /** + * load seperated model + * + * @param modelDir model dir + * @return isloadsuccess + */ + public static native boolean load(String modelDir); + + /** + * load combined model + * + * @param modelPath model file path + * @param paramPath param file path + * @return isloadsuccess + */ + public static native boolean loadCombined(String modelPath, String paramPath); + + /** + * load model and qualified params + * + * @param modelDir qualified model dir + * @return isloadsuccess + */ + public static native boolean loadQualified(String modelDir); + + /** + * load model and qualified combined params + * + * @param modelPath model file path + * @param paramPath qualified param path + * @return isloadsuccess + */ + public static native boolean loadCombinedQualified(String modelPath, String paramPath); + + /** + * predict image + * + * @param buf of pretreated image (as your model like) + * @param ddims format of your input + * @return result + */ + public static native float[] predictImage(float[] buf, int[] ddims); + + + public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues); + + /** + * clear model data + */ + public static native void clear(); + + /** + * setThread num when u enable openmp + * + * @param threadCount threadCount + */ + public static native void setThread(int threadCount); + + +} diff --git a/src/jni/paddle_mobile_jni.cpp b/src/jni/paddle_mobile_jni.cpp index c8ed491672..1b909532e9 100644 --- a/src/jni/paddle_mobile_jni.cpp +++ b/src/jni/paddle_mobile_jni.cpp @@ -20,6 +20,12 @@ limitations under the License. */ #include "framework/tensor.h" #include "io/paddle_mobile.h" +#ifdef ENABLE_EXCEPTION + +#include "common/enforce.h" + +#endif + #ifdef __cplusplus extern "C" { #endif @@ -33,17 +39,10 @@ using std::string; extern const char *ANDROID_LOG_TAG = "paddle_mobile LOG built on " __DATE__ " " __TIME__; -static PaddleMobile *shared_paddle_mobile_instance = nullptr; +paddle_mobile::PaddleMobile paddle_mobile; +static std::mutex shared_mutex; -// toDo mutex lock -// static std::mutex shared_mutex; - -PaddleMobile *getPaddleMobileInstance() { - if (nullptr == shared_paddle_mobile_instance) { - shared_paddle_mobile_instance = new PaddleMobile(); - } - return shared_paddle_mobile_instance; -} +PaddleMobile *getPaddleMobileInstance() { return &paddle_mobile; } string jstring2cppstring(JNIEnv *env, jstring jstr) { const char *cstr = env->GetStringUTFChars(jstr, 0); @@ -55,43 +54,144 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) { JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, jclass thiz, jstring modelPath) { + std::lock_guard lock(shared_mutex); ANDROIDLOGI("load invoked"); bool optimize = true; - return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), - optimize); + bool isLoadOk = false; + +#ifdef ENABLE_EXCEPTION + try { + isLoadOk = getPaddleMobileInstance()->Load( + jstring2cppstring(env, modelPath), optimize); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + isLoadOk = false; + } +#else + isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), + optimize); +#endif + return static_cast(isLoadOk); } JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified( JNIEnv *env, jclass thiz, jstring modelPath) { + std::lock_guard lock(shared_mutex); + ANDROIDLOGI("loadQualified invoked"); bool optimize = true; bool qualified = true; - return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), - optimize, qualified); + bool isLoadOk = false; + +#ifdef ENABLE_EXCEPTION + try { + isLoadOk = getPaddleMobileInstance()->Load( + jstring2cppstring(env, modelPath), optimize, qualified); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + isLoadOk = false; + } +#else + isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), + optimize, qualified); +#endif + + return static_cast(isLoadOk); } JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined( JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) { + std::lock_guard lock(shared_mutex); ANDROIDLOGI("loadCombined invoked"); bool optimize = true; - return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), - jstring2cppstring(env, paramPath), - optimize); + bool isLoadOk = false; + +#ifdef ENABLE_EXCEPTION + try { + isLoadOk = getPaddleMobileInstance()->Load( + jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath), + optimize); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + isLoadOk = false; + } +#else + isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), + jstring2cppstring(env, paramPath), + optimize); +#endif + return static_cast(isLoadOk); } JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified( JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) { + std::lock_guard lock(shared_mutex); ANDROIDLOGI("loadCombinedQualified invoked"); bool optimize = true; bool qualified = true; - return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), - jstring2cppstring(env, paramPath), - optimize, qualified); + bool isLoadOk = false; + +#ifdef ENABLE_EXCEPTION + try { + isLoadOk = getPaddleMobileInstance()->Load( + jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath), + optimize, qualified); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + isLoadOk = false; + } +#else + isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), + jstring2cppstring(env, paramPath), + optimize, qualified); +#endif + return static_cast(isLoadOk); } JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) { + std::lock_guard lock(shared_mutex); + ANDROIDLOGI("predictImage invoked"); + jfloatArray result = NULL; + +#ifdef ENABLE_EXCEPTION + ANDROIDLOGE("ENABLE_EXCEPTION!"); + + try { + jsize ddim_size = env->GetArrayLength(ddims); + if (ddim_size != 4) { + ANDROIDLOGE("ddims size not equal to 4"); + } + jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL); + framework::DDim ddim = framework::make_ddim( + {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); + int length = framework::product(ddim); + int count = 0; + float *dataPointer = nullptr; + if (nullptr != buf) { + dataPointer = env->GetFloatArrayElements(buf, NULL); + } + framework::Tensor input; + input.Resize(ddim); + auto input_ptr = input.mutable_data(); + for (int i = 0; i < length; i++) { + input_ptr[i] = dataPointer[i]; + } + auto output = getPaddleMobileInstance()->Predict(input); + count = output->numel(); + result = env->NewFloatArray(count); + env->SetFloatArrayRegion(result, 0, count, output->data()); + env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); + env->DeleteLocalRef(ddims); + env->ReleaseFloatArrayElements(buf, dataPointer, 0); + env->DeleteLocalRef(buf); + + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + } + +#else jsize ddim_size = env->GetArrayLength(ddims); if (ddim_size != 4) { ANDROIDLOGE("ddims size not equal to 4"); @@ -100,7 +200,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( framework::DDim ddim = framework::make_ddim( {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); int length = framework::product(ddim); - jfloatArray result = NULL; int count = 0; float *dataPointer = nullptr; if (nullptr != buf) { @@ -112,12 +211,19 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( for (int i = 0; i < length; i++) { input_ptr[i] = dataPointer[i]; } - auto output = shared_paddle_mobile_instance->Predict(input); + auto output = getPaddleMobileInstance()->Predict(input); count = output->numel(); result = env->NewFloatArray(count); env->SetFloatArrayRegion(result, 0, count, output->data()); env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); + env->DeleteLocalRef(ddims); + env->ReleaseFloatArrayElements(buf, dataPointer, 0); + env->DeleteLocalRef(buf); + env->DeleteLocalRef(dataPointer); +#endif + ANDROIDLOGI("predictImage finished"); + return result; } @@ -170,7 +276,48 @@ void convert_nv21_to_matrix(uint8_t *nv21, float *matrix, int width, int height, JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( JNIEnv *env, jclass thiz, jbyteArray yuv_, jint imgwidth, jint imgHeight, jintArray ddims, jfloatArray meanValues) { + std::lock_guard lock(shared_mutex); + ANDROIDLOGI("predictYuv invoked"); + jfloatArray result = NULL; + +#ifdef ENABLE_EXCEPTION + try { + jsize ddim_size = env->GetArrayLength(ddims); + if (ddim_size != 4) { + ANDROIDLOGE("ddims size not equal to 4"); + } + jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL); + framework::DDim ddim = framework::make_ddim( + {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); + int length = framework::product(ddim); + float matrix[length]; + jbyte *yuv = env->GetByteArrayElements(yuv_, NULL); + float *meansPointer = nullptr; + if (nullptr != meanValues) { + meansPointer = env->GetFloatArrayElements(meanValues, NULL); + } + convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], + ddim[2], meansPointer); + int count = 0; + framework::Tensor input; + input.Resize(ddim); + auto input_ptr = input.mutable_data(); + for (int i = 0; i < length; i++) { + input_ptr[i] = matrix[i]; + } + auto output = getPaddleMobileInstance()->Predict(input); + count = output->numel(); + result = env->NewFloatArray(count); + env->SetFloatArrayRegion(result, 0, count, output->data()); + env->ReleaseByteArrayElements(yuv_, yuv, 0); + env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); + env->ReleaseFloatArrayElements(meanValues, meansPointer, 0); + ANDROIDLOGI("predictYuv finished"); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + } +#else jsize ddim_size = env->GetArrayLength(ddims); if (ddim_size != 4) { ANDROIDLOGE("ddims size not equal to 4"); @@ -187,7 +334,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( } convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], ddim[2], meansPointer); - jfloatArray result = NULL; int count = 0; framework::Tensor input; input.Resize(ddim); @@ -195,7 +341,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( for (int i = 0; i < length; i++) { input_ptr[i] = matrix[i]; } - auto output = shared_paddle_mobile_instance->Predict(input); + auto output = getPaddleMobileInstance()->Predict(input); count = output->numel(); result = env->NewFloatArray(count); env->SetFloatArrayRegion(result, 0, count, output->data()); @@ -203,19 +349,44 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); env->ReleaseFloatArrayElements(meanValues, meansPointer, 0); ANDROIDLOGI("predictYuv finished"); +#endif + return result; } JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, jclass thiz, jint threadCount) { + std::lock_guard lock(shared_mutex); + ANDROIDLOGI("setThreadCount %d", threadCount); +#ifdef ENABLE_EXCEPTION + try { + getPaddleMobileInstance()->SetThreadNum((int)threadCount); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + } +#else getPaddleMobileInstance()->SetThreadNum((int)threadCount); + +#endif } JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env, jclass thiz) { + std::lock_guard lock(shared_mutex); + +#ifdef ENABLE_EXCEPTION + try { + getPaddleMobileInstance()->Clear(); + + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + } +#else getPaddleMobileInstance()->Clear(); + +#endif } } // namespace jni diff --git a/src/jni/paddle_mobile_jni.h b/src/jni/paddle_mobile_jni.h index 4fd62a6d56..158d64d451 100644 --- a/src/jni/paddle_mobile_jni.h +++ b/src/jni/paddle_mobile_jni.h @@ -73,8 +73,8 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, /** * clear data of the net when destroy for android */ -JNIEXPORT void JNICALL Java_com_baidu_paddle_PMLL_clear(JNIEnv *env, - jclass thiz); +JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env, + jclass thiz); } // namespace jni } // namespace paddle_mobile #ifdef __cplusplus -- GitLab