未验证 提交 71aa1b49 编写于 作者: Y yiicy 提交者: GitHub

[JAVA API]java tensor api setData and getData support Int type, test=develop (#2583)

上级 a2f981a4
......@@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
return JNI_TRUE;
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *env, jobject jtensor, jintArray buf) {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
if (tensor == nullptr || (*tensor == nullptr)) {
return JNI_FALSE;
}
int64_t buf_size = (int64_t)env->GetArrayLength(buf);
if (buf_size != product((*tensor)->shape())) {
return JNI_FALSE;
}
int32_t *input = (*tensor)->mutable_data<int32_t>();
env->GetIntArrayRegion(buf, 0, buf_size, input);
return JNI_TRUE;
}
JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
......@@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) {
}
}
JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
std::unique_ptr<const Tensor> *tensor =
get_read_only_tensor_pointer(env, jtensor);
return cpp_array_to_jintarray(
env, (*tensor)->data<int32_t>(), product((*tensor)->shape()));
} else {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
return cpp_array_to_jintarray(
env, (*tensor)->data<int32_t>(), product((*tensor)->shape()));
}
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(
JNIEnv *env, jobject jtensor, jlong java_pointer) {
if (java_pointer == 0) {
......
......@@ -16,8 +16,8 @@
#include <jni.h>
/* Header for class com_baidu_paddle_lite_Tensor */
#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#ifndef LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#define LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#ifdef __cplusplus
extern "C" {
#endif
......@@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject);
JNIEXPORT jbyteArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: getIntData
* Signature: ()[I
*/
JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeResize
......@@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F(
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
JNIEnv *, jobject, jbyteArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeSetData
* Signature: ([I)Z
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *, jobject, jintArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: deleteCppTensor
......@@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong);
#ifdef __cplusplus
}
#endif
#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
#endif // LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_
......@@ -108,6 +108,19 @@ public class Tensor {
return nativeSetData(buf);
}
/**
* Set the tensor int data.
*
* @param buf the int array buffer which will be copied into tensor.
* @return true if set data successfully.
*/
public boolean setData(int[] buf) {
if (readOnly) {
return false;
}
return nativeSetData(buf);
}
/**
* @return shape of the tensor as long array.
*/
......@@ -123,12 +136,19 @@ public class Tensor {
*/
public native byte[] getByteData();
/**
* @return the tensor data as int array.
*/
public native int[] getIntData();
private native boolean nativeResize(long[] dims);
private native boolean nativeSetData(float[] buf);
private native boolean nativeSetData(byte[] buf);
private native boolean nativeSetData(int[] buf);
/**
* Delete C++ Tenor object pointed by the input pointer, which is presented by a
* long value.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册