提交 2809ee63 编写于 作者: Y yeyunpeng

add api : get output by tensor name in c++ and java

上级 9ab4f123
......@@ -38,7 +38,7 @@ enum CpuBindMode {
typedef enum {
DT_CPU, /**< CPU device type */
DT_GPU, /**< GPU device type */
DT_NPU /**< NPU device type */
DT_NPU /**< NPU device type, not supported yet */
} DeviceType;
/// \brief DeviceContext defined for holding DeviceType.
......
......@@ -86,17 +86,34 @@ class MS_API LiteSession {
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
/// \brief Get output MindSpore Lite MSTensors of model.
/// \brief Get output MindSpore Lite MSTensors of model mapped by node name.
///
/// \return The map of output node name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputMapByNode() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \return The vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const = 0;
virtual std::vector<tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const = 0;
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
///
/// \return The map of output tensor name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputMapByTensor() const = 0;
/// \brief Get name of output tensors of model compiled by this session.
///
/// \return The vector of string as output tensor names in order.
virtual std::vector<std::string> GetOutputTensorNames() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
///
/// \param[in] tensor_name Define tensor name.
///
/// \return Pointer of MindSpore Lite MSTensor.
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const = 0;
/// \brief Resize inputs shape.
///
......
......@@ -18,7 +18,7 @@ cd ${TOP_PATH}/output/
rm -rf mindspore-lite-0.7.0
tar -zxvf mindspore-lite-0.7.0-runtime-arm64-cpu.tar.gz
mkdir -p ${BASE_PATH}/lib/
cp ${TOP_PATH}/output/mindspore-lite-0.7.0/lib/libmindspore-lite.so ${BASE_PATH}/lib/
cp ${TOP_PATH}/output/mindspore-lite-0.7.0-runtime-arm64-cpu/lib/libmindspore-lite.so ${BASE_PATH}/lib/
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/
# build jni so
......
......@@ -76,8 +76,8 @@ public class LiteSession {
return tensors;
}
public Map<String, List<MSTensor>> getOutputs() {
Map<String, List<Long>> ret = this.getOutputs(this.sessionPtr);
public Map<String, List<MSTensor>> getOutputMapByNode() {
Map<String, List<Long>> ret = this.getOutputMapByNode(this.sessionPtr);
Map<String, List<MSTensor>> tensorMap = new HashMap<>();
Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet();
for (Map.Entry<String, List<Long>> entry : entrySet) {
......@@ -93,8 +93,8 @@ public class LiteSession {
return tensorMap;
}
public List<MSTensor> getOutputsByName(String nodeName) {
List<Long> ret = this.getOutputsByName(this.sessionPtr, nodeName);
public List<MSTensor> getOutputsByNodeName(String nodeName) {
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
......@@ -103,6 +103,27 @@ public class LiteSession {
return tensors;
}
public Map<String, MSTensor> getOutputMapByTensor() {
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
Map<String, MSTensor> tensorMap = new HashMap<>();
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
for (Map.Entry<String, Long> entry : entrySet) {
String name = entry.getKey();
Long msTensorAddr = entry.getValue();
tensorMap.put(name, new MSTensor(msTensorAddr));
}
return tensorMap;
}
public List<String> getOutputTensorNames() {
return getOutputTensorNames(this.sessionPtr);
}
public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
return new MSTensor(tensor_addr);
}
public void free() {
this.free(this.sessionPtr);
this.sessionPtr = 0;
......@@ -120,9 +141,15 @@ public class LiteSession {
private native List<Long> getInputsByName(long sessionPtr, String nodeName);
private native Map<String, List<Long>> getOutputs(long sessionPtr);
private native Map<String, List<Long>> getOutputMapByNode(long sessionPtr);
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
private native List<String> getOutputTensorNames(long sessionPtr);
private native List<Long> getOutputsByName(long sessionPtr, String nodeName);
private native Long getOutputByTensorName(long sessionPtr, String tensorName);
private native void free(long sessionPtr);
}
......@@ -80,6 +80,11 @@ public class Model {
return ret;
}
public boolean loadModel(String modelPath) {
this.modelPtr = loadModelByPath(modelPath);
return this.modelPtr != 0;
}
public void free() {
this.free(this.modelPtr);
this.modelPtr = 0;
......@@ -87,5 +92,7 @@ public class Model {
private native long loadModel(MappedByteBuffer buffer);
private native long loadModelByPath(String modelPath);
private native void free(long modelPtr);
}
......@@ -14,12 +14,11 @@
* limitations under the License.
*/
#include "common/jni_utils.h"
#include <cstring>
char *JstringToChar(JNIEnv *env, jstring jstr) {
char *rtn = NULL;
char *rtn = nullptr;
jclass clsstring = env->FindClass("java/lang/String");
jstring strencode = env->NewStringUTF("GB2312");
jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
......
......@@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "common/jni_utils.h"
......@@ -22,7 +21,7 @@
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSession(JNIEnv *env, jobject thiz,
jlong context_ptr) {
jlong context_ptr) {
auto *pointer = reinterpret_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
......@@ -38,8 +37,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compileGraph(JNIEnv *env, jobject thiz,
jlong session_ptr,
jlong model_ptr) {
jlong session_ptr,
jlong model_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
......@@ -58,7 +57,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compil
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread(JNIEnv *env, jobject thiz,
jlong session_ptr, jboolean if_bind) {
jlong session_ptr, jboolean if_bind) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
......@@ -69,7 +68,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGraph(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
......@@ -81,7 +80,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGra
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
......@@ -104,8 +103,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
......@@ -127,8 +126,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByNode(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
......@@ -140,7 +139,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputs();
auto outputs = lite_session_ptr->GetOutputMapByNode();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
jclass array_list = env->FindClass("java/util/ArrayList");
......@@ -159,9 +158,9 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return hash_map;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
......@@ -175,7 +174,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetOutputsByName(JstringToChar(env, node_name));
auto inputs = lite_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
......@@ -183,8 +182,66 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByTensor(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
jmethodID hash_map_put =
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputMapByTensor();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
for (auto output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensor = output_iter.second;
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
}
return hash_map;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputTensorNames(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output_names = lite_session_ptr->GetOutputTensorNames();
for (auto output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
}
return ret;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutputByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output = lite_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name));
return jlong(output);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
......
......@@ -14,9 +14,10 @@
* limitations under the License.
*/
#include <jni.h>
#include <fstream>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/model.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) {
......@@ -38,6 +39,46 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEn
return reinterpret_cast<jlong>(model);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz,
jstring model_path) {
auto model_path_char = JstringToChar(env, model_path);
if (nullptr == model_path_char) {
MS_LOGE("model_path_char is nullptr");
return reinterpret_cast<jlong>(nullptr);
}
std::ifstream ifs(model_path_char);
if (!ifs.good()) {
MS_LOGE("file: %s is not exist", model_path_char);
return reinterpret_cast<jlong>(nullptr);
}
if (!ifs.is_open()) {
MS_LOGE("file: %s open failed", model_path_char);
return reinterpret_cast<jlong>(nullptr);
}
ifs.seekg(0, std::ios::end);
auto size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[size]);
if (buf == nullptr) {
MS_LOGE("malloc buf failed, file: %s", model_path_char);
ifs.close();
return reinterpret_cast<jlong>(nullptr);
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), size);
ifs.close();
delete[](model_path_char);
MS_LOGD("Start Loading model");
auto model = mindspore::lite::Model::Import(buf.get(), size);
if (model == nullptr) {
MS_LOGE("Import model failed");
return reinterpret_cast<jlong>(nullptr);
}
return reinterpret_cast<jlong>(model);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
......
......@@ -14,15 +14,14 @@
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/ms_tensor.h"
#include "ir/dtype/type_id.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTensor(JNIEnv *env, jobject thiz,
jint data_type, jintArray shape,
jint shape_len) {
jint data_type, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
std::vector<int> local_shape(shape_len);
......@@ -39,7 +38,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTens
}
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -59,8 +58,8 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
......@@ -78,7 +77,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -89,7 +88,7 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(J
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jint data_type) {
jlong tensor_ptr, jint data_type) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -101,7 +100,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy
}
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -113,6 +112,11 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData
MS_LOGD("Tensor has no data");
return env->NewByteArray(0);
}
auto *float_local_data = reinterpret_cast<float *>(ms_tensor_ptr->MutableData());
for (size_t i = 0; i < ms_tensor_ptr->ElementsNum() && i < 5; i++) {
MS_LOGE("data[%zu] = %f", i, float_local_data[i]);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
......@@ -120,8 +124,8 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......@@ -150,7 +154,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_elementsNum(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
......
......@@ -49,8 +49,6 @@ if (BUILD_MINDDATA)
target_link_libraries(mindspore-lite minddata-eager minddata-lite)
endif ()
add_subdirectory(ops)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32))
add_custom_command(TARGET mindspore-lite POST_BUILD
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
......
......@@ -27,18 +27,15 @@ namespace lite {
std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph) {
MS_ASSERT(nullptr != meta_graph);
std::vector<size_t> ret;
for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) {
auto input_index = meta_graph->inputIndex()->GetAs<uint32_t>(i);
for (auto graph_in_index : *(meta_graph->inputIndex())) {
for (size_t j = 0; j < meta_graph->nodes()->size(); j++) {
auto *cNode = meta_graph->nodes()->GetAs<schema::CNode>(j);
MS_ASSERT(nullptr != cNode);
MS_ASSERT(nullptr != cNode->inputIndex());
for (size_t k = 0; k < cNode->inputIndex()->size(); k++) {
if (cNode->inputIndex()->GetAs<uint32_t>(k) == input_index) {
if (!IsContain<size_t>(ret, j)) {
ret.emplace_back(j);
}
break;
if (std::any_of(cNode->inputIndex()->begin(), cNode->inputIndex()->end(),
[&](const uint32_t &node_in_index) { return node_in_index == graph_in_index; })) {
if (!IsContain<size_t>(ret, j)) {
ret.emplace_back(j);
}
}
}
......@@ -49,33 +46,20 @@ std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph) {
std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph) {
MS_ASSERT(nullptr != meta_graph);
std::vector<size_t> ret;
for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) {
auto output_index = meta_graph->outputIndex()->GetAs<uint32_t>(i);
for (auto graph_out_index : *(meta_graph->outputIndex())) {
for (size_t j = 0; j < meta_graph->nodes()->size(); j++) {
auto *cNode = meta_graph->nodes()->GetAs<schema::CNode>(j);
MS_ASSERT(nullptr != cNode);
for (size_t k = 0; k < cNode->outputIndex()->size(); k++) {
if (cNode->outputIndex()->GetAs<uint32_t>(k) == output_index) {
if (!IsContain<size_t>(ret, j)) {
ret.emplace_back(j);
}
break;
MS_ASSERT(nullptr != cNode->outputIndex());
if (std::any_of(cNode->outputIndex()->begin(), cNode->outputIndex()->end(),
[&](const uint32_t &node_out_index) { return node_out_index == graph_out_index; })) {
if (!IsContain<size_t>(ret, j)) {
ret.emplace_back(j);
}
}
}
}
return ret;
}
// NODE_ID OpNode::ID() { return id; }
//
// void OpNode::AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); }
//
// void OpNode::AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); }
//
// std::unordered_set<NODE_ID> OpNode::GetAllInEdges() { return inEdges; }
//
// std::unordered_set<NODE_ID> OpNode::GetAllOutEdges() { return outEdges; }
} // namespace lite
} // namespace mindspore
......@@ -16,6 +16,7 @@
#include "src/lite_session.h"
#include <vector>
#include <utility>
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/scheduler.h"
......@@ -81,6 +82,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
}
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->inputs_.empty());
MS_ASSERT(meta_graph != nullptr);
......@@ -93,7 +95,7 @@ void LiteSession::InitGraphInputTensors(const lite::Model *model) {
}
}
void LiteSession::InitGraphInputMSTensors(const lite::Model *model) {
void LiteSession::InitGraphInputMSTensors() {
MS_ASSERT(this->input_vec_.empty());
for (auto &input_tensor : this->inputs_) {
MS_ASSERT(input_tensor != nullptr);
......@@ -102,6 +104,7 @@ void LiteSession::InitGraphInputMSTensors(const lite::Model *model) {
}
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->outputs_.empty());
MS_ASSERT(meta_graph != nullptr);
......@@ -115,6 +118,7 @@ void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
}
void LiteSession::InitGraphInputMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->input_map_.empty());
MS_ASSERT(meta_graph != nullptr);
......@@ -145,9 +149,10 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
}
}
void LiteSession::InitGraphOutputMap(const lite::Model *model) {
void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->output_map_.empty());
MS_ASSERT(this->output_node_map_.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (auto out_node_index : graph_output_node_indexes) {
......@@ -171,17 +176,44 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) {
MS_ASSERT(out_tensor != nullptr);
auto *ms_tensor = new tensor::LiteTensor(out_tensor);
MS_ASSERT(nullptr != ms_tensor);
this->output_map_[out_node->name()->str()].emplace_back(ms_tensor);
this->output_node_map_[out_node->name()->str()].emplace_back(ms_tensor);
}
}
}
void LiteSession::InitGraphOutputTensorNames(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->output_tensor_names_.empty());
MS_ASSERT(meta_graph != nullptr);
for (auto output_index : *meta_graph->outputIndex()) {
this->output_tensor_names_.emplace_back(std::to_string(output_index));
}
}
void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->output_tensor_map_.empty());
MS_ASSERT(meta_graph != nullptr);
for (auto graph_out_index : *(meta_graph->outputIndex())) {
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
MS_ASSERT(out_tensor != nullptr);
auto *ms_tensor = new tensor::LiteTensor(out_tensor);
MS_ASSERT(nullptr != ms_tensor);
this->output_tensor_map_.insert(std::make_pair(std::to_string(graph_out_index), ms_tensor));
}
}
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model);
InitGraphInputMSTensors(model);
InitGraphInputMSTensors();
InitGraphOutputTensors(model);
InitGraphInputMap(model);
InitGraphOutputMap(model);
InitGraphOutputNodeMap(model);
InitGraphOutputTensorNames(model);
InitGraphOutputTensorMap(model);
}
int LiteSession::CompileGraph(Model *model) {
......@@ -223,10 +255,6 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
}
}
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputs() const {
return this->output_map_;
}
int LiteSession::Init(Context *context) {
MS_EXCEPTION_IF_NULL(context);
this->context_ = new (std::nothrow) Context(context->thread_num_, context->allocator, context->device_ctx_);
......@@ -276,14 +304,19 @@ LiteSession::~LiteSession() {
iter.second.clear();
}
input_map_.clear();
for (auto iter : this->output_map_) {
for (auto iter : this->output_node_map_) {
for (auto *ms_tensor : iter.second) {
((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr);
delete ms_tensor;
}
iter.second.clear();
}
output_map_.clear();
output_node_map_.clear();
for (auto iter : this->output_tensor_map_) {
((tensor::LiteTensor *)(iter.second))->SetTensorImpl(nullptr);
delete (iter.second);
}
output_tensor_map_.clear();
for (auto *kernel : kernels_) {
delete kernel;
}
......@@ -309,16 +342,35 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(const st
return ret->second;
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputsByName(const std::string &name) const {
auto ret = output_map_.find(name);
if (ret == output_map_.end()) {
MS_LOG(WARNING) << "Node " << name << " is not an output node";
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputMapByNode() const {
return this->output_node_map_;
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputsByNodeName(const std::string &node_name) const {
auto ret = output_node_map_.find(node_name);
if (ret == output_node_map_.end()) {
MS_LOG(WARNING) << "Node " << node_name << " is not an output node";
std::vector<mindspore::tensor::MSTensor *> empty_ret;
return empty_ret;
}
return ret->second;
}
std::vector<std::string> LiteSession::GetOutputTensorNames() const { return this->output_tensor_names_; }
mindspore::tensor::MSTensor *LiteSession::GetOutputByTensorName(const std::string &tensor_name) const {
auto ret = output_tensor_map_.find(tensor_name);
if (ret == output_tensor_map_.end()) {
MS_LOG(WARNING) << "Tensor " << tensor_name << " is not an output node";
return nullptr;
}
return ret->second;
}
std::unordered_map<std::string, mindspore::tensor::MSTensor *> LiteSession::GetOutputMapByTensor() const {
return this->output_tensor_map_;
}
int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs) {
if (inputs.size() != inputs_.size()) {
MS_LOG(ERROR) << "Inputs size " << inputs.size() << " is not equal to " << inputs_.size();
......
......@@ -50,9 +50,15 @@ class LiteSession : public session::LiteSession {
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const override;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputMapByNode() const override;
std::vector<mindspore::tensor::MSTensor *> GetOutputsByName(const std::string &name) const override;
std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override;
std::vector<std::string> GetOutputTensorNames() const override;
mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputMapByTensor() const override;
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs) override;
......@@ -63,13 +69,17 @@ class LiteSession : public session::LiteSession {
void InitGraphInputTensors(const lite::Model *model);
void InitGraphInputMSTensors(const lite::Model *model);
void InitGraphInputMSTensors();
void InitGraphOutputTensors(const lite::Model *model);
void InitGraphInputMap(const lite::Model *model);
void InitGraphOutputMap(const lite::Model *model);
void InitGraphOutputNodeMap(const lite::Model *model);
void InitGraphOutputTensorNames(const lite::Model *model);
void InitGraphOutputTensorMap(const lite::Model *model);
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs);
......@@ -86,7 +96,11 @@ class LiteSession : public session::LiteSession {
// graph input node name -- input tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_;
// graph output node name -- output tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> output_map_;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> output_node_map_;
std::vector<std::string> output_tensor_names_;
// graph output tensor name -- output tensor
std::unordered_map<std::string, mindspore::tensor::MSTensor *> output_tensor_map_;
Executor *executor = nullptr;
};
} // namespace lite
......
......@@ -16,98 +16,6 @@
#include "src/ops/primitive_c.h"
#include "include/model.h"
#include "src/ops/unique.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/conv2d.h"
#include "src/ops/roi_pooling.h"
#include "src/ops/topk.h"
#include "src/ops/broadcast_to.h"
#include "src/ops/unsqueeze.h"
#include "src/ops/unstack.h"
#include "src/ops/depth_to_space.h"
#include "src/ops/batch_to_space.h"
#include "src/ops/prior_box.h"
#include "src/ops/lstm.h"
#include "src/ops/softmax.h"
#include "src/ops/activation.h"
#include "src/ops/deconv2d.h"
#include "src/ops/reduce.h"
#include "src/ops/pooling.h"
#include "src/ops/fused_batchnorm.h"
#include "src/ops/batch_norm.h"
#include "src/ops/power.h"
#include "src/ops/range.h"
#include "src/ops/add.h"
#include "src/ops/sub.h"
#include "src/ops/div.h"
#include "src/ops/bias_add.h"
#include "src/ops/expand_dims.h"
#include "src/ops/full_connection.h"
#include "src/ops/shape.h"
#include "src/ops/elu.h"
#include "src/ops/embedding_lookup.h"
#include "src/ops/quant_dtype_cast.h"
#include "src/ops/matmul.h"
#include "src/ops/resize.h"
#include "src/ops/tile.h"
#include "src/ops/one_hot.h"
#include "src/ops/space_to_depth.h"
#include "src/ops/split.h"
#include "src/ops/argmax.h"
#include "src/ops/argmin.h"
#include "src/ops/cast.h"
#include "src/ops/reshape.h"
#include "src/ops/scale.h"
#include "src/ops/concat.h"
#include "src/ops/nchw2nhwc.h"
#include "src/ops/slice.h"
#include "src/ops/squeeze.h"
#include "src/ops/flatten.h"
#include "src/ops/mean.h"
#include "src/ops/nhwc2nchw.h"
#include "src/ops/stack.h"
#include "src/ops/crop.h"
#include "src/ops/addn.h"
#include "src/ops/gather.h"
#include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h"
#include "src/ops/prelu.h"
#include "src/ops/caffe_p_relu.h"
#include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h"
#include "src/ops/mul.h"
#include "src/ops/eltwise.h"
#include "src/ops/fill.h"
#include "src/ops/transpose.h"
#include "src/ops/log.h"
#include "src/ops/abs.h"
#include "src/ops/sin.h"
#include "src/ops/cos.h"
#include "src/ops/sqrt.h"
#include "src/ops/square.h"
#include "src/ops/exp.h"
#include "src/ops/rsqrt.h"
#include "src/ops/maximum.h"
#include "src/ops/minimum.h"
#include "src/ops/strided_slice.h"
#include "src/ops/reverse.h"
#include "src/ops/logical_and.h"
#include "src/ops/logical_or.h"
#include "src/ops/logical_not.h"
#include "src/ops/floor_div.h"
#include "src/ops/floor_mod.h"
#include "src/ops/equal.h"
#include "src/ops/not_equal.h"
#include "src/ops/less.h"
#include "src/ops/less_equal.h"
#include "src/ops/greater_equal.h"
#include "src/ops/greater.h"
#include "src/ops/floor.h"
#include "src/ops/squared_difference.h"
#include "src/ops/ceil.h"
#include "src/ops/round.h"
#include "utils/log_adapter.h"
namespace mindspore::lite {
......
......@@ -109,6 +109,7 @@
#include "src/ops/round.h"
#include "src/ops/unique.h"
#include "src/ops/zeros_like.h"
#include "src/ops/return.h"
#include "src/ops/where.h"
#include "src/ops/scatter_nd.h"
#include "src/ops/constant_of_shape.h"
......@@ -122,7 +123,7 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; }
void PrimitiveC::SetPrimitiveT(schema::PrimitiveT *prim) { this->primitive_ = prim; }
void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; }
void PrimitiveC::SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
......@@ -155,21 +156,21 @@ std::shared_ptr<PrimitiveC> GetReturnPrim() {
auto return_primitiveT = new schema::PrimitiveT;
return_primitiveT->value.type = schema::PrimitiveType_Return;
return_primitiveT->value.value = new schema::ReturnT;
return std::make_shared<PrimitiveC>(return_primitiveT);
return std::make_shared<Return>(return_primitiveT);
}
std::shared_ptr<PrimitiveC> GetMakeTuplePrim() {
auto make_tuple_primitiveT = new schema::PrimitiveT;
make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple;
make_tuple_primitiveT->value.value = new schema::MakeTupleT;
return std::make_shared<PrimitiveC>(make_tuple_primitiveT);
return std::make_shared<MakeTuple>(make_tuple_primitiveT);
}
std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
auto tuple_get_item_primitiveT = new schema::PrimitiveT();
tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem;
tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT;
return std::make_shared<PrimitiveC>(tuple_get_item_primitiveT);
return std::make_shared<TupleGetItem>(tuple_get_item_primitiveT);
}
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
......@@ -439,7 +440,7 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
}
#else
PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) {
MS_EXCEPTION_IF_NULL(primitive);
MS_ASSERT(primitive);
auto op_type = primitive->value_type();
switch (op_type) {
case schema::PrimitiveType_SoftMax:
......@@ -646,6 +647,9 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
#endif
int PrimitiveC::Type() const {
if (this->primitive_ == nullptr) {
return schema::PrimitiveType_NONE;
}
#ifdef PRIMITIVE_WRITEABLE
return this->primitive_->value.type;
#else
......
......@@ -61,15 +61,13 @@ class PrimitiveC : public mindspore::Primitive {
MS_DECLARE_PARENT(PrimitiveC, Primitive);
~PrimitiveC() override {
// delete this->primitive_;
}
~PrimitiveC() override { delete this->primitive_; }
int Type() const;
schema::PrimitiveT *GetPrimitiveT() const;
void SetPrimitiveT(schema::PrimitiveT *prim);
void ClearPrimitiveT();
bool operator==(const Value &rhs) const {
if (rhs.isa<PrimitiveC>()) {
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/return.h"
#include <memory>
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Return::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Return;
}
if (this->primitive_->value.type != schema::PrimitiveType_Return) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ReturnT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
} // namespace
int Return::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
return RET_ERROR;
}
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (this->primitive_ == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_RETURN_H_
#define LITE_MINDSPORE_LITE_C_OPS_RETURN_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Return : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Return, PrimitiveC);
Return() = default;
explicit Return(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else
explicit Return(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_RETURN_H_
......@@ -72,10 +72,10 @@ int FullconnectionCPUKernel::ReSize() {
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float));
fc_param_->a_const_ = false;
fc_param_->b_const_ = false;
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
fc_param_->a_const_ = (in_tensors_[0]->Data() != nullptr);
fc_param_->b_const_ = (in_tensors_[1]->Data() != nullptr);
if (fc_param_->a_const_) InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
if (fc_param_->b_const_) InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
return RET_OK;
}
......@@ -87,27 +87,11 @@ int FullconnectionCPUKernel::Init() {
}
void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
if (fc_param_->a_const_ == true) {
return;
}
if (src_ptr == nullptr) {
return;
}
fc_param_->a_const_ = true;
RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
return;
}
void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
if (fc_param_->b_const_ == true) {
return;
}
if (src_ptr == nullptr) {
return;
}
fc_param_->b_const_ = true;
RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
return;
}
int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
......@@ -142,8 +126,8 @@ int FullconnectionCPUKernel::Run() {
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
c_r_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
InitMatrixA(a_ptr, a_c12_ptr_);
InitMatrixB(b_ptr, b_r8_ptr_);
if (!fc_param_->a_const_) InitMatrixA(a_ptr, a_c12_ptr_);
if (!fc_param_->b_const_) InitMatrixB(b_ptr, b_r8_ptr_);
LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_);
......
......@@ -130,7 +130,7 @@ TEST_F(InferTest, TestConvNode) {
memcpy(data, input_data, input_size);
ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
auto outputs = session->GetOutputMapByNode();
ASSERT_EQ(outputs.size(), 1);
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
......@@ -222,7 +222,7 @@ TEST_F(InferTest, TestAddNode) {
(void)inTensor1->MutableData();
ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
auto outputs = session->GetOutputMapByNode();
ASSERT_EQ(outputs.size(), 1);
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
......@@ -325,7 +325,7 @@ TEST_F(InferTest, TestParallelExecutor) {
(void)inTensor1->MutableData();
ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
auto outputs = session->GetOutputMapByNode();
ASSERT_EQ(outputs.size(), 1);
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
......@@ -362,7 +362,7 @@ TEST_F(InferTest, TestModel) {
(void)inTensor->MutableData();
ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
auto outputs = session->GetOutputMapByNode();
MS_LOG(INFO) << "Passed";
}
......
......@@ -168,7 +168,7 @@ void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_p
}
}
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
for (const auto &cnode : cnodes) {
......@@ -177,17 +177,17 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return nullptr;
}
auto primT = primitive_c->GetPrimitiveT();
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
primitive_c->Type() == schema::PrimitiveType_MakeTuple) {
continue;
}
RemoveIfMakeTuple(cnode);
auto primT = primitive_c->GetPrimitiveT();
auto node = std::make_unique<schema::CNodeT>();
if (node == nullptr) {
MS_LOG(ERROR) << "object failed to be constructed";
return nullptr;
MS_LOG(ERROR) << "object failed to be constructed";
return nullptr;
}
if (primT->value.type == schema::PrimitiveType_Return) {
node->name = "return_node";
......@@ -208,7 +208,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "ConvertQuantParam failed";
return nullptr;
}
if (!keep_graph) {
primitive_c->ClearPrimitiveT();
}
meta_graphT->nodes.emplace_back(std::move(node));
}
// set graph input tensors
......@@ -414,15 +416,15 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType
return false;
}
const auto &prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
auto prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (prim == nullptr) {
return false;
}
return (schema::PrimitiveType)prim->Type() == type;
return (schema::PrimitiveType)(prim->Type()) == type;
}
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph) {
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph) {
AnfExporter anf_exporter;
return anf_exporter.Export(func_graph);
return anf_exporter.Export(func_graph, keep_graph);
}
} // namespace mindspore::lite
......@@ -30,7 +30,7 @@ class AnfExporter {
public:
AnfExporter() = default;
virtual ~AnfExporter() = default;
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph);
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false);
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
......@@ -55,6 +55,6 @@ class AnfExporter {
std::vector<schema::CNodeT *> graph_input_nodes_;
};
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph);
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_
......@@ -239,7 +239,7 @@ int Benchmark::CompareOutput() {
bool hasError = false;
for (const auto &calibTensor : calibData) {
std::string nodeName = calibTensor.first;
auto tensors = session->GetOutputsByName(nodeName);
auto tensors = session->GetOutputsByNodeName(nodeName);
if (tensors.empty()) {
MS_LOG(ERROR) << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail.";
std::cerr << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail." << std::endl;
......
......@@ -58,27 +58,7 @@ class MindsporeImporter : public Converter {
~MindsporeImporter() override = default;
};
void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return;
}
auto primT = primitive_c->GetPrimitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "PrimitiveT is nullptr";
return;
}
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
delete primT;
primitive_c->SetPrimitiveT(nullptr);
}
}
}
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// parse the model and weight file to generate inference data structure
FuncGraphPtr graph = nullptr;
......@@ -137,7 +117,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return nullptr;
}
FreeFuncGraph(graph);
return meta_graph;
}
......
......@@ -57,7 +57,7 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
break;
}
default:
// MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str());
MS_LOG(INFO) << "will support quantizer type " << flags->quantTypeIn << " in the future";
break;
}
}
......
......@@ -47,8 +47,8 @@ class ModelParser {
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {
MS_EXCEPTION_IF_NULL(meta_graph);
auto func_graph = std::make_shared<FuncGraph>();
auto importer = new AnfImporterFromMetaGraphT(meta_graph, func_graph);
auto ret = importer->Import();
AnfImporterFromMetaGraphT importer(meta_graph, func_graph);
auto ret = importer.Import();
if (RET_OK != ret) {
MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << ret;
return nullptr;
......
......@@ -69,7 +69,11 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto,
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr(new (std::nothrow) schema::Conv2DT());
auto attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW;
......@@ -135,9 +139,9 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto,
op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.get();
op->primitive->value.value = attr.release();
status = ParseGroupConvolution(op, attr.release());
status = ParseGroupConvolution(op, static_cast<schema::Conv2DT *>(op->primitive->value.value));
if (status != RET_OK) {
MS_LOG(ERROR) << "Parse group convolution failed";
return RET_ERROR;
......
......@@ -310,6 +310,10 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file,
// load graph
std::unique_ptr<tflite::ModelT> tflite_model = ReadTfliteModel(model_file.c_str());
if (tflite_model == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
return nullptr;
}
if (tflite_model->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
......
......@@ -490,7 +490,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) {
quant_max = (1 << (this->bit_num - 1)) - 1; // 127
quant_min = -quant_max; // -127
quant_min = -quant_max; // -127
} else if (target_type == kNumberTypeUInt8) {
quant_max = (1 << this->bit_num) - 1; // 255
quant_min = 0;
......@@ -538,8 +538,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
return RET_OK;
}
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
bool perchanel, bool depthwise) {
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchanel,
bool depthwise) {
// const vector<int> dims = filter->dims;
// perlayer
if (!weight->isa<Parameter>()) {
......@@ -556,8 +556,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_ERROR;
}
auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
perchanel, depthwise);
auto status =
QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
......@@ -954,7 +954,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
}
// anf -- fb
auto meta_graph = Export(funcGraph);
auto meta_graph = Export(funcGraph, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR;
......
......@@ -17,14 +17,11 @@
#include <memory>
#include <set>
#include <vector>
#include <algorithm>
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "src/kernel_registry.h"
#include "src/scheduler.h"
#include "include/context.h"
#include "src/lite_session.h"
#include "src/populate_parameter.h"
#include "src/ops/primitive_c.h"
......@@ -135,26 +132,7 @@ void FreeInputTensor(std::vector<Tensor *> *input_tensor) {
}
return;
}
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return nullptr;
}
auto *lite_primitive = primitive_c->GetPrimitiveT();
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "Primitive in primitive_c is nullptr";
return nullptr;
}
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::Primitive::Pack(builder, lite_primitive);
builder.Finish(offset);
auto buf = builder.GetBufferPointer();
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
return const_cast<schema::Primitive *>(primitive);
}
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
CheckIfFuncGraphIsNull(func_graph);
......@@ -176,15 +154,13 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
auto output_nums = GetOutputTensorNum(input_cnode);
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
auto scheam_primitive = PackPrimitiveT(input_cnode);
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
FreeInputTensor(&input_tensors);
MS_LOG(ERROR) << "lite_primitive is nullptr";
return nullptr;
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get());
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeInputTensor(&input_tensors);
......
......@@ -114,14 +114,6 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
}
pre_node->set_abstract(abstr);
const auto &prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transform_node->input(0));
if (prim != nullptr) {
auto *prim_t = prim->GetPrimitiveT();
if (prim_t != nullptr) {
delete prim_t;
prim->SetPrimitiveT(nullptr);
}
}
return pre_node;
}
......
......@@ -360,7 +360,7 @@ int TimeProfile::RunTimeProfile() {
delete model;
return RET_ERROR;
}
auto outputs = session_->GetOutputs();
auto outputs = session_->GetOutputMapByNode();
uint64_t run_end = GetTimeUs();
uint64_t time = run_end - run_begin;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部