提交 10b4bfae 编写于 作者: xiebaiyuan's avatar xiebaiyuan

update jni with lock and exception logics && add java interface close #807

上级 3fba2b98
package com.baidu.paddle;
public class PML {
/**
* load seperated model
*
* @param modelDir model dir
* @return isloadsuccess
*/
public static native boolean load(String modelDir);
/**
* load combined model
*
* @param modelPath model file path
* @param paramPath param file path
* @return isloadsuccess
*/
public static native boolean loadCombined(String modelPath, String paramPath);
/**
* load model and qualified params
*
* @param modelDir qualified model dir
* @return isloadsuccess
*/
public static native boolean loadQualified(String modelDir);
/**
* load model and qualified combined params
*
* @param modelPath model file path
* @param paramPath qualified param path
* @return isloadsuccess
*/
public static native boolean loadCombinedQualified(String modelPath, String paramPath);
/**
* predict image
*
* @param buf of pretreated image (as your model like)
* @param ddims format of your input
* @return result
*/
public static native float[] predictImage(float[] buf, int[] ddims);
public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues);
/**
* clear model data
*/
public static native void clear();
/**
* setThread num when u enable openmp
*
* @param threadCount threadCount
*/
public static native void setThread(int threadCount);
}
...@@ -20,6 +20,12 @@ limitations under the License. */ ...@@ -20,6 +20,12 @@ limitations under the License. */
#include "framework/tensor.h" #include "framework/tensor.h"
#include "io/paddle_mobile.h" #include "io/paddle_mobile.h"
#ifdef ENABLE_EXCEPTION
#include "common/enforce.h"
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -33,17 +39,10 @@ using std::string; ...@@ -33,17 +39,10 @@ using std::string;
extern const char *ANDROID_LOG_TAG = extern const char *ANDROID_LOG_TAG =
"paddle_mobile LOG built on " __DATE__ " " __TIME__; "paddle_mobile LOG built on " __DATE__ " " __TIME__;
static PaddleMobile<CPU> *shared_paddle_mobile_instance = nullptr; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
static std::mutex shared_mutex;
// toDo mutex lock PaddleMobile<CPU> *getPaddleMobileInstance() { return &paddle_mobile; }
// static std::mutex shared_mutex;
PaddleMobile<CPU> *getPaddleMobileInstance() {
if (nullptr == shared_paddle_mobile_instance) {
shared_paddle_mobile_instance = new PaddleMobile<CPU>();
}
return shared_paddle_mobile_instance;
}
string jstring2cppstring(JNIEnv *env, jstring jstr) { string jstring2cppstring(JNIEnv *env, jstring jstr) {
const char *cstr = env->GetStringUTFChars(jstr, 0); const char *cstr = env->GetStringUTFChars(jstr, 0);
...@@ -55,43 +54,111 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) { ...@@ -55,43 +54,111 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) {
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
jclass thiz, jclass thiz,
jstring modelPath) { jstring modelPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("load invoked"); ANDROIDLOGI("load invoked");
bool optimize = true; bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), optimize);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize); optimize);
#endif
return static_cast<jboolean>(isLoadOk);
} }
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified(
JNIEnv *env, jclass thiz, jstring modelPath) { JNIEnv *env, jclass thiz, jstring modelPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadQualified invoked"); ANDROIDLOGI("loadQualified invoked");
bool optimize = true; bool optimize = true;
bool qualified = true; bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), optimize, qualified);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize, qualified); optimize, qualified);
#endif
return static_cast<jboolean>(isLoadOk);
} }
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) { JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadCombined invoked"); ANDROIDLOGI("loadCombined invoked");
bool optimize = true; bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath),
optimize);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath), jstring2cppstring(env, paramPath),
optimize); optimize);
#endif
return static_cast<jboolean>(isLoadOk);
} }
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) { JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadCombinedQualified invoked"); ANDROIDLOGI("loadCombinedQualified invoked");
bool optimize = true; bool optimize = true;
bool qualified = true; bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath),
optimize, qualified);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath), jstring2cppstring(env, paramPath),
optimize, qualified); optimize, qualified);
#endif
return static_cast<jboolean>(isLoadOk);
} }
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) { JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("predictImage invoked"); ANDROIDLOGI("predictImage invoked");
jfloatArray result = NULL;
#ifdef ENABLE_EXCEPTION
ANDROIDLOGE("ENABLE_EXCEPTION!");
try {
jsize ddim_size = env->GetArrayLength(ddims); jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) { if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4"); ANDROIDLOGE("ddims size not equal to 4");
...@@ -100,7 +167,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( ...@@ -100,7 +167,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
framework::DDim ddim = framework::make_ddim( framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim); int length = framework::product(ddim);
jfloatArray result = NULL;
int count = 0; int count = 0;
float *dataPointer = nullptr; float *dataPointer = nullptr;
if (nullptr != buf) { if (nullptr != buf) {
...@@ -112,12 +178,52 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( ...@@ -112,12 +178,52 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i]; input_ptr[i] = dataPointer[i];
} }
auto output = shared_paddle_mobile_instance->Predict(input); auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel(); count = output->numel();
result = env->NewFloatArray(count); result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>()); env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
int count = 0;
float *dataPointer = nullptr;
if (nullptr != buf) {
dataPointer = env->GetFloatArrayElements(buf, NULL);
}
framework::Tensor input;
input.Resize(ddim);
auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf);
env->DeleteLocalRef(dataPointer);
#endif
ANDROIDLOGI("predictImage finished"); ANDROIDLOGI("predictImage finished");
return result; return result;
} }
...@@ -170,7 +276,48 @@ void convert_nv21_to_matrix(uint8_t *nv21, float *matrix, int width, int height, ...@@ -170,7 +276,48 @@ void convert_nv21_to_matrix(uint8_t *nv21, float *matrix, int width, int height,
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
JNIEnv *env, jclass thiz, jbyteArray yuv_, jint imgwidth, jint imgHeight, JNIEnv *env, jclass thiz, jbyteArray yuv_, jint imgwidth, jint imgHeight,
jintArray ddims, jfloatArray meanValues) { jintArray ddims, jfloatArray meanValues) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("predictYuv invoked"); ANDROIDLOGI("predictYuv invoked");
jfloatArray result = NULL;
#ifdef ENABLE_EXCEPTION
try {
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
float matrix[length];
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr;
if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL);
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
int count = 0;
framework::Tensor input;
input.Resize(ddim);
auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseByteArrayElements(yuv_, yuv, 0);
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->ReleaseFloatArrayElements(meanValues, meansPointer, 0);
ANDROIDLOGI("predictYuv finished");
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
jsize ddim_size = env->GetArrayLength(ddims); jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) { if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4"); ANDROIDLOGE("ddims size not equal to 4");
...@@ -187,7 +334,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( ...@@ -187,7 +334,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
} }
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer); ddim[2], meansPointer);
jfloatArray result = NULL;
int count = 0; int count = 0;
framework::Tensor input; framework::Tensor input;
input.Resize(ddim); input.Resize(ddim);
...@@ -195,7 +341,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( ...@@ -195,7 +341,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i]; input_ptr[i] = matrix[i];
} }
auto output = shared_paddle_mobile_instance->Predict(input); auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel(); count = output->numel();
result = env->NewFloatArray(count); result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>()); env->SetFloatArrayRegion(result, 0, count, output->data<float>());
...@@ -203,19 +349,44 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( ...@@ -203,19 +349,44 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0); env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->ReleaseFloatArrayElements(meanValues, meansPointer, 0); env->ReleaseFloatArrayElements(meanValues, meansPointer, 0);
ANDROIDLOGI("predictYuv finished"); ANDROIDLOGI("predictYuv finished");
#endif
return result; return result;
} }
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
jclass thiz, jclass thiz,
jint threadCount) { jint threadCount) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("setThreadCount %d", threadCount); ANDROIDLOGI("setThreadCount %d", threadCount);
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->SetThreadNum((int)threadCount); getPaddleMobileInstance()->SetThreadNum((int)threadCount);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->SetThreadNum((int)threadCount);
#endif
} }
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env, JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
jclass thiz) { jclass thiz) {
std::lock_guard<std::mutex> lock(shared_mutex);
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->Clear();
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->Clear(); getPaddleMobileInstance()->Clear();
#endif
} }
} // namespace jni } // namespace jni
......
...@@ -73,7 +73,7 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, ...@@ -73,7 +73,7 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
/** /**
* clear data of the net when destroy for android * clear data of the net when destroy for android
*/ */
JNIEXPORT void JNICALL Java_com_baidu_paddle_PMLL_clear(JNIEnv *env, JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
jclass thiz); jclass thiz);
} // namespace jni } // namespace jni
} // namespace paddle_mobile } // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册