From 71aa1b491aeabc43f85327bc8f50ccb5bd619154 Mon Sep 17 00:00:00 2001 From: yiicy Date: Mon, 9 Dec 2019 17:15:43 +0800 Subject: [PATCH] [JAVA API]java tensor api setData and getData support Int type, test=develop (#2583) --- lite/api/android/jni/native/tensor_jni.cc | 30 +++++++++++++++++++ lite/api/android/jni/native/tensor_jni.h | 22 ++++++++++++-- .../jni/src/com/baidu/paddle/lite/Tensor.java | 20 +++++++++++++ 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/lite/api/android/jni/native/tensor_jni.cc b/lite/api/android/jni/native/tensor_jni.cc index 59cafa1939..5212fe9a6e 100644 --- a/lite/api/android/jni/native/tensor_jni.cc +++ b/lite/api/android/jni/native/tensor_jni.cc @@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( return JNI_TRUE; } +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *env, jobject jtensor, jintArray buf) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + int64_t buf_size = (int64_t)env->GetArrayLength(buf); + if (buf_size != product((*tensor)->shape())) { + return JNI_FALSE; + } + + int32_t *input = (*tensor)->mutable_data(); + env->GetIntArrayRegion(buf, 0, buf_size, input); + return JNI_TRUE; +} + JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { if (is_const_tensor(env, jtensor)) { @@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) { } } +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } +} + JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( JNIEnv *env, jobject jtensor, jlong java_pointer) { if (java_pointer == 0) { diff --git a/lite/api/android/jni/native/tensor_jni.h b/lite/api/android/jni/native/tensor_jni.h index 34c35b6a76..9b029dfb4c 100644 --- a/lite/api/android/jni/native/tensor_jni.h +++ b/lite/api/android/jni/native/tensor_jni.h @@ -16,8 +16,8 @@ #include /* Header for class com_baidu_paddle_lite_Tensor */ -#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ -#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#ifndef LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#define LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #ifdef __cplusplus extern "C" { #endif @@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject); JNIEXPORT jbyteArray JNICALL Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: getIntData + * Signature: ()[I + */ +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject); + /* * Class: com_baidu_paddle_lite_Tensor * Method: nativeResize @@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( JNIEnv *, jobject, jbyteArray); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeSetData + * Signature: ([I)Z + */ +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *, jobject, jintArray); + /* * Class: com_baidu_paddle_lite_Tensor * Method: deleteCppTensor @@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong); #ifdef __cplusplus } #endif -#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#endif // LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java index ac78800bd2..f76841dd41 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java @@ -108,6 +108,19 @@ public class Tensor { return nativeSetData(buf); } + /** + * Set the tensor int data. + * + * @param buf the int array buffer which will be copied into tensor. + * @return true if set data successfully. + */ + public boolean setData(int[] buf) { + if (readOnly) { + return false; + } + return nativeSetData(buf); + } + /** * @return shape of the tensor as long array. */ @@ -123,12 +136,19 @@ public class Tensor { */ public native byte[] getByteData(); + /** + * @return the tensor data as int array. + */ + public native int[] getIntData(); + private native boolean nativeResize(long[] dims); private native boolean nativeSetData(float[] buf); private native boolean nativeSetData(byte[] buf); + private native boolean nativeSetData(int[] buf); + /** * Delete C++ Tenor object pointed by the input pointer, which is presented by a * long value. -- GitLab