diff --git a/mace/dsp/hexagon_control_wrapper.cc b/mace/dsp/hexagon_control_wrapper.cc index 7c65e7e5212f0797e847b22fd67640bb58854f41..08ad17b52eab45b905079ce8dd7f647617d33d6a 100644 --- a/mace/dsp/hexagon_control_wrapper.cc +++ b/mace/dsp/hexagon_control_wrapper.cc @@ -111,22 +111,32 @@ bool HexagonControlWrapper::SetupGraph(const NetDef& net_def) { } // input info - const InputInfo& input_info = net_def.input_info()[0]; - input_shape_.insert(input_shape_.begin(), - input_info.dims().begin(), input_info.dims().end()); - while (input_shape_.size() < 4) { - input_shape_.insert(input_shape_.begin(), 1); + num_inputs_ = 0; + for (const InputInfo &input_info: net_def.input_info()) { + vector input_shape; + input_shape.insert(input_shape.begin(), + input_info.dims().begin(), input_info.dims().end()); + while (input_shape.size() < 4) { + input_shape.insert(input_shape.begin(), 1); + } + input_shapes_.push_back(input_shape); + input_data_types_.push_back(input_info.data_type()); + num_inputs_ += 1; } - input_data_type_ = input_info.data_type(); // output info - const OutputInfo& output_info = net_def.output_info()[0]; - output_shape_.insert(output_shape_.begin(), - output_info.dims().begin(), output_info.dims().end()); - while (output_shape_.size() < 4) { - output_shape_.insert(output_shape_.begin(), 1); + num_outputs_ = 0; + for (const OutputInfo &output_info: net_def.output_info()) { + vector output_shape; + output_shape.insert(output_shape.begin(), + output_info.dims().begin(), output_info.dims().end()); + while (output_shape.size() < 4) { + output_shape.insert(output_shape.begin(), 1); + } + output_shapes_.push_back(output_shape); + output_data_types_.push_back(output_info.data_type()); + num_outputs_ += 1; } - output_data_type_ = output_info.data_type(); bool res = hexagon_nn_prepare(nn_id_) == 0; return res; diff --git a/mace/dsp/hexagon_control_wrapper.h b/mace/dsp/hexagon_control_wrapper.h index a67e9903b7a42f6866fe1e1a63177586bfdfd326..9ef1113f5f18569012b735929989b67d67b3cafc 100644 --- a/mace/dsp/hexagon_control_wrapper.h +++ b/mace/dsp/hexagon_control_wrapper.h @@ -25,8 +25,11 @@ class HexagonControlWrapper { bool SetupGraph(const std::string &model_file); bool ExecuteGraph(const Tensor &input_tensor, Tensor *output_tensor) { LOG(INFO) << "Execute graph: " << nn_id_; - output_tensor->SetDtype(output_data_type_); - output_tensor->Resize(output_shape_); + // single input and single output + MACE_ASSERT(num_inputs_ == 1, "Wrong inputs num"); + MACE_ASSERT(num_outputs_ == 1, "Wrong outputs num"); + output_tensor->SetDtype(output_data_types_[0]); + output_tensor->Resize(output_shapes_[0]); vector output_shape(4); uint32_t output_bytes; int res = hexagon_nn_execute(nn_id_, @@ -46,13 +49,58 @@ class HexagonControlWrapper { output_tensor->raw_size(), &output_bytes); - MACE_ASSERT(output_shape == output_shape_, + MACE_ASSERT(output_shape == output_shapes_[0], "wrong output shape inferred"); MACE_ASSERT(output_bytes == output_tensor->raw_size(), "wrong output bytes inferred."); return res == 0; }; + bool ExecuteGraphNew(const Tensor *input_tensors, int num_inputs, + Tensor *output_tensors, int num_outputs) { + LOG(INFO) << "Execute graph new: " << nn_id_; + MACE_ASSERT(num_inputs_ == num_inputs, "Wrong inputs num"); + MACE_ASSERT(num_outputs_ == num_outputs, "Wrong outputs num"); + + hexagon_nn_tensordef *inputs = new hexagon_nn_tensordef[num_inputs]; + hexagon_nn_tensordef *outputs = new hexagon_nn_tensordef[num_outputs]; + + for (int i = 0; i < num_inputs; ++i) { + vector input_shape = input_tensors[i].shape(); + inputs[i].batches = input_shape[0]; + inputs[i].height = input_shape[1]; + inputs[i].width = input_shape[2]; + inputs[i].depth = input_shape[3]; + inputs[i].data = const_cast( + reinterpret_cast(input_tensors[i].raw_data())); + inputs[i].dataLen = input_tensors[i].raw_size(); + inputs[i].data_valid_len = input_tensors[i].raw_size(); + inputs[i].unused = 0; + } + + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i].SetDtype(output_data_types_[i]); + output_tensors[i].Resize(output_shapes_[i]); + vector output_shape = output_tensors[0].shape(); + outputs[i].batches = output_shape[0]; + outputs[i].height = output_shape[1]; + outputs[i].width = output_shape[2]; + outputs[i].depth = output_shape[3]; + outputs[i].data = reinterpret_cast( + output_tensors[i].raw_mutable_data()); + outputs[i].dataLen = output_tensors[i].raw_size(); + outputs[i].data_valid_len = output_tensors[i].raw_size(); + outputs[i].unused = 0; + } + + int res = hexagon_nn_execute_new(nn_id_, inputs, num_inputs, + outputs, num_outputs); + + delete(inputs); + delete(outputs); + return res == 0; + }; + bool TeardownGraph(); void PrintLog(); void PrintGraph(); @@ -71,10 +119,12 @@ class HexagonControlWrapper { int nn_id_; Serializer serializer_; - vector input_shape_; - vector output_shape_; - DataType input_data_type_; - DataType output_data_type_; + vector> input_shapes_; + vector> output_shapes_; + vector input_data_types_; + vector output_data_types_; + uint32_t num_inputs_; + uint32_t num_outputs_; DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper); }; diff --git a/mace/dsp/test/supernode_test.cc b/mace/dsp/test/supernode_test.cc index 634795ecf5beb6adbb3f255666537897be960436..5178d5d647d381553661e7c6ac5c6cd364b7d504 100644 --- a/mace/dsp/test/supernode_test.cc +++ b/mace/dsp/test/supernode_test.cc @@ -224,7 +224,7 @@ TEST(SupernodeTest, Supernode) { input_data[h * 4 + w] = (uint8_t)((h == 0) ? 0 : h * 64 - 1); } - VLOG(0) << wrapper.ExecuteGraph(input_tensor, &output_tensor); + VLOG(0) << wrapper.ExecuteGraphNew(&input_tensor, 1, &output_tensor, 1); wrapper.PrintLog(); // expect out: [[49.2095, 49.2095], [50.7905, 50.7905]]