未验证 提交 6562c696 编写于 作者: W WangLiu 提交者: GitHub

Merge pull request #562 from cocodark/develop

add jni interface to accept yuv420 as input
...@@ -58,6 +58,7 @@ public class MainActivity extends Activity { ...@@ -58,6 +58,7 @@ public class MainActivity extends Activity {
private Context mContext = null; private Context mContext = null;
private int inputSize = 224; private int inputSize = 224;
private int[] ddims = {1, 3, 224, 224};
enum TYPE { enum TYPE {
googlenet googlenet
...@@ -121,14 +122,14 @@ public class MainActivity extends Activity { ...@@ -121,14 +122,14 @@ public class MainActivity extends Activity {
String assetPath = "pml_demo"; String assetPath = "pml_demo";
String sdcardPath = Environment.getExternalStorageDirectory() String sdcardPath = Environment.getExternalStorageDirectory()
+ File.separator + assetPath + File.separator + type; + File.separator + assetPath + File.separator + type;
//PML.load(sdcardPath); PML.load(sdcardPath);
String modelPath = Environment.getExternalStorageDirectory() String modelPath = Environment.getExternalStorageDirectory()
+ File.separator + assetPath + + File.separator + assetPath +
File.separator + "googlenet_combine" + File.separator + "model"; File.separator + "googlenet_combine" + File.separator + "model";
String paramPath = Environment.getExternalStorageDirectory() String paramPath = Environment.getExternalStorageDirectory()
+ File.separator + assetPath + + File.separator + assetPath +
File.separator + "googlenet_combine" + File.separator + "params"; File.separator + "googlenet_combine" + File.separator + "params";
PML.loadCombined(modelPath, paramPath); // PML.loadCombined(modelPath, paramPath);
} }
}); });
...@@ -351,8 +352,8 @@ public class MainActivity extends Activity { ...@@ -351,8 +352,8 @@ public class MainActivity extends Activity {
@Override @Override
public void onBackPressed() { public void onBackPressed() {
super.onBackPressed(); super.onBackPressed();
Log.d("mdl", "mdl clear"); Log.d("pml", "pml clear");
// clear mdl // clear pml
PML.clear(); PML.clear();
} }
...@@ -402,7 +403,7 @@ public class MainActivity extends Activity { ...@@ -402,7 +403,7 @@ public class MainActivity extends Activity {
float[] result = null; float[] result = null;
try { try {
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
result = PML.predict(inputData); result = PML.predictImage(inputData, ddims);
long end = System.currentTimeMillis(); long end = System.currentTimeMillis();
time = end - start; time = end - start;
......
...@@ -2,14 +2,14 @@ package com.baidu.paddle; ...@@ -2,14 +2,14 @@ package com.baidu.paddle;
public class PML { public class PML {
/** /**
* Load * Load seperated parameters
* @param modelPath * @param modelDir
* @return * @return
*/ */
public static native boolean load(String modelPath); public static native boolean load(String modelDir);
/** /**
* Load * Load combined parameters
* @param modelPath * @param modelPath
* @param paramPath * @param paramPath
* @return * @return
...@@ -23,7 +23,20 @@ public class PML { ...@@ -23,7 +23,20 @@ public class PML {
* @param buf * @param buf
* @return * @return
*/ */
public static native float[] predict(float[] buf); public static native float[] predictImage(float[] buf, int[]ddims);
/**
*
* @param buf yuv420格式的字节数组
* @param imgWidth yuv数据的宽
* @param imgHeight yuv数据的高
* @param ddims 输入数据的形状
* @param meanValues 模型训练时各通道的均值
* @return
*/
public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[]meanValues);
public static native void clear(); public static native void clear();
......
...@@ -192,27 +192,51 @@ which to test : ...@@ -192,27 +192,51 @@ which to test :
##部署 ##部署
Android应用可通过JNI接口调用底层C/C++,paddle-mobile对外提供的JNI接口如下: Android应用可通过JNI接口调用底层C/C++,paddle-mobile对外提供的JNI接口如下:
##### 1 load接口 加载模型参数 ##### 1 load接口 加载模型参数
- 用于加载参数文件分散的模型
``` ```
/* /**
*@param modelPath 模型文件路径 * Load seperated parameters
*@return jboolean * @param modelDir
*/ * @return
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, */
jclass thiz, public static native boolean load(String modelDir);
jstring modelPath); ```
- 用于加载参数文件合并的模型文件
``` ```
/**
* Load combined parameters
* @param modelPath
* @param paramPath
* @return
*/
public static native boolean loadCombined(String modelPath,String paramPath);
```
##### 2 predict接口 执行预测 ##### 2 predict接口 执行预测
- 接受预处理过的RGB数组的predict接口
``` ```
/** /**
*@param buf 输入数据 *@param buf 输入数据
*@return 输出数据 *@return 输出数据
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predict( JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf); JNIEnv *env, jclass thiz, jfloatArray buf);
``` ```
- 接受原始yuv数据的predict接口
```
/**
*
* @param buf yuv420格式的字节数组
* @param imgWidth yuv数据的宽
* @param imgHeight yuv数据的高
* @param ddims 输入数据的形状
* @param meanValues 模型训练时各通道的均值
* @return
*/
public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[]meanValues);
```
##### 3 clear接口 销毁实例、清理内存操作 ##### 3 clear接口 销毁实例、清理内存操作
``` ```
......
...@@ -62,15 +62,24 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, ...@@ -62,15 +62,24 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) { JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
ANDROIDLOGI("load invoked"); ANDROIDLOGI("loadCombined invoked");
bool optimize = true; bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath), jstring2cppstring(env, paramPath),
optimize); optimize);
} }
JNIEXPORT jfloatArray JNICALL JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf) { JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) {
ANDROIDLOGI("predictImage invoked");
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
jfloatArray result = NULL; jfloatArray result = NULL;
int count = 0; int count = 0;
float *dataPointer = nullptr; float *dataPointer = nullptr;
...@@ -78,17 +87,102 @@ Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf) { ...@@ -78,17 +87,102 @@ Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf) {
dataPointer = env->GetFloatArrayElements(buf, NULL); dataPointer = env->GetFloatArrayElements(buf, NULL);
} }
framework::Tensor input; framework::Tensor input;
framework::DDim ddim = framework::make_ddim({1, 3, 224, 224});
input.Resize(ddim); input.Resize(ddim);
auto input_ptr = input.mutable_data<float>(); auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < framework::product(ddim); i++) { for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i]; input_ptr[i] = dataPointer[i];
} }
auto output = shared_paddle_mobile_instance->Predict(input); auto output = shared_paddle_mobile_instance->Predict(input);
count = output->numel(); count = output->numel();
result = env->NewFloatArray(count); result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>()); env->SetFloatArrayRegion(result, 0, count, output->data<float>());
ANDROIDLOGI("predict finished"); env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
ANDROIDLOGI("predictImage finished");
return result;
}
inline int yuv_to_rgb(int y, int u, int v, float *r, float *g, float *b) {
int r1 = (int)(y + 1.370705 * (v - 128));
int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128));
int b1 = (int)(y + 1.732446 * (u - 128));
r1 = (int)fminf(255, fmaxf(0, r1));
g1 = (int)fminf(255, fmaxf(0, g1));
b1 = (int)fminf(255, fmaxf(0, b1));
*r = r1;
*g = g1;
*b = b1;
return 0;
}
void convert_nv21_to_matrix(uint8_t *nv21, float *matrix, int width, int height,
int targetWidth, int targetHeight, float *means) {
const uint8_t *yData = nv21;
const uint8_t *vuData = nv21 + width * height;
const int yRowStride = width;
const int vuRowStride = width;
float scale_x = width * 1.0 / targetWidth;
float scale_y = height * 1.0 / targetHeight;
for (int j = 0; j < targetHeight; ++j) {
int y = j * scale_y;
const uint8_t *pY = yData + y * yRowStride;
const uint8_t *pVU = vuData + (y >> 1) * vuRowStride;
for (int i = 0; i < targetWidth; ++i) {
int x = i * scale_x;
const int offset = ((x >> 1) << 1);
float r = 0;
float g = 0;
float b = 0;
yuv_to_rgb(pY[x], pVU[offset + 1], pVU[offset], &r, &g, &b);
int r_index = j * targetWidth + i;
int g_index = r_index + targetWidth * targetHeight;
int b_index = g_index + targetWidth * targetHeight;
matrix[r_index] = r - means[0];
matrix[g_index] = g - means[1];
matrix[b_index] = b - means[2];
}
}
}
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
JNIEnv *env, jclass thiz, jbyteArray yuv_, jint imgwidth, jint imgHeight,
jintArray ddims, jfloatArray meanValues) {
ANDROIDLOGI("predictYuv invoked");
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
ANDROIDLOGE("ddims size not equal to 4");
}
jint *ddim_ptr = env->GetIntArrayElements(ddims, NULL);
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
float matrix[length];
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr;
if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL);
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
jfloatArray result = NULL;
int count = 0;
framework::Tensor input;
input.Resize(ddim);
auto input_ptr = input.mutable_data<float>();
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = shared_paddle_mobile_instance->Predict(input);
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
env->ReleaseByteArrayElements(yuv_, yuv, 0);
env->ReleaseIntArrayElements(ddims, ddim_ptr, 0);
env->ReleaseFloatArrayElements(meanValues, meansPointer, 0);
ANDROIDLOGI("predictYuv finished");
return result; return result;
} }
......
...@@ -33,6 +33,19 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, ...@@ -33,6 +33,19 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath); JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath);
/**
* object detection for anroid
*/
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims);
/**
* object detection for anroid
*/
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
JNIEnv *env, jclass thiz, jbyteArray yuv, jint imgwidth, jint imgHeight,
jintArray ddims, jfloatArray meanValues);
/** /**
* object detection for anroid * object detection for anroid
*/ */
......
...@@ -60,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> { ...@@ -60,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
T *output_data = output->mutable_data<T>(); T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
#pragma omp parallel for
for (int ph = 0; ph < output_height; ++ph) { for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height; int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height); int hend = std::min(hstart + ksize_height, input_height);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册