提交 07ea317d 编写于 作者: B Bin Li

Support multiple inputs and outputs for Hexagon DSP

上级 5967c7ab
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_ #ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_ #define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#include <map>
#include <memory> #include <memory>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -25,17 +27,20 @@ ...@@ -25,17 +27,20 @@
namespace mace { namespace mace {
struct InOutInfo { struct InOutInfo {
InOutInfo(const std::vector<index_t> &shape, InOutInfo(const std::string &name,
const std::vector<index_t> &shape,
const DataType data_type, const DataType data_type,
const float scale, const float scale,
const int32_t zero_point, const int32_t zero_point,
std::unique_ptr<Tensor> tensor_u8) std::unique_ptr<Tensor> tensor_u8)
: shape(shape), : name(name),
shape(shape),
data_type(data_type), data_type(data_type),
scale(scale), scale(scale),
zero_point(zero_point), zero_point(zero_point),
tensor_u8(std::move(tensor_u8)) {} tensor_u8(std::move(tensor_u8)) {}
std::string name;
std::vector<index_t> shape; std::vector<index_t> shape;
DataType data_type; DataType data_type;
float scale; float scale;
...@@ -56,8 +61,9 @@ class HexagonControlWrapper { ...@@ -56,8 +61,9 @@ class HexagonControlWrapper {
const unsigned char *model_data) = 0; const unsigned char *model_data) = 0;
virtual bool ExecuteGraph(const Tensor &input_tensor, virtual bool ExecuteGraph(const Tensor &input_tensor,
Tensor *output_tensor) = 0; Tensor *output_tensor) = 0;
virtual bool ExecuteGraphNew(const std::vector<Tensor *> &input_tensors, virtual bool ExecuteGraphNew(
std::vector<Tensor *> *output_tensors) = 0; const std::map<std::string, Tensor*> &input_tensors,
std::map<std::string, Tensor*> *output_tensors) = 0;
virtual bool TeardownGraph() = 0; virtual bool TeardownGraph() = 0;
virtual void PrintLog() = 0; virtual void PrintLog() = 0;
virtual void PrintGraph() = 0; virtual void PrintGraph() = 0;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <algorithm> #include <algorithm>
#include <iomanip> #include <iomanip>
#include <map>
#include <memory> #include <memory>
#include <thread> // NOLINT(build/c++11) #include <thread> // NOLINT(build/c++11)
#include <vector> #include <vector>
...@@ -239,7 +240,8 @@ bool HexagonDSPWrapper::SetupGraph(const NetDef &net_def, ...@@ -239,7 +240,8 @@ bool HexagonDSPWrapper::SetupGraph(const NetDef &net_def,
while (input_shape.size() < 4) { while (input_shape.size() < 4) {
input_shape.insert(input_shape.begin(), 1); input_shape.insert(input_shape.begin(), 1);
} }
input_info_.emplace_back(input_shape, input_info_.emplace_back(input_info.name(),
input_shape,
input_info.data_type(), input_info.data_type(),
input_info.scale(), input_info.scale(),
input_info.zero_point(), input_info.zero_point(),
...@@ -255,7 +257,8 @@ bool HexagonDSPWrapper::SetupGraph(const NetDef &net_def, ...@@ -255,7 +257,8 @@ bool HexagonDSPWrapper::SetupGraph(const NetDef &net_def,
while (output_shape.size() < 4) { while (output_shape.size() < 4) {
output_shape.insert(output_shape.begin(), 1); output_shape.insert(output_shape.begin(), 1);
} }
output_info_.emplace_back(output_shape, output_info_.emplace_back(output_info.name(),
output_shape,
output_info.data_type(), output_info.data_type(),
output_info.scale(), output_info.scale(),
output_info.zero_point(), output_info.zero_point(),
...@@ -396,8 +399,8 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor, ...@@ -396,8 +399,8 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor,
Tensor *output_tensor) { Tensor *output_tensor) {
VLOG(2) << "Execute graph: " << nn_id_; VLOG(2) << "Execute graph: " << nn_id_;
// single input and single output // single input and single output
MACE_ASSERT(num_inputs_ == 1, "Wrong inputs num"); MACE_CHECK(num_inputs_ == 1, "Wrong inputs num");
MACE_ASSERT(num_outputs_ == 1, "Wrong outputs num"); MACE_CHECK(num_outputs_ == 1, "Wrong outputs num");
output_tensor->SetDtype(output_info_[0].data_type); output_tensor->SetDtype(output_info_[0].data_type);
output_tensor->Resize(output_info_[0].shape); output_tensor->Resize(output_info_[0].shape);
std::vector<uint32_t> output_shape(4); std::vector<uint32_t> output_shape(4);
...@@ -419,26 +422,27 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor, ...@@ -419,26 +422,27 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor,
&output_bytes); &output_bytes);
MACE_CHECK(res == 0, "execute error"); MACE_CHECK(res == 0, "execute error");
MACE_ASSERT(output_shape.size() == output_info_[0].shape.size(), MACE_CHECK(output_shape.size() == output_info_[0].shape.size(),
"wrong output shape inferred"); "wrong output shape inferred");
for (size_t i = 0; i < output_shape.size(); ++i) { for (size_t i = 0; i < output_shape.size(); ++i) {
MACE_ASSERT(static_cast<index_t>(output_shape[i]) MACE_CHECK(static_cast<index_t>(output_shape[i])
== output_info_[0].shape[i], == output_info_[0].shape[i],
"wrong output shape inferred"); "wrong output shape inferred");
} }
MACE_ASSERT(output_bytes == output_tensor->raw_size(), MACE_CHECK(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred."); "wrong output bytes inferred.");
return res == 0; return res == 0;
} }
bool HexagonDSPWrapper::ExecuteGraphNew( bool HexagonDSPWrapper::ExecuteGraphNew(
const std::vector<Tensor *> &input_tensors, const std::map<std::string, Tensor*> &input_tensors,
std::vector<Tensor *> *output_tensors) { std::map<std::string, Tensor*> *output_tensors) {
VLOG(2) << "Execute graph new: " << nn_id_; VLOG(2) << "Execute graph new: " << nn_id_;
uint32_t num_inputs = static_cast<uint32_t>(input_tensors.size()); uint32_t num_inputs = static_cast<uint32_t>(input_tensors.size());
uint32_t num_outputs = static_cast<uint32_t>(output_tensors->size()); uint32_t num_outputs = static_cast<uint32_t>(output_tensors->size());
MACE_ASSERT(num_inputs_ == num_inputs, "Wrong inputs num"); MACE_CHECK(num_inputs_ == static_cast<int>(num_inputs), "Wrong inputs num");
MACE_ASSERT(num_outputs_ == num_outputs, "Wrong outputs num"); MACE_CHECK(num_outputs_ == static_cast<int>(num_outputs),
"Wrong outputs num");
std::vector<hexagon_nn_tensordef> inputs(num_inputs * kNumMetaData); std::vector<hexagon_nn_tensordef> inputs(num_inputs * kNumMetaData);
std::vector<hexagon_nn_tensordef> outputs(num_outputs * kNumMetaData); std::vector<hexagon_nn_tensordef> outputs(num_outputs * kNumMetaData);
...@@ -447,17 +451,18 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -447,17 +451,18 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
// transform mace input to hexagon input // transform mace input to hexagon input
for (size_t i = 0; i < num_inputs; ++i) { for (size_t i = 0; i < num_inputs; ++i) {
std::vector<index_t> input_shape = input_tensors[i]->shape(); const auto input_tensor = input_tensors.at(input_info_[i].name);
const auto &input_shape = input_tensor->shape();
size_t index = i * kNumMetaData; size_t index = i * kNumMetaData;
inputs[index].batches = static_cast<uint32_t>(input_shape[0]); inputs[index].batches = static_cast<uint32_t>(input_shape[0]);
inputs[index].height = static_cast<uint32_t>(input_shape[1]); inputs[index].height = static_cast<uint32_t>(input_shape[1]);
inputs[index].width = static_cast<uint32_t>(input_shape[2]); inputs[index].width = static_cast<uint32_t>(input_shape[2]);
inputs[index].depth = static_cast<uint32_t>(input_shape[3]); inputs[index].depth = static_cast<uint32_t>(input_shape[3]);
inputs[index].data = const_cast<unsigned char *>( inputs[index].data = const_cast<unsigned char *>(
reinterpret_cast<const unsigned char *>(input_tensors[i]->raw_data())); reinterpret_cast<const unsigned char *>(input_tensor->raw_data()));
inputs[index].dataLen = static_cast<int>(input_tensors[i]->raw_size()); inputs[index].dataLen = static_cast<int>(input_tensor->raw_size());
inputs[index].data_valid_len = inputs[index].data_valid_len =
static_cast<uint32_t>(input_tensors[i]->raw_size()); static_cast<uint32_t>(input_tensor->raw_size());
inputs[index].unused = 0; inputs[index].unused = 0;
input_metadata[i].Init(.0f, .0f, 1); input_metadata[i].Init(.0f, .0f, 1);
AddInputMetadata(input_metadata[i].min_val, &inputs[index + 1]); AddInputMetadata(input_metadata[i].min_val, &inputs[index + 1]);
...@@ -467,13 +472,14 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -467,13 +472,14 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
// transform mace output to hexagon output // transform mace output to hexagon output
for (size_t i = 0; i < num_outputs; ++i) { for (size_t i = 0; i < num_outputs; ++i) {
auto output_tensor = output_tensors->at(output_info_[i].name);
size_t index = i * kNumMetaData; size_t index = i * kNumMetaData;
(*output_tensors)[i]->SetDtype(output_info_[i].data_type); output_tensor->SetDtype(output_info_[i].data_type);
(*output_tensors)[i]->Resize(output_info_[i].shape); output_tensor->Resize(output_info_[i].shape);
outputs[index].data = reinterpret_cast<unsigned char *>( outputs[index].data = reinterpret_cast<unsigned char *>(
(*output_tensors)[i]->raw_mutable_data()); output_tensor->raw_mutable_data());
outputs[index].dataLen = static_cast<int>((*output_tensors)[i]->raw_size()); outputs[index].dataLen = static_cast<int>(output_tensor->raw_size());
output_metadata[i].Init(.0f, .0f, 1); output_metadata[i].Init(.0f, .0f, 1);
AddOutputMetadata(output_metadata[i].min_val, &outputs[index + 1]); AddOutputMetadata(output_metadata[i].min_val, &outputs[index + 1]);
...@@ -495,17 +501,20 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -495,17 +501,20 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
std::vector<uint32_t> output_shape{ std::vector<uint32_t> output_shape{
outputs[index].batches, outputs[index].height, outputs[index].width, outputs[index].batches, outputs[index].height, outputs[index].width,
outputs[index].depth}; outputs[index].depth};
MACE_ASSERT(output_shape.size() == output_info_[i].shape.size(), MACE_CHECK(output_shape.size() == output_info_[i].shape.size(),
output_shape.size(), " vs ", output_info_[i].shape.size(),
"wrong output shape inferred"); "wrong output shape inferred");
for (size_t j = 0; j < output_shape.size(); ++j) { for (size_t j = 0; j < output_shape.size(); ++j) {
MACE_ASSERT(static_cast<index_t>(output_shape[j]) MACE_CHECK(static_cast<index_t>(output_shape[j])
== output_info_[i].shape[j], == output_info_[i].shape[j],
output_shape[j], " vs ", output_info_[i].shape[j],
"wrong output shape inferred"); "wrong output shape inferred");
} }
auto output_tensor = output_tensors->at(output_info_[i].name);
MACE_ASSERT(static_cast<index_t>(outputs[index].data_valid_len) MACE_CHECK(static_cast<index_t>(outputs[index].data_valid_len)
== (*output_tensors)[i]->raw_size(), == output_tensor->raw_size(),
"wrong output bytes inferred."); outputs[index].data_valid_len, " vs ", output_tensor->raw_size(),
" wrong output bytes inferred.");
} }
return res == 0; return res == 0;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_DSP_WRAPPER_H_ #ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_DSP_WRAPPER_H_
#define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_DSP_WRAPPER_H_ #define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_DSP_WRAPPER_H_
#include <map>
#include <string>
#include <vector> #include <vector>
#include "mace/core/runtime/hexagon/hexagon_control_wrapper.h" #include "mace/core/runtime/hexagon/hexagon_control_wrapper.h"
...@@ -35,8 +37,8 @@ class HexagonDSPWrapper : public HexagonControlWrapper { ...@@ -35,8 +37,8 @@ class HexagonDSPWrapper : public HexagonControlWrapper {
const unsigned char *model_data) override; const unsigned char *model_data) override;
bool ExecuteGraph(const Tensor &input_tensor, bool ExecuteGraph(const Tensor &input_tensor,
Tensor *output_tensor) override; Tensor *output_tensor) override;
bool ExecuteGraphNew(const std::vector<Tensor *> &input_tensors, bool ExecuteGraphNew(const std::map<std::string, Tensor*> &input_tensors,
std::vector<Tensor *> *output_tensors) override; std::map<std::string, Tensor*> *output_tensors) override;
bool TeardownGraph() override; bool TeardownGraph() override;
void PrintLog() override; void PrintLog() override;
void PrintGraph() override; void PrintGraph() override;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <iomanip> #include <iomanip>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -160,7 +161,8 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, ...@@ -160,7 +161,8 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
while (input_shape.size() < 4) { while (input_shape.size() < 4) {
input_shape.insert(input_shape.begin(), 1); input_shape.insert(input_shape.begin(), 1);
} }
input_info_.emplace_back(input_shape, input_info_.emplace_back(input_info.name(),
input_shape,
input_info.data_type(), input_info.data_type(),
input_info.scale(), input_info.scale(),
input_info.zero_point(), input_info.zero_point(),
...@@ -176,7 +178,8 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, ...@@ -176,7 +178,8 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
while (output_shape.size() < 4) { while (output_shape.size() < 4) {
output_shape.insert(output_shape.begin(), 1); output_shape.insert(output_shape.begin(), 1);
} }
output_info_.emplace_back(output_shape, output_info_.emplace_back(output_info.name(),
output_shape,
output_info.data_type(), output_info.data_type(),
output_info.scale(), output_info.scale(),
output_info.zero_point(), output_info.zero_point(),
...@@ -234,19 +237,21 @@ bool HexagonHTAWrapper::ExecuteGraph(const Tensor &input_tensor, ...@@ -234,19 +237,21 @@ bool HexagonHTAWrapper::ExecuteGraph(const Tensor &input_tensor,
} }
bool HexagonHTAWrapper::ExecuteGraphNew( bool HexagonHTAWrapper::ExecuteGraphNew(
const std::vector<Tensor *> &input_tensors, const std::map<std::string, Tensor*> &input_tensors,
std::vector<Tensor *> *output_tensors) { std::map<std::string, Tensor*> *output_tensors) {
VLOG(2) << "Execute graph new: " << nn_id_; VLOG(2) << "Execute graph new: " << nn_id_;
uint32_t num_inputs = static_cast<uint32_t>(input_tensors.size()); uint32_t num_inputs = static_cast<uint32_t>(input_tensors.size());
uint32_t num_outputs = static_cast<uint32_t>(output_tensors->size()); uint32_t num_outputs = static_cast<uint32_t>(output_tensors->size());
MACE_ASSERT(num_inputs_ == num_inputs, "Wrong inputs num"); MACE_CHECK(num_inputs_ == static_cast<int>(num_inputs), "Wrong inputs num");
MACE_ASSERT(num_outputs_ == num_outputs, "Wrong outputs num"); MACE_CHECK(num_outputs_ == static_cast<int>(num_outputs),
"Wrong outputs num");
std::vector<hexagon_hta_nn_tensordef> inputs(num_inputs); std::vector<hexagon_hta_nn_tensordef> inputs(num_inputs);
std::vector<hexagon_hta_nn_tensordef> outputs(num_outputs); std::vector<hexagon_hta_nn_tensordef> outputs(num_outputs);
for (size_t i = 0; i < num_inputs; ++i) { for (size_t i = 0; i < num_inputs; ++i) {
std::vector<index_t> input_shape = input_tensors[i]->shape(); const auto input_tensor = input_tensors.at(input_info_[i].name);
const auto &input_shape = input_tensor->shape();
inputs[i].batches = static_cast<uint32_t>(input_shape[0]); inputs[i].batches = static_cast<uint32_t>(input_shape[0]);
inputs[i].height = static_cast<uint32_t>(input_shape[1]); inputs[i].height = static_cast<uint32_t>(input_shape[1]);
inputs[i].width = static_cast<uint32_t>(input_shape[2]); inputs[i].width = static_cast<uint32_t>(input_shape[2]);
...@@ -254,10 +259,10 @@ bool HexagonHTAWrapper::ExecuteGraphNew( ...@@ -254,10 +259,10 @@ bool HexagonHTAWrapper::ExecuteGraphNew(
input_info_[i].tensor_u8->SetDtype(DT_UINT8); input_info_[i].tensor_u8->SetDtype(DT_UINT8);
input_info_[i].tensor_u8->Resize(input_shape); input_info_[i].tensor_u8->Resize(input_shape);
const float *input_data = input_tensors[i]->data<float>(); const float *input_data = input_tensor->data<float>();
uint8_t *input_data_u8 = input_info_[i].tensor_u8->mutable_data<uint8_t>(); uint8_t *input_data_u8 = input_info_[i].tensor_u8->mutable_data<uint8_t>();
QuantizeWithScaleAndZeropoint(input_data, QuantizeWithScaleAndZeropoint(input_data,
input_tensors[i]->size(), input_tensor->size(),
input_info_[i].scale, input_info_[i].scale,
input_info_[i].zero_point, input_info_[i].zero_point,
input_data_u8); input_data_u8);
...@@ -272,8 +277,9 @@ bool HexagonHTAWrapper::ExecuteGraphNew( ...@@ -272,8 +277,9 @@ bool HexagonHTAWrapper::ExecuteGraphNew(
} }
for (size_t i = 0; i < num_outputs; ++i) { for (size_t i = 0; i < num_outputs; ++i) {
(*output_tensors)[i]->SetDtype(output_info_[i].data_type); auto output_tensor = output_tensors->at(output_info_[i].name);
(*output_tensors)[i]->Resize(output_info_[i].shape); output_tensor->SetDtype(output_info_[i].data_type);
output_tensor->Resize(output_info_[i].shape);
output_info_[i].tensor_u8->SetDtype(DT_UINT8); output_info_[i].tensor_u8->SetDtype(DT_UINT8);
output_info_[i].tensor_u8->Resize(output_info_[i].shape); output_info_[i].tensor_u8->Resize(output_info_[i].shape);
outputs[i].data = reinterpret_cast<unsigned char *>( outputs[i].data = reinterpret_cast<unsigned char *>(
...@@ -292,19 +298,23 @@ bool HexagonHTAWrapper::ExecuteGraphNew( ...@@ -292,19 +298,23 @@ bool HexagonHTAWrapper::ExecuteGraphNew(
std::vector<uint32_t> output_shape{ std::vector<uint32_t> output_shape{
outputs[i].batches, outputs[i].height, outputs[i].width, outputs[i].batches, outputs[i].height, outputs[i].width,
outputs[i].depth}; outputs[i].depth};
MACE_ASSERT(output_shape.size() == output_info_[i].shape.size(), MACE_CHECK(output_shape.size() == output_info_[i].shape.size(),
output_shape.size(), " vs ", output_info_[i].shape.size(),
"wrong output shape inferred"); "wrong output shape inferred");
for (size_t j = 0; j < output_shape.size(); ++j) { for (size_t j = 0; j < output_shape.size(); ++j) {
MACE_ASSERT(static_cast<index_t>(output_shape[j]) MACE_CHECK(static_cast<index_t>(output_shape[j])
== output_info_[i].shape[j], == output_info_[i].shape[j],
output_shape[j], " vs ", output_info_[i].shape[j],
"wrong output shape inferred"); "wrong output shape inferred");
} }
MACE_ASSERT(static_cast<index_t>(outputs[i].data_valid_len) auto output_tensor = output_tensors->at(output_info_[i].name);
== (*output_tensors)[i]->raw_size(), MACE_CHECK(static_cast<index_t>(outputs[i].data_valid_len)
"wrong output bytes inferred."); == output_tensor->raw_size(),
outputs[i].data_valid_len, " vs ", output_tensor->raw_size(),
" wrong output bytes inferred.");
const uint8_t *output_data_u8 = output_info_[i].tensor_u8->data<uint8_t>(); const uint8_t *output_data_u8 = output_info_[i].tensor_u8->data<uint8_t>();
float *output_data = (*output_tensors)[i]->mutable_data<float>(); float *output_data = output_tensor->mutable_data<float>();
Dequantize(output_data_u8, Dequantize(output_data_u8,
output_info_[i].tensor_u8->size(), output_info_[i].tensor_u8->size(),
output_info_[i].scale, output_info_[i].scale,
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_HTA_WRAPPER_H_ #ifndef MACE_CORE_RUNTIME_HEXAGON_HEXAGON_HTA_WRAPPER_H_
#define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_HTA_WRAPPER_H_ #define MACE_CORE_RUNTIME_HEXAGON_HEXAGON_HTA_WRAPPER_H_
#include <map>
#include <string>
#include <vector> #include <vector>
#include "mace/core/runtime/hexagon/hexagon_control_wrapper.h" #include "mace/core/runtime/hexagon/hexagon_control_wrapper.h"
...@@ -35,8 +37,8 @@ class HexagonHTAWrapper : public HexagonControlWrapper { ...@@ -35,8 +37,8 @@ class HexagonHTAWrapper : public HexagonControlWrapper {
const unsigned char *model_data) override; const unsigned char *model_data) override;
bool ExecuteGraph(const Tensor &input_tensor, bool ExecuteGraph(const Tensor &input_tensor,
Tensor *output_tensor) override; Tensor *output_tensor) override;
bool ExecuteGraphNew(const std::vector<Tensor *> &input_tensors, bool ExecuteGraphNew(const std::map<std::string, Tensor*> &input_tensors,
std::vector<Tensor *> *output_tensors) override; std::map<std::string, Tensor*> *output_tensors) override;
bool TeardownGraph() override; bool TeardownGraph() override;
void PrintLog() override; void PrintLog() override;
void PrintGraph() override; void PrintGraph() override;
......
...@@ -736,8 +736,8 @@ MaceStatus MaceEngine::Impl::Run( ...@@ -736,8 +736,8 @@ MaceStatus MaceEngine::Impl::Run(
std::map<std::string, MaceTensor> *outputs, std::map<std::string, MaceTensor> *outputs,
RunMetadata *run_metadata) { RunMetadata *run_metadata) {
MACE_CHECK_NOTNULL(outputs); MACE_CHECK_NOTNULL(outputs);
std::vector<Tensor *> input_tensors; std::map<std::string, Tensor*> input_tensors;
std::vector<Tensor *> output_tensors; std::map<std::string, Tensor*> output_tensors;
for (auto &input : inputs) { for (auto &input : inputs) {
if (input_info_map_.find(input.first) == input_info_map_.end()) { if (input_info_map_.find(input.first) == input_info_map_.end()) {
LOG(FATAL) << "'" << input.first LOG(FATAL) << "'" << input.first
...@@ -746,7 +746,7 @@ MaceStatus MaceEngine::Impl::Run( ...@@ -746,7 +746,7 @@ MaceStatus MaceEngine::Impl::Run(
} }
Tensor *input_tensor = ws_->GetTensor(input.first); Tensor *input_tensor = ws_->GetTensor(input.first);
MACE_RETURN_IF_ERROR(TransposeInput(input, input_tensor)); MACE_RETURN_IF_ERROR(TransposeInput(input, input_tensor));
input_tensors.push_back(input_tensor); input_tensors[input.first] = input_tensor;
} }
for (auto &output : *outputs) { for (auto &output : *outputs) {
if (output_info_map_.find(output.first) == output_info_map_.end()) { if (output_info_map_.find(output.first) == output_info_map_.end()) {
...@@ -755,12 +755,14 @@ MaceStatus MaceEngine::Impl::Run( ...@@ -755,12 +755,14 @@ MaceStatus MaceEngine::Impl::Run(
<< MakeString(MapKeys(output_info_map_)); << MakeString(MapKeys(output_info_map_));
} }
Tensor *output_tensor = ws_->GetTensor(output.first); Tensor *output_tensor = ws_->GetTensor(output.first);
output_tensors.push_back(output_tensor); output_tensors[output.first] = output_tensor;
} }
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
if (device_type_ == HEXAGON || device_type_ == HTA) { if (device_type_ == HEXAGON || device_type_ == HTA) {
if (device_type_ == HTA) {
MACE_CHECK(input_tensors.size() == 1 && output_tensors.size() == 1, MACE_CHECK(input_tensors.size() == 1 && output_tensors.size() == 1,
"HEXAGON not support multiple inputs and outputs yet."); "HTA not support multiple inputs and outputs yet.");
}
hexagon_controller_->ExecuteGraphNew(input_tensors, &output_tensors); hexagon_controller_->ExecuteGraphNew(input_tensors, &output_tensors);
} else { } else {
#endif #endif
......
...@@ -119,9 +119,10 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -119,9 +119,10 @@ class HexagonConverter(base_converter.ConverterInterface):
self._quantize_activation_info = quantize_activation_info self._quantize_activation_info = quantize_activation_info
def run(self): def run(self):
if self._option.device == DeviceType.HTA.value:
mace_check(len(self._option.input_nodes) == 1 mace_check(len(self._option.input_nodes) == 1
and len(self._option.output_nodes) == 1, and len(self._option.output_nodes) == 1,
'dsp only support single input and output') 'hta only support single input and output')
for tensor in self._model.tensors: for tensor in self._model.tensors:
self._consts[tensor.name] = tensor self._consts[tensor.name] = tensor
...@@ -129,13 +130,7 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -129,13 +130,7 @@ class HexagonConverter(base_converter.ConverterInterface):
# convert op node # convert op node
self.convert_ops() self.convert_ops()
self.add_input_output_node() self.convert_input_output_node()
if not self._option.check_nodes:
output_name = list(self._option.output_nodes.values())[0].name
else:
output_name = list(self._option.check_nodes.values())[0].name
output_name = normalize_name(output_name)
self._model = graph_util.sort_mace_graph(self._model, output_name)
self.add_node_id() self.add_node_id()
...@@ -399,21 +394,42 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -399,21 +394,42 @@ class HexagonConverter(base_converter.ConverterInterface):
elif op.input[i] == input_op + ':2': elif op.input[i] == input_op + ':2':
op.input[i] = input_max op.input[i] = input_max
def add_input_output_node(self): def convert_input_output_node(self):
quantize_input_op = self._model.op[0]
mace_check( mace_check(
self._model.op[0].type == HexagonOp.QuantizeINPUT_f_to_8.name, quantize_input_op.type == HexagonOp.QuantizeINPUT_f_to_8.name,
"Not started with Quantize op.") "Not started with Quantize op.")
quantize_input_op = self._model.op[0]
del quantize_input_op.input[:] del quantize_input_op.input[:]
mace_check(
self._model.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name,
"Not ended with Dequantize op.")
dequantize_output_op = self._model.op[-1] dequantize_output_op = self._model.op[-1]
mace_check(dequantize_output_op.type
== HexagonOp.DequantizeOUTPUT_8tof.name,
"Not ended with Dequantize op.")
dequantize_input = [input for input in dequantize_output_op.input]
del dequantize_output_op.input[:]
del dequantize_output_op.output_shape[:] del dequantize_output_op.output_shape[:]
del dequantize_output_op.output_type[:] del dequantize_output_op.output_type[:]
del dequantize_output_op.out_max_byte_size[:] del dequantize_output_op.out_max_byte_size[:]
index = 1
while index < len(self._model.op) - 1:
op = self._model.op[index]
if op.type == HexagonOp.QuantizeINPUT_f_to_8.name:
quantize_input_op.output.extend(op.output)
quantize_input_op.output_shape.extend(op.output_shape)
quantize_input_op.output_type.extend(op.output_type)
quantize_input_op.out_max_byte_size.extend(
op.out_max_byte_size)
del self._model.op[index]
elif op.type == HexagonOp.DequantizeOUTPUT_8tof.name:
dequantize_output_op.input.extend(op.input)
del self._model.op[index]
index += 1
# input order matters
dequantize_output_op.input.extend(dequantize_input)
if self._option.device == DeviceType.HTA.value: if self._option.device == DeviceType.HTA.value:
# replace QuantizeINPUT_f_to_8 with INPUT # replace QuantizeINPUT_f_to_8 with INPUT
quantize_input_op.type = HexagonOp.INPUT.name quantize_input_op.type = HexagonOp.INPUT.name
......
...@@ -160,6 +160,7 @@ def main(unused_args): ...@@ -160,6 +160,7 @@ def main(unused_args):
dequantize_op.node_input[0].node_id = op.node_id dequantize_op.node_input[0].node_id = op.node_id
dequantize_op.node_input[1].node_id = op.node_id dequantize_op.node_input[1].node_id = op.node_id
dequantize_op.node_input[2].node_id = op.node_id dequantize_op.node_input[2].node_id = op.node_id
del dequantize_op.node_input[3:]
model_path = save_model_to_proto(net, normalize_op_name(op_name), model_path = save_model_to_proto(net, normalize_op_name(op_name),
FLAGS.output_dir) FLAGS.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册