From fc1dd4f642cdc0db18d1f3cc281a9a6ea34ace3f Mon Sep 17 00:00:00 2001 From: Bin Li Date: Thu, 21 Nov 2019 16:08:18 +0800 Subject: [PATCH] Update for HTA --- docs/user_guide/advanced_usage.rst | 15 +++-- docs/user_guide/advanced_usage_cmake.rst | 15 +++-- mace/core/runtime/hexagon/hexagon_hta_ops.h | 4 +- .../runtime/hexagon/hexagon_hta_wrapper.cc | 57 +++++++------------ mace/libmace/mace.cc | 3 + 5 files changed, 43 insertions(+), 51 deletions(-) diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index f2d96b89..091fcb98 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -308,19 +308,22 @@ Transform models after conversion // Report error or fallback } + std::vector transformed_model_graph_data; + std::vector transformed_model_weights_data; // Add transformations here. ... + // Release original model data after transformations + model_graph_data.reset(); + model_weights_data.reset(); // Create the MACE engine from the model buffer std::shared_ptr engine; MaceStatus create_engine_status; create_engine_status = - CreateMaceEngineFromProto(reinterpret_cast( - model_graph_data->data()), - model_graph_data->length(), - reinterpret_cast( - model_weights_data->data()), - model_weights_data->length(), + CreateMaceEngineFromProto(transformed_model_graph_data.data(), + transformed_model_graph_data.size(), + transformed_model_weights_data.data(), + transformed_model_weights_data.size(), input_names, output_names, config, diff --git a/docs/user_guide/advanced_usage_cmake.rst b/docs/user_guide/advanced_usage_cmake.rst index 842fab3e..87d17fe4 100644 --- a/docs/user_guide/advanced_usage_cmake.rst +++ b/docs/user_guide/advanced_usage_cmake.rst @@ -260,19 +260,22 @@ and ``model_data_format`` are set as `file`. // Report error or fallback } + std::vector transformed_model_graph_data; + std::vector transformed_model_weights_data; // Add transformations here. ... + // Release original model data after transformations + model_graph_data.reset(); + model_weights_data.reset(); // Create the MACE engine from the model buffer std::shared_ptr engine; MaceStatus create_engine_status; create_engine_status = - CreateMaceEngineFromProto(reinterpret_cast( - model_graph_data->data()), - model_graph_data->length(), - reinterpret_cast( - model_weights_data->data()), - model_weights_data->length(), + CreateMaceEngineFromProto(transformed_model_graph_data.data(), + transformed_model_graph_data.size(), + transformed_model_weights_data.data(), + transformed_model_weights_data.size(), input_names, output_names, config, diff --git a/mace/core/runtime/hexagon/hexagon_hta_ops.h b/mace/core/runtime/hexagon/hexagon_hta_ops.h index 39a10860..936fd526 100644 --- a/mace/core/runtime/hexagon/hexagon_hta_ops.h +++ b/mace/core/runtime/hexagon/hexagon_hta_ops.h @@ -26,11 +26,11 @@ namespace mace { class OpMap { public: void Init() { -#define HTA_DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME; +#define DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME; #include "third_party/hta/hta_ops.h" -#undef HTA_DEF_OP +#undef DEF_OP } hta_op_type GetOpId(const std::string &op_type) { diff --git a/mace/core/runtime/hexagon/hexagon_hta_wrapper.cc b/mace/core/runtime/hexagon/hexagon_hta_wrapper.cc index a568cd53..07a7a5e9 100644 --- a/mace/core/runtime/hexagon/hexagon_hta_wrapper.cc +++ b/mace/core/runtime/hexagon/hexagon_hta_wrapper.cc @@ -73,32 +73,24 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, tensor_shape.insert(tensor_shape.begin(), 1); } - hexagon_nn_const_node const_node; - const_node.node_id = node_id(const_tensor.node_id()); - const_node.tensor.batches = tensor_shape[0]; - const_node.tensor.height = tensor_shape[1]; - const_node.tensor.width = tensor_shape[2]; - const_node.tensor.depth = tensor_shape[3]; - - if (const_tensor.data_type() == DataType::DT_INT32 && - const_tensor.data_size() == 0) { - const_node.tensor.data = NULL; - const_node.tensor.dataLen = 0; - } else { - const_node.tensor.data = + unsigned char *const_node_data = nullptr; + int const_node_data_len = 0; + if (!(const_tensor.data_type() == DataType::DT_INT32 && + const_tensor.data_size() == 0)) { + const_node_data = const_cast(model_data + const_tensor.offset()); - const_node.tensor.dataLen = const_tensor.data_size() * + const_node_data_len = const_tensor.data_size() * GetEnumTypeSize(const_tensor.data_type()); } hexagon_hta_nn_append_const_node(nn_id_, - const_node.node_id, - const_node.tensor.batches, - const_node.tensor.height, - const_node.tensor.width, - const_node.tensor.depth, - const_node.tensor.data, - const_node.tensor.dataLen); + node_id(const_tensor.node_id()), + tensor_shape[0], + tensor_shape[1], + tensor_shape[2], + tensor_shape[3], + const_node_data, + const_node_data_len); } // op node @@ -137,23 +129,14 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, auto padding_type = static_cast(op.padding()); - hexagon_nn_op_node op_node; - op_node.node_id = node_id(op.node_id()); - op_node.operation = op_id; - op_node.padding = padding_type; - op_node.inputs = cached_inputs.back().data(); - op_node.inputsLen = inputs.size(); - op_node.outputs = cached_outputs.back().data(); - op_node.outputsLen = outputs.size(); - hexagon_hta_nn_append_node(nn_id_, - op_node.node_id, - op_node.operation, - op_node.padding, - op_node.inputs, - op_node.inputsLen, - op_node.outputs, - op_node.outputsLen); + node_id(op.node_id()), + op_id, + padding_type, + cached_inputs.back().data(), + inputs.size(), + cached_outputs.back().data(), + outputs.size()); } // input info diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index eb79da6f..98ddf484 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -600,6 +600,9 @@ MaceStatus MaceEngine::Impl::Init( #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) if (device_type_ == HEXAGON || device_type_ == HTA) { hexagon_controller_ = CreateHexagonControlWrapper(device_.get()); + LOG(INFO) << "Hexagon " << (device_type_ == HEXAGON ? "DSP" : "HTA") + << " version: 0x" << std::hex + << hexagon_controller_->GetVersion(); MACE_CHECK(hexagon_controller_->Config(), "hexagon config error"); MACE_CHECK(hexagon_controller_->Init(), "hexagon init error"); hexagon_controller_->SetDebugLevel( -- GitLab