未验证 提交 abd8b1ea 编写于 作者: Z zhupengyang 提交者: GitHub

add java api: getLongData (#3750)

上级 1c415218
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <jni.h> #include <jni.h>
#include <string> #include <string> // NOLINT
#include <vector> #include <vector> // NOLINT
#include "lite/api/light_api.h" #include "lite/api/light_api.h"
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
...@@ -78,6 +78,14 @@ inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, ...@@ -78,6 +78,14 @@ inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env,
return result; return result;
} }
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env,
const int8_t *buf,
int64_t len) {
jbyteArray result = env->NewByteArray(len);
env->SetByteArrayRegion(result, 0, len, buf);
return result;
}
inline jintArray cpp_array_to_jintarray(JNIEnv *env, inline jintArray cpp_array_to_jintarray(JNIEnv *env,
const int *buf, const int *buf,
int64_t len) { int64_t len) {
...@@ -86,11 +94,11 @@ inline jintArray cpp_array_to_jintarray(JNIEnv *env, ...@@ -86,11 +94,11 @@ inline jintArray cpp_array_to_jintarray(JNIEnv *env,
return result; return result;
} }
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, inline jlongArray cpp_array_to_jlongarray(JNIEnv *env,
const int8_t *buf, const int64_t *buf,
int64_t len) { int64_t len) {
jbyteArray result = env->NewByteArray(len); jlongArray result = env->NewLongArray(len);
env->SetByteArrayRegion(result, 0, len, buf); env->SetLongArrayRegion(result, 0, len, buf);
return result; return result;
} }
......
...@@ -136,6 +136,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( ...@@ -136,6 +136,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
return JNI_TRUE; return JNI_TRUE;
} }
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3L(
JNIEnv *env, jobject jtensor, jlongArray buf) {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
if (tensor == nullptr || (*tensor == nullptr)) {
return JNI_FALSE;
}
int64_t buf_size = (int64_t)env->GetArrayLength(buf);
if (buf_size != product((*tensor)->shape())) {
return JNI_FALSE;
}
int64_t *input = (*tensor)->mutable_data<int64_t>();
env->GetLongArrayRegion(buf, 0, buf_size, input);
return JNI_TRUE;
}
JNIEXPORT jfloatArray JNICALL JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) { if (is_const_tensor(env, jtensor)) {
...@@ -178,6 +194,20 @@ Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) { ...@@ -178,6 +194,20 @@ Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) {
} }
} }
JNIEXPORT jlongArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getLongData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
std::unique_ptr<const Tensor> *tensor =
get_read_only_tensor_pointer(env, jtensor);
return cpp_array_to_jlongarray(
env, (*tensor)->data<int64_t>(), product((*tensor)->shape()));
} else {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
return cpp_array_to_jlongarray(
env, (*tensor)->data<int64_t>(), product((*tensor)->shape()));
}
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(
JNIEnv *env, jobject jtensor, jlong java_pointer) { JNIEnv *env, jobject jtensor, jlong java_pointer) {
if (java_pointer == 0) { if (java_pointer == 0) {
......
...@@ -57,6 +57,14 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); ...@@ -57,6 +57,14 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject);
JNIEXPORT jintArray JNICALL JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject); Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: getLongData
* Signature: ()[L
*/
JNIEXPORT jlongArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getLongData(JNIEnv *, jobject);
/* /*
* Class: com_baidu_paddle_lite_Tensor * Class: com_baidu_paddle_lite_Tensor
* Method: nativeResize * Method: nativeResize
...@@ -89,6 +97,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( ...@@ -89,6 +97,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *, jobject, jintArray); JNIEnv *, jobject, jintArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeSetData
* Signature: ([L)Z
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3L(
JNIEnv *, jobject, jlongArray);
/* /*
* Class: com_baidu_paddle_lite_Tensor * Class: com_baidu_paddle_lite_Tensor
* Method: deleteCppTensor * Method: deleteCppTensor
......
...@@ -141,6 +141,11 @@ public class Tensor { ...@@ -141,6 +141,11 @@ public class Tensor {
*/ */
public native int[] getIntData(); public native int[] getIntData();
/**
* @return the tensor data as long array.
*/
public native long[] getLongData();
private native boolean nativeResize(long[] dims); private native boolean nativeResize(long[] dims);
private native boolean nativeSetData(float[] buf); private native boolean nativeSetData(float[] buf);
...@@ -149,6 +154,8 @@ public class Tensor { ...@@ -149,6 +154,8 @@ public class Tensor {
private native boolean nativeSetData(int[] buf); private native boolean nativeSetData(int[] buf);
private native boolean nativeSetData(long[] buf);
/** /**
* Delete C++ Tenor object pointed by the input pointer, which is presented by a * Delete C++ Tenor object pointed by the input pointer, which is presented by a
* long value. * long value.
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include <fstream> #include <fstream>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册