未验证 提交 8505700f 编写于 作者: H Houjiang Chen 提交者: GitHub

Add fetch API for java, refine android log (#1558)

* Clear no persistable tensor array before predicting, fix crash when predicting with gpu debugging mode

* Fix code style

* Add fetch API for java, refine android log
上级 563f0cc5
...@@ -36,16 +36,20 @@ static const char *ANDROID_LOG_TAG = ...@@ -36,16 +36,20 @@ static const char *ANDROID_LOG_TAG =
#define ANDROIDLOGI(...) \ #define ANDROIDLOGI(...) \
__android_log_print(ANDROID_LOG_INFO, ANDROID_LOG_TAG, __VA_ARGS__); \ __android_log_print(ANDROID_LOG_INFO, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__); fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGW(...) \ #define ANDROIDLOGW(...) \
__android_log_print(ANDROID_LOG_WARNING, ANDROID_LOG_TAG, __VA_ARGS__); \ __android_log_print(ANDROID_LOG_WARNING, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__); fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGD(...) \ #define ANDROIDLOGD(...) \
__android_log_print(ANDROID_LOG_DEBUG, ANDROID_LOG_TAG, __VA_ARGS__); \ __android_log_print(ANDROID_LOG_DEBUG, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__) fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGE(...) \ #define ANDROIDLOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, ANDROID_LOG_TAG, __VA_ARGS__); \ __android_log_print(ANDROID_LOG_ERROR, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__) fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#else #else
#define ANDROIDLOGI(...) #define ANDROIDLOGI(...)
#define ANDROIDLOGW(...) #define ANDROIDLOGW(...)
......
...@@ -44,6 +44,8 @@ public class PML { ...@@ -44,6 +44,8 @@ public class PML {
*/ */
public static native float[] predictImage(float[] buf, int[] ddims); public static native float[] predictImage(float[] buf, int[] ddims);
public static native float[] fetch(String varName);
public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues); public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues);
// predict with variable length input // predict with variable length input
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef ANDROID #ifdef ANDROID
#include "paddle_mobile_jni.h" #include "io/jni/paddle_mobile_jni.h"
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -193,11 +193,9 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( ...@@ -193,11 +193,9 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
env->DeleteLocalRef(ddims); env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0); env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf); env->DeleteLocalRef(buf);
} catch (paddle_mobile::PaddleMobileException &e) { } catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
} }
#else #else
jsize ddim_size = env->GetArrayLength(ddims); jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) { if (ddim_size != 4) {
...@@ -231,18 +229,43 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( ...@@ -231,18 +229,43 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
#endif #endif
ANDROIDLOGI("predictImage finished"); ANDROIDLOGI("predictImage finished");
return result;
}
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env,
jclass thiz,
jstring varName) {
jfloatArray result = NULL;
#ifdef ENABLE_EXCEPTION
try {
auto output =
getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName));
int count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
auto output =
getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName));
int count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
#endif
return result; return result;
} }
inline int yuv_to_rgb(int y, int u, int v, float *r, float *g, float *b) { inline int yuv_to_rgb(int y, int u, int v, float *r, float *g, float *b) {
int r1 = (int)(y + 1.370705 * (v - 128)); int r1 = (int)(y + 1.370705 * (v - 128)); // NOLINT
int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128)); int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128)); // NOLINT
int b1 = (int)(y + 1.732446 * (u - 128)); int b1 = (int)(y + 1.732446 * (u - 128)); // NOLINT
r1 = (int)fminf(255, fmaxf(0, r1)); r1 = (int)fminf(255, fmaxf(0, r1)); // NOLINT
g1 = (int)fminf(255, fmaxf(0, g1)); g1 = (int)fminf(255, fmaxf(0, g1)); // NOLINT
b1 = (int)fminf(255, fmaxf(0, b1)); b1 = (int)fminf(255, fmaxf(0, b1)); // NOLINT
*r = r1; *r = r1;
*g = g1; *g = g1;
*b = b1; *b = b1;
...@@ -299,14 +322,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( ...@@ -299,14 +322,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
framework::DDim ddim = framework::make_ddim( framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim); int length = framework::product(ddim);
float matrix[length]; float matrix[length]; // NOLINT
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL); jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr; float *meansPointer = nullptr;
if (nullptr != meanValues) { if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL); meansPointer = env->GetFloatArrayElements(meanValues, NULL);
} }
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], convert_nv21_to_matrix(reinterpret_cast<uint8_t *>(yuv), matrix, imgwidth,
ddim[2], meansPointer); imgHeight, ddim[3], ddim[2], meansPointer);
int count = 0; int count = 0;
framework::Tensor input; framework::Tensor input;
input.Resize(ddim); input.Resize(ddim);
...@@ -335,14 +358,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( ...@@ -335,14 +358,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
framework::DDim ddim = framework::make_ddim( framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]}); {ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim); int length = framework::product(ddim);
float matrix[length]; float matrix[length]; // NOLINT
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL); jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr; float *meansPointer = nullptr;
if (nullptr != meanValues) { if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL); meansPointer = env->GetFloatArrayElements(meanValues, NULL);
} }
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3], convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, // NOLINT
ddim[2], meansPointer); imgHeight, ddim[3], ddim[2], meansPointer);
int count = 0; int count = 0;
framework::Tensor input; framework::Tensor input;
input.Resize(ddim); input.Resize(ddim);
...@@ -408,13 +431,12 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, ...@@ -408,13 +431,12 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
ANDROIDLOGI("setThreadCount %d", threadCount); ANDROIDLOGI("setThreadCount %d", threadCount);
#ifdef ENABLE_EXCEPTION #ifdef ENABLE_EXCEPTION
try { try {
getPaddleMobileInstance()->SetThreadNum((int)threadCount); getPaddleMobileInstance()->SetThreadNum(static_cast<int>(threadCount));
} catch (paddle_mobile::PaddleMobileException &e) { } catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
} }
#else #else
getPaddleMobileInstance()->SetThreadNum((int)threadCount); getPaddleMobileInstance()->SetThreadNum(static_cast<int>(threadCount));
#endif #endif
} }
...@@ -425,13 +447,11 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env, ...@@ -425,13 +447,11 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
#ifdef ENABLE_EXCEPTION #ifdef ENABLE_EXCEPTION
try { try {
getPaddleMobileInstance()->Clear(); getPaddleMobileInstance()->Clear();
} catch (paddle_mobile::PaddleMobileException &e) { } catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what()); ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
} }
#else #else
getPaddleMobileInstance()->Clear(); getPaddleMobileInstance()->Clear();
#endif #endif
} }
......
...@@ -54,6 +54,10 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified( ...@@ -54,6 +54,10 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims); JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims);
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env,
jclass thiz,
jstring varName);
/** /**
* object detection for anroid * object detection for anroid
*/ */
......
...@@ -28,6 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -28,6 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1]; param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == type_id<int8_t>().hash_code()) { if (param->Filter()->type() == type_id<int8_t>().hash_code()) {
#ifndef __aarch64__ #ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 && if (depth3x3 && param->Strides()[0] < 3 &&
......
...@@ -444,7 +444,7 @@ endif() ...@@ -444,7 +444,7 @@ endif()
# Generic flags. # Generic flags.
list(APPEND ANDROID_COMPILER_FLAGS list(APPEND ANDROID_COMPILER_FLAGS
-g # -g
-DANDROID -DANDROID
-ffunction-sections -ffunction-sections
-funwind-tables -funwind-tables
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册