提交 7787dfde 编写于 作者: 李寅

Merge branch 'feature_wuch' into 'master'

add multi-inputs & multi-outputs support

See merge request !122
......@@ -111,22 +111,32 @@ bool HexagonControlWrapper::SetupGraph(const NetDef& net_def) {
}
// input info
const InputInfo& input_info = net_def.input_info()[0];
input_shape_.insert(input_shape_.begin(),
input_info.dims().begin(), input_info.dims().end());
while (input_shape_.size() < 4) {
input_shape_.insert(input_shape_.begin(), 1);
num_inputs_ = 0;
for (const InputInfo &input_info: net_def.input_info()) {
vector<index_t> input_shape;
input_shape.insert(input_shape.begin(),
input_info.dims().begin(), input_info.dims().end());
while (input_shape.size() < 4) {
input_shape.insert(input_shape.begin(), 1);
}
input_shapes_.push_back(input_shape);
input_data_types_.push_back(input_info.data_type());
num_inputs_ += 1;
}
input_data_type_ = input_info.data_type();
// output info
const OutputInfo& output_info = net_def.output_info()[0];
output_shape_.insert(output_shape_.begin(),
output_info.dims().begin(), output_info.dims().end());
while (output_shape_.size() < 4) {
output_shape_.insert(output_shape_.begin(), 1);
num_outputs_ = 0;
for (const OutputInfo &output_info: net_def.output_info()) {
vector<index_t> output_shape;
output_shape.insert(output_shape.begin(),
output_info.dims().begin(), output_info.dims().end());
while (output_shape.size() < 4) {
output_shape.insert(output_shape.begin(), 1);
}
output_shapes_.push_back(output_shape);
output_data_types_.push_back(output_info.data_type());
num_outputs_ += 1;
}
output_data_type_ = output_info.data_type();
bool res = hexagon_nn_prepare(nn_id_) == 0;
return res;
......
......@@ -25,8 +25,11 @@ class HexagonControlWrapper {
bool SetupGraph(const std::string &model_file);
bool ExecuteGraph(const Tensor &input_tensor, Tensor *output_tensor) {
LOG(INFO) << "Execute graph: " << nn_id_;
output_tensor->SetDtype(output_data_type_);
output_tensor->Resize(output_shape_);
// single input and single output
MACE_ASSERT(num_inputs_ == 1, "Wrong inputs num");
MACE_ASSERT(num_outputs_ == 1, "Wrong outputs num");
output_tensor->SetDtype(output_data_types_[0]);
output_tensor->Resize(output_shapes_[0]);
vector<uint32_t> output_shape(4);
uint32_t output_bytes;
int res = hexagon_nn_execute(nn_id_,
......@@ -46,13 +49,58 @@ class HexagonControlWrapper {
output_tensor->raw_size(),
&output_bytes);
MACE_ASSERT(output_shape == output_shape_,
MACE_ASSERT(output_shape == output_shapes_[0],
"wrong output shape inferred");
MACE_ASSERT(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred.");
return res == 0;
};
bool ExecuteGraphNew(const Tensor *input_tensors, int num_inputs,
Tensor *output_tensors, int num_outputs) {
LOG(INFO) << "Execute graph new: " << nn_id_;
MACE_ASSERT(num_inputs_ == num_inputs, "Wrong inputs num");
MACE_ASSERT(num_outputs_ == num_outputs, "Wrong outputs num");
hexagon_nn_tensordef *inputs = new hexagon_nn_tensordef[num_inputs];
hexagon_nn_tensordef *outputs = new hexagon_nn_tensordef[num_outputs];
for (int i = 0; i < num_inputs; ++i) {
vector<index_t> input_shape = input_tensors[i].shape();
inputs[i].batches = input_shape[0];
inputs[i].height = input_shape[1];
inputs[i].width = input_shape[2];
inputs[i].depth = input_shape[3];
inputs[i].data = const_cast<unsigned char *>(
reinterpret_cast<const unsigned char *>(input_tensors[i].raw_data()));
inputs[i].dataLen = input_tensors[i].raw_size();
inputs[i].data_valid_len = input_tensors[i].raw_size();
inputs[i].unused = 0;
}
for (int i = 0; i < num_outputs; ++i) {
output_tensors[i].SetDtype(output_data_types_[i]);
output_tensors[i].Resize(output_shapes_[i]);
vector<index_t> output_shape = output_tensors[0].shape();
outputs[i].batches = output_shape[0];
outputs[i].height = output_shape[1];
outputs[i].width = output_shape[2];
outputs[i].depth = output_shape[3];
outputs[i].data = reinterpret_cast<unsigned char *>(
output_tensors[i].raw_mutable_data());
outputs[i].dataLen = output_tensors[i].raw_size();
outputs[i].data_valid_len = output_tensors[i].raw_size();
outputs[i].unused = 0;
}
int res = hexagon_nn_execute_new(nn_id_, inputs, num_inputs,
outputs, num_outputs);
delete(inputs);
delete(outputs);
return res == 0;
};
bool TeardownGraph();
void PrintLog();
void PrintGraph();
......@@ -71,10 +119,12 @@ class HexagonControlWrapper {
int nn_id_;
Serializer serializer_;
vector<index_t> input_shape_;
vector<index_t> output_shape_;
DataType input_data_type_;
DataType output_data_type_;
vector<vector<index_t>> input_shapes_;
vector<vector<index_t>> output_shapes_;
vector<DataType> input_data_types_;
vector<DataType> output_data_types_;
uint32_t num_inputs_;
uint32_t num_outputs_;
DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
};
......
......@@ -224,7 +224,7 @@ TEST(SupernodeTest, Supernode) {
input_data[h * 4 + w] = (uint8_t)((h == 0) ? 0 : h * 64 - 1);
}
VLOG(0) << wrapper.ExecuteGraph(input_tensor, &output_tensor);
VLOG(0) << wrapper.ExecuteGraphNew(&input_tensor, 1, &output_tensor, 1);
wrapper.PrintLog();
// expect out: [[49.2095, 49.2095], [50.7905, 50.7905]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册