未验证 提交 abd8b1ea 编写于 作者: Z zhupengyang 提交者: GitHub

add java api: getLongData (#3750)

上级 1c415218
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <jni.h>
#include <string>
#include <vector>
#include <string> // NOLINT
#include <vector> // NOLINT
#include "lite/api/light_api.h"
#include "lite/api/paddle_api.h"
......@@ -78,6 +78,14 @@ inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env,
return result;
}
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env,
const int8_t *buf,
int64_t len) {
jbyteArray result = env->NewByteArray(len);
env->SetByteArrayRegion(result, 0, len, buf);
return result;
}
inline jintArray cpp_array_to_jintarray(JNIEnv *env,
const int *buf,
int64_t len) {
......@@ -86,11 +94,11 @@ inline jintArray cpp_array_to_jintarray(JNIEnv *env,
return result;
}
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env,
const int8_t *buf,
inline jlongArray cpp_array_to_jlongarray(JNIEnv *env,
const int64_t *buf,
int64_t len) {
jbyteArray result = env->NewByteArray(len);
env->SetByteArrayRegion(result, 0, len, buf);
jlongArray result = env->NewLongArray(len);
env->SetLongArrayRegion(result, 0, len, buf);
return result;
}
......
......@@ -136,6 +136,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
return JNI_TRUE;
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3L(
JNIEnv *env, jobject jtensor, jlongArray buf) {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
if (tensor == nullptr || (*tensor == nullptr)) {
return JNI_FALSE;
}
int64_t buf_size = (int64_t)env->GetArrayLength(buf);
if (buf_size != product((*tensor)->shape())) {
return JNI_FALSE;
}
int64_t *input = (*tensor)->mutable_data<int64_t>();
env->GetLongArrayRegion(buf, 0, buf_size, input);
return JNI_TRUE;
}
JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
......@@ -178,6 +194,20 @@ Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) {
}
}
JNIEXPORT jlongArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getLongData(JNIEnv *env, jobject jtensor) {
if (is_const_tensor(env, jtensor)) {
std::unique_ptr<const Tensor> *tensor =
get_read_only_tensor_pointer(env, jtensor);
return cpp_array_to_jlongarray(
env, (*tensor)->data<int64_t>(), product((*tensor)->shape()));
} else {
std::unique_ptr<Tensor> *tensor = get_writable_tensor_pointer(env, jtensor);
return cpp_array_to_jlongarray(
env, (*tensor)->data<int64_t>(), product((*tensor)->shape()));
}
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(
JNIEnv *env, jobject jtensor, jlong java_pointer) {
if (java_pointer == 0) {
......
......@@ -57,6 +57,14 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject);
JNIEXPORT jintArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: getLongData
* Signature: ()[L
*/
JNIEXPORT jlongArray JNICALL
Java_com_baidu_paddle_lite_Tensor_getLongData(JNIEnv *, jobject);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeResize
......@@ -89,6 +97,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B(
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I(
JNIEnv *, jobject, jintArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: nativeSetData
* Signature: ([L)Z
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3L(
JNIEnv *, jobject, jlongArray);
/*
* Class: com_baidu_paddle_lite_Tensor
* Method: deleteCppTensor
......
......@@ -141,6 +141,11 @@ public class Tensor {
*/
public native int[] getIntData();
/**
* @return the tensor data as long array.
*/
public native long[] getLongData();
private native boolean nativeResize(long[] dims);
private native boolean nativeSetData(float[] buf);
......@@ -149,6 +154,8 @@ public class Tensor {
private native boolean nativeSetData(int[] buf);
private native boolean nativeSetData(long[] buf);
/**
* Delete C++ Tenor object pointed by the input pointer, which is presented by a
* long value.
......
......@@ -18,6 +18,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册