提交 7787dfde 编写于 作者: 李寅

Merge branch 'feature_wuch' into 'master'

add multi-inputs & multi-outputs support

See merge request !122
...@@ -111,22 +111,32 @@ bool HexagonControlWrapper::SetupGraph(const NetDef& net_def) { ...@@ -111,22 +111,32 @@ bool HexagonControlWrapper::SetupGraph(const NetDef& net_def) {
} }
// input info // input info
const InputInfo& input_info = net_def.input_info()[0]; num_inputs_ = 0;
input_shape_.insert(input_shape_.begin(), for (const InputInfo &input_info: net_def.input_info()) {
input_info.dims().begin(), input_info.dims().end()); vector<index_t> input_shape;
while (input_shape_.size() < 4) { input_shape.insert(input_shape.begin(),
input_shape_.insert(input_shape_.begin(), 1); input_info.dims().begin(), input_info.dims().end());
while (input_shape.size() < 4) {
input_shape.insert(input_shape.begin(), 1);
}
input_shapes_.push_back(input_shape);
input_data_types_.push_back(input_info.data_type());
num_inputs_ += 1;
} }
input_data_type_ = input_info.data_type();
// output info // output info
const OutputInfo& output_info = net_def.output_info()[0]; num_outputs_ = 0;
output_shape_.insert(output_shape_.begin(), for (const OutputInfo &output_info: net_def.output_info()) {
output_info.dims().begin(), output_info.dims().end()); vector<index_t> output_shape;
while (output_shape_.size() < 4) { output_shape.insert(output_shape.begin(),
output_shape_.insert(output_shape_.begin(), 1); output_info.dims().begin(), output_info.dims().end());
while (output_shape.size() < 4) {
output_shape.insert(output_shape.begin(), 1);
}
output_shapes_.push_back(output_shape);
output_data_types_.push_back(output_info.data_type());
num_outputs_ += 1;
} }
output_data_type_ = output_info.data_type();
bool res = hexagon_nn_prepare(nn_id_) == 0; bool res = hexagon_nn_prepare(nn_id_) == 0;
return res; return res;
......
...@@ -25,8 +25,11 @@ class HexagonControlWrapper { ...@@ -25,8 +25,11 @@ class HexagonControlWrapper {
bool SetupGraph(const std::string &model_file); bool SetupGraph(const std::string &model_file);
bool ExecuteGraph(const Tensor &input_tensor, Tensor *output_tensor) { bool ExecuteGraph(const Tensor &input_tensor, Tensor *output_tensor) {
LOG(INFO) << "Execute graph: " << nn_id_; LOG(INFO) << "Execute graph: " << nn_id_;
output_tensor->SetDtype(output_data_type_); // single input and single output
output_tensor->Resize(output_shape_); MACE_ASSERT(num_inputs_ == 1, "Wrong inputs num");
MACE_ASSERT(num_outputs_ == 1, "Wrong outputs num");
output_tensor->SetDtype(output_data_types_[0]);
output_tensor->Resize(output_shapes_[0]);
vector<uint32_t> output_shape(4); vector<uint32_t> output_shape(4);
uint32_t output_bytes; uint32_t output_bytes;
int res = hexagon_nn_execute(nn_id_, int res = hexagon_nn_execute(nn_id_,
...@@ -46,13 +49,58 @@ class HexagonControlWrapper { ...@@ -46,13 +49,58 @@ class HexagonControlWrapper {
output_tensor->raw_size(), output_tensor->raw_size(),
&output_bytes); &output_bytes);
MACE_ASSERT(output_shape == output_shape_, MACE_ASSERT(output_shape == output_shapes_[0],
"wrong output shape inferred"); "wrong output shape inferred");
MACE_ASSERT(output_bytes == output_tensor->raw_size(), MACE_ASSERT(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred."); "wrong output bytes inferred.");
return res == 0; return res == 0;
}; };
bool ExecuteGraphNew(const Tensor *input_tensors, int num_inputs,
Tensor *output_tensors, int num_outputs) {
LOG(INFO) << "Execute graph new: " << nn_id_;
MACE_ASSERT(num_inputs_ == num_inputs, "Wrong inputs num");
MACE_ASSERT(num_outputs_ == num_outputs, "Wrong outputs num");
hexagon_nn_tensordef *inputs = new hexagon_nn_tensordef[num_inputs];
hexagon_nn_tensordef *outputs = new hexagon_nn_tensordef[num_outputs];
for (int i = 0; i < num_inputs; ++i) {
vector<index_t> input_shape = input_tensors[i].shape();
inputs[i].batches = input_shape[0];
inputs[i].height = input_shape[1];
inputs[i].width = input_shape[2];
inputs[i].depth = input_shape[3];
inputs[i].data = const_cast<unsigned char *>(
reinterpret_cast<const unsigned char *>(input_tensors[i].raw_data()));
inputs[i].dataLen = input_tensors[i].raw_size();
inputs[i].data_valid_len = input_tensors[i].raw_size();
inputs[i].unused = 0;
}
for (int i = 0; i < num_outputs; ++i) {
output_tensors[i].SetDtype(output_data_types_[i]);
output_tensors[i].Resize(output_shapes_[i]);
vector<index_t> output_shape = output_tensors[0].shape();
outputs[i].batches = output_shape[0];
outputs[i].height = output_shape[1];
outputs[i].width = output_shape[2];
outputs[i].depth = output_shape[3];
outputs[i].data = reinterpret_cast<unsigned char *>(
output_tensors[i].raw_mutable_data());
outputs[i].dataLen = output_tensors[i].raw_size();
outputs[i].data_valid_len = output_tensors[i].raw_size();
outputs[i].unused = 0;
}
int res = hexagon_nn_execute_new(nn_id_, inputs, num_inputs,
outputs, num_outputs);
delete(inputs);
delete(outputs);
return res == 0;
};
bool TeardownGraph(); bool TeardownGraph();
void PrintLog(); void PrintLog();
void PrintGraph(); void PrintGraph();
...@@ -71,10 +119,12 @@ class HexagonControlWrapper { ...@@ -71,10 +119,12 @@ class HexagonControlWrapper {
int nn_id_; int nn_id_;
Serializer serializer_; Serializer serializer_;
vector<index_t> input_shape_; vector<vector<index_t>> input_shapes_;
vector<index_t> output_shape_; vector<vector<index_t>> output_shapes_;
DataType input_data_type_; vector<DataType> input_data_types_;
DataType output_data_type_; vector<DataType> output_data_types_;
uint32_t num_inputs_;
uint32_t num_outputs_;
DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper); DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
}; };
......
...@@ -224,7 +224,7 @@ TEST(SupernodeTest, Supernode) { ...@@ -224,7 +224,7 @@ TEST(SupernodeTest, Supernode) {
input_data[h * 4 + w] = (uint8_t)((h == 0) ? 0 : h * 64 - 1); input_data[h * 4 + w] = (uint8_t)((h == 0) ? 0 : h * 64 - 1);
} }
VLOG(0) << wrapper.ExecuteGraph(input_tensor, &output_tensor); VLOG(0) << wrapper.ExecuteGraphNew(&input_tensor, 1, &output_tensor, 1);
wrapper.PrintLog(); wrapper.PrintLog();
// expect out: [[49.2095, 49.2095], [50.7905, 50.7905]] // expect out: [[49.2095, 49.2095], [50.7905, 50.7905]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册