From 5798e5c9af759097053ab1b0d9790a8d0ec46cf2 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Fri, 28 Sep 2018 13:43:26 +0800 Subject: [PATCH] update nlp support --- src/jni/PML.java | 8 ++++++++ src/jni/paddle_mobile_jni.cpp | 22 ++++++++++++++++++++++ test/net/test_nlp.cpp | 30 ++++++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/jni/PML.java b/src/jni/PML.java index 717d9ebb97..9cbea253ff 100644 --- a/src/jni/PML.java +++ b/src/jni/PML.java @@ -9,6 +9,14 @@ public class PML { */ public static native boolean load(String modelDir); + /** + * load seperated model + * + * @param modelDir model dir + * @return isloadsuccess + */ + public static native boolean loadnlp(String modelDir); + /** * load combined model * diff --git a/src/jni/paddle_mobile_jni.cpp b/src/jni/paddle_mobile_jni.cpp index 111ec35def..56d522b156 100644 --- a/src/jni/paddle_mobile_jni.cpp +++ b/src/jni/paddle_mobile_jni.cpp @@ -74,6 +74,28 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, return static_cast(isLoadOk); } +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_PML_loadnlp(JNIEnv *env, jclass thiz, jstring modelPath) { + std::lock_guard lock(shared_mutex); + ANDROIDLOGI("load invoked"); + bool optimize = true; + bool isLoadOk = false; + +#ifdef ENABLE_EXCEPTION + try { + isLoadOk = getPaddleMobileInstance()->Load( + jstring2cppstring(env, modelPath), optimize, false, 1, true); + } catch (paddle_mobile::PaddleMobileException &e) { + ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); + isLoadOk = false; + } +#else + isLoadOk = getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), + optimize, false, 1, true); +#endif + return static_cast(isLoadOk); +} + JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified( JNIEnv *env, jclass thiz, jstring modelPath) { std::lock_guard lock(shared_mutex); diff --git a/test/net/test_nlp.cpp b/test/net/test_nlp.cpp index ca5f6571c8..edf5cd623a 100644 --- a/test/net/test_nlp.cpp +++ b/test/net/test_nlp.cpp @@ -32,8 +32,7 @@ int main() { std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; // 1064 1603 644 699 2878 1219 867 1352 8 1 13 312 479 - std::vector ids{1064, 1603, 644, 699, 2878, 1219, 867, - 1352, 8, 1, 13, 312, 479}; + std::vector ids{1918, 117, 55, 97, 1352, 4272, 1656, 903}; paddle_mobile::framework::LoDTensor words; auto size = static_cast(ids.size()); @@ -56,5 +55,32 @@ int main() { std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms" << std::endl; } + + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + // 1064 1603 644 699 2878 1219 867 1352 8 1 13 312 479 + + std::vector ids{1791, 656, 1549, 281, 96}; + + paddle_mobile::framework::LoDTensor words; + auto size = static_cast(ids.size()); + paddle_mobile::framework::LoD lod{{0, ids.size()}}; + DDim dims{size, 1}; + words.Resize(dims); + words.set_lod(lod); + DLOG << "words lod : " << words.lod(); + auto *pdata = words.mutable_data(); + size_t n = words.numel() * sizeof(int64_t); + DLOG << "n :" << n; + memcpy(pdata, ids.data(), n); + DLOG << "words lod 22: " << words.lod(); + auto time3 = time(); + for (int i = 0; i < 1; ++i) { + auto vec_result = paddle_mobile.PredictLod(words); + DLOG << *vec_result; + } + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms" + << std::endl; return 0; } -- GitLab