提交 551cdfe2 编写于 作者: H hangq

return unorderd_map rather than vector for LiteSession::GetOutputs

上级 fcdc9c40
......@@ -20,6 +20,7 @@
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "include/ms_tensor.h"
#include "include/model.h"
#include "include/context.h"
......@@ -85,8 +86,8 @@ class MS_API LiteSession {
/// \brief Get output MindSpore Lite MSTensors of model.
///
/// \return A vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetOutputs() const = 0;
/// \return A map of output node name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by node name.
///
......
......@@ -177,17 +177,8 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
}
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() const {
std::vector<mindspore::tensor::MSTensor *> ret;
for (auto &iter : this->output_map) {
auto &node_output_tensors = iter.second;
for (auto tensor : node_output_tensors) {
if (!IsContain(ret, tensor)) {
ret.emplace_back(tensor);
}
}
}
return ret;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputs() const {
return this->output_map;
}
int LiteSession::Init(Context *context) {
......
......@@ -49,7 +49,7 @@ class LiteSession : public session::LiteSession {
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;
std::vector<mindspore::tensor::MSTensor *> GetOutputs() const override;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const override;
std::vector<mindspore::tensor::MSTensor *> GetOutputsByName(const std::string &name) const override;
......
......@@ -130,7 +130,8 @@ TEST_F(InferTest, TestConvNode) {
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto outTensor = outputs.front();
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum());
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
......@@ -220,7 +221,8 @@ TEST_F(InferTest, TestAddNode) {
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto outTensor = outputs.front();
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册