提交 5fc18b98 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5244 Get different types of data in tensor

Merge pull request !5244 from yeyunpeng2020/master_tensor
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
*
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -18,6 +18,8 @@ package com.mindspore.lite;
import android.util.Log;
import java.nio.ByteBuffer;
public class MSTensor {
private long tensorPtr;
......@@ -29,7 +31,7 @@ public class MSTensor {
this.tensorPtr = tensorPtr;
}
public boolean init (int dataType, int[] shape) {
public boolean init(int dataType, int[] shape) {
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
return this.tensorPtr != 0;
}
......@@ -50,18 +52,30 @@ public class MSTensor {
this.setDataType(this.tensorPtr, dataType);
}
public byte[] getData() {
return this.getData(this.tensorPtr);
public byte[] getBtyeData() {
return this.getByteData(this.tensorPtr);
}
public float[] getFloatData() {
return decodeBytes(this.getData(this.tensorPtr));
return this.getFloatData(this.tensorPtr);
}
public int[] getIntData() {
return this.getIntData(this.tensorPtr);
}
public long[] getLongData() {
return this.getLongData(this.tensorPtr);
}
public void setData(byte[] data) {
this.setData(this.tensorPtr, data, data.length);
}
public void setData(ByteBuffer data) {
this.setByteBufferData(this.tensorPtr, data);
}
public long size() {
return this.size(this.tensorPtr);
}
......@@ -82,13 +96,13 @@ public class MSTensor {
}
int size = bytes.length / 4;
float[] ret = new float[size];
for (int i = 0; i < size; i=i+4) {
for (int i = 0; i < size; i = i + 4) {
int accNum = 0;
accNum = accNum | (bytes[i] & 0xff) << 0;
accNum = accNum | (bytes[i+1] & 0xff) << 8;
accNum = accNum | (bytes[i+2] & 0xff) << 16;
accNum = accNum | (bytes[i+3] & 0xff) << 24;
ret[i/4] = Float.intBitsToFloat(accNum);
accNum = accNum | (bytes[i + 1] & 0xff) << 8;
accNum = accNum | (bytes[i + 2] & 0xff) << 16;
accNum = accNum | (bytes[i + 3] & 0xff) << 24;
ret[i / 4] = Float.intBitsToFloat(accNum);
}
return ret;
}
......@@ -103,10 +117,18 @@ public class MSTensor {
private native boolean setDataType(long tensorPtr, int dataType);
private native byte[] getData(long tensorPtr);
private native byte[] getByteData(long tensorPtr);
private native long[] getLongData(long tensorPtr);
private native int[] getIntData(long tensorPtr);
private native float[] getFloatData(long tensorPtr);
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
private native long size(long tensorPtr);
private native int elementsNum(long tensorPtr);
......
......@@ -99,8 +99,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy
return ret == data_type;
}
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -113,16 +113,95 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData
return env->NewByteArray(0);
}
auto *float_local_data = reinterpret_cast<float *>(ms_tensor_ptr->MutableData());
for (size_t i = 0; i < ms_tensor_ptr->ElementsNum() && i < 5; i++) {
MS_LOGE("data[%zu] = %f", i, float_local_data[i]);
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeUInt8) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewByteArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
return ret;
}
extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLongData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewLongArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jlong *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewLongArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt64) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewLongArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewLongArray(local_data_size);
env->SetLongArrayRegion(ret, 0, local_data_size, local_data);
return ret;
}
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewIntArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jint *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewIntArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt32) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewIntArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewIntArray(local_data_size);
env->SetIntArrayRegion(ret, 0, local_data_size, local_data);
return ret;
}
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFloatData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewFloatArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jfloat *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewFloatArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeFloat32) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewFloatArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewFloatArray(local_data_size);
env->SetFloatArrayRegion(ret, 0, local_data_size, local_data);
return ret;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
......@@ -143,6 +222,36 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(J
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setByteBufferData(JNIEnv *env, jobject thiz,
jlong tensor_ptr,
jobject buffer) {
jbyte *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); // get buffer poiter
jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity
if (!p_data) {
MS_LOGE("GetDirectBufferAddress return null");
return NULL;
}
jbyteArray data = env->NewByteArray(data_len); // create byte[]
env->SetByteArrayRegion(data, 0, data_len, p_data); // copy data to byte[]
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
if (data_len != ms_tensor_ptr->Size()) {
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->Size());
return static_cast<jboolean>(false);
}
jboolean is_copy = false;
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData();
memcpy(local_data, data_arr, data_len);
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册