未验证 提交 71aa1b49 编写于 作者: Y yiicy 提交者: GitHub

[JAVA API]java tensor api setData and getData support Int type, test=develop (#2583)

上级 a2f981a4
...@@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( ...@@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
return JNI_TRUE; return JNI_TRUE;
} }
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *env, jobject jtensor, jintArray buf) {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
if (tensor == nullptr || (*tensor == nullptr)) {
return JNI_FALSE;
}
int64_t buf_size = (int64_t)env->GetArrayLength(buf);
if (buf_size != product((*tensor)->shape())) {
return JNI_FALSE;
}
int32_t *input = (*tensor)->mutable_data<int32_t>();
env->GetIntArrayRegion(buf, 0, buf_size, input);
return JNI_TRUE;
}
JNIEXPORT jfloatArray JNICALL JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) { if (is_const_tensor(env, jtensor)) {
...@@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) { ...@@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) {
} }
} }
JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
std::unique_ptr<const Tensor> *tensor =
get_read_only_tensor_pointer(env, jtensor);
return cpp_array_to_jintarray(
env, (*tensor)->data<int32_t>(), product((*tensor)->shape()));
} else {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
return cpp_array_to_jintarray(
env, (*tensor)->data<int32_t>(), product((*tensor)->shape()));
}
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(
JNIEnv *env, jobject jtensor, jlong java_pointer) { JNIEnv *env, jobject jtensor, jlong java_pointer) {
if (java_pointer == 0) { if (java_pointer == 0) {
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <jni.h> #include <jni.h>
/* Header for class com_baidu_paddle_lite_Tensor */ /* Header for class com_baidu_paddle_lite_Tensor */
#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #ifndef LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #define LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject); ...@@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject);
JNIEXPORT jbyteArray JNICALL JNIEXPORT jbyteArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: getIntData
* Signature: ()[I
*/
JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject);
/* /*
* Class: com_baidu_paddle_lite_Tensor * Class: com_baidu_paddle_lite_Tensor
* Method: nativeResize * Method: nativeResize
...@@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( ...@@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F(
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
JNIEnv *, jobject, jbyteArray); JNIEnv *, jobject, jbyteArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeSetData
* Signature: ([I)Z
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *, jobject, jintArray);
/* /*
* Class: com_baidu_paddle_lite_Tensor * Class: com_baidu_paddle_lite_Tensor
* Method: deleteCppTensor * Method: deleteCppTensor
...@@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong); ...@@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #endif // LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
...@@ -108,6 +108,19 @@ public class Tensor { ...@@ -108,6 +108,19 @@ public class Tensor {
return nativeSetData(buf); return nativeSetData(buf);
} }
/**
* Set the tensor int data.
*
* @param buf the int array buffer which will be copied into tensor.
* @return true if set data successfully.
*/
public boolean setData(int[] buf) {
if (readOnly) {
return false;
}
return nativeSetData(buf);
}
/** /**
* @return shape of the tensor as long array. * @return shape of the tensor as long array.
*/ */
...@@ -123,12 +136,19 @@ public class Tensor { ...@@ -123,12 +136,19 @@ public class Tensor {
*/ */
public native byte[] getByteData(); public native byte[] getByteData();
/**
* @return the tensor data as int array.
*/
public native int[] getIntData();
private native boolean nativeResize(long[] dims); private native boolean nativeResize(long[] dims);
private native boolean nativeSetData(float[] buf); private native boolean nativeSetData(float[] buf);
private native boolean nativeSetData(byte[] buf); private native boolean nativeSetData(byte[] buf);
private native boolean nativeSetData(int[] buf);
/** /**
* Delete C++ Tenor object pointed by the input pointer, which is presented by a * Delete C++ Tenor object pointed by the input pointer, which is presented by a
* long value. * long value.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册