提交 10b4bfae 编写于 作者: xiebaiyuan's avatar xiebaiyuan

update jni with lock and exception logics && add java interface close #807

上级 3fba2b98
package com.baidu.paddle;
public class PML {
/**
* load seperated model
*
* @param modelDir model dir
* @return isloadsuccess
*/
public static native boolean load(String modelDir);
/**
* load combined model
*
* @param modelPath model file path
* @param paramPath param file path
* @return isloadsuccess
*/
public static native boolean loadCombined(String modelPath, String paramPath);
/**
* load model and qualified params
*
* @param modelDir qualified model dir
* @return isloadsuccess
*/
public static native boolean loadQualified(String modelDir);
/**
* load model and qualified combined params
*
* @param modelPath model file path
* @param paramPath qualified param path
* @return isloadsuccess
*/
public static native boolean loadCombinedQualified(String modelPath, String paramPath);
/**
* predict image
*
* @param buf of pretreated image (as your model like)
* @param ddims format of your input
* @return result
*/
public static native float[] predictImage(float[] buf, int[] ddims);
public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues);
/**
* clear model data
*/
public static native void clear();
/**
* setThread num when u enable openmp
*
* @param threadCount threadCount
*/
public static native void setThread(int threadCount);
}
......@@ -20,6 +20,12 @@ limitations under the License. */
#include "framework/tensor.h"
#include "io/paddle_mobile.h"
#ifdef ENABLE_EXCEPTION
#include "common/enforce.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
......@@ -33,17 +39,10 @@ using std::string;
extern const char *ANDROID_LOG_TAG =
"paddle_mobile LOG built on " __DATE__ " " __TIME__;
static PaddleMobile<CPU> *shared_paddle_mobile_instance = nullptr;
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
static std::mutex shared_mutex;
// toDo mutex lock
// static std::mutex shared_mutex;
PaddleMobile<CPU> *getPaddleMobileInstance() {
if (nullptr == shared_paddle_mobile_instance) {
shared_paddle_mobile_instance = new PaddleMobile<CPU>();
}
return shared_paddle_mobile_instance;
}
PaddleMobile<CPU> *getPaddleMobileInstance() { return &paddle_mobile; }
string jstring2cppstring(JNIEnv *env, jstring jstr) {
const char *cstr = env->GetStringUTFChars(jstr, 0);
......@@ -55,43 +54,111 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) {
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
jclass thiz,
jstring modelPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("load invoked");
bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), optimize);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize);
#endif
return static_cast<jboolean>(isLoadOk);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified(
JNIEnv *env, jclass thiz, jstring modelPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadQualified invoked");
bool optimize = true;
bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), optimize, qualified);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize, qualified);
#endif
return static_cast<jboolean>(isLoadOk);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadCombined invoked");
bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath),
optimize);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath),
optimize);
#endif
return static_cast<jboolean>(isLoadOk);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("loadCombinedQualified invoked");
bool optimize = true;
bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
bool isLoadOk = false;
#ifdef ENABLE_EXCEPTION
try {
isLoadOk = getPaddleMobileInstance()->Load(
jstring2cppstring(env, modelPath), jstring2cppstring(env, paramPath),
optimize, qualified);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
isLoadOk = false;
}
#else
isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath),
optimize, qualified);
#endif
return static_cast<jboolean>(isLoadOk);
}
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("predictImage invoked");
jfloatArray result = NULL;
#ifdef ENABLE_EXCEPTION
ANDROIDLOGE("ENABLE_EXCEPTION!");
try {
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
......@@ -100,7 +167,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
jfloatArray result = NULL;
int count = 0;
float *dataPointer = nullptr;
if (nullptr != buf) {
......@@ -112,12 +178,52 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i];
}
auto output = shared_paddle_mobile_instance->Predict(input);
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
int count = 0;
float *dataPointer = nullptr;
if (nullptr != buf) {
dataPointer = env->GetFloatArrayElements(buf, NULL);
}
framework::Tensor input;
input.Resize(ddim);
auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf);
env->DeleteLocalRef(dataPointer);
#endif
ANDROIDLOGI("predictImage finished");
return result;
}
......@@ -170,7 +276,48 @@ void convert_nv21_to_matrix(uint8_t *nv21, float *matrix, int width, int height,
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
JNIEnv *env, jclass thiz, jbyteArray yuv_, jint imgwidth, jint imgHeight,
jintArray ddims, jfloatArray meanValues) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("predictYuv invoked");
jfloatArray result = NULL;
#ifdef ENABLE_EXCEPTION
try {
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
float matrix[length];
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr;
if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL);
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
int count = 0;
framework::Tensor input;
input.Resize(ddim);
auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseByteArrayElements(yuv_, yuv, 0);
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->ReleaseFloatArrayElements(meanValues, meansPointer, 0);
ANDROIDLOGI("predictYuv finished");
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
......@@ -187,7 +334,6 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
jfloatArray result = NULL;
int count = 0;
framework::Tensor input;
input.Resize(ddim);
......@@ -195,7 +341,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = shared_paddle_mobile_instance->Predict(input);
auto output = getPaddleMobileInstance()->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
......@@ -203,19 +349,44 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->ReleaseFloatArrayElements(meanValues, meansPointer, 0);
ANDROIDLOGI("predictYuv finished");
#endif
return result;
}
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
jclass thiz,
jint threadCount) {
std::lock_guard<std::mutex> lock(shared_mutex);
ANDROIDLOGI("setThreadCount %d", threadCount);
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->SetThreadNum((int)threadCount);
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->SetThreadNum((int)threadCount);
#endif
}
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
jclass thiz) {
std::lock_guard<std::mutex> lock(shared_mutex);
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->Clear();
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->Clear();
#endif
}
} // namespace jni
......
......@@ -73,7 +73,7 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
/**
* clear data of the net when destroy for android
*/
JNIEXPORT void JNICALL Java_com_baidu_paddle_PMLL_clear(JNIEnv *env,
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
jclass thiz);
} // namespace jni
} // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册