提交 adace180 编写于 作者: 卢旭辉

Merge branch 'update' into 'master'

Update for HTA

See merge request !1230
...@@ -308,19 +308,22 @@ Transform models after conversion ...@@ -308,19 +308,22 @@ Transform models after conversion
// Report error or fallback // Report error or fallback
} }
std::vector<unsigned char> transformed_model_graph_data;
std::vector<unsigned char> transformed_model_weights_data;
// Add transformations here. // Add transformations here.
... ...
// Release original model data after transformations
model_graph_data.reset();
model_weights_data.reset();
// Create the MACE engine from the model buffer // Create the MACE engine from the model buffer
std::shared_ptr<mace::MaceEngine> engine; std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status; MaceStatus create_engine_status;
create_engine_status = create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>( CreateMaceEngineFromProto(transformed_model_graph_data.data(),
model_graph_data->data()), transformed_model_graph_data.size(),
model_graph_data->length(), transformed_model_weights_data.data(),
reinterpret_cast<const unsigned char *>( transformed_model_weights_data.size(),
model_weights_data->data()),
model_weights_data->length(),
input_names, input_names,
output_names, output_names,
config, config,
......
...@@ -260,19 +260,22 @@ and ``model_data_format`` are set as `file`. ...@@ -260,19 +260,22 @@ and ``model_data_format`` are set as `file`.
// Report error or fallback // Report error or fallback
} }
std::vector<unsigned char> transformed_model_graph_data;
std::vector<unsigned char> transformed_model_weights_data;
// Add transformations here. // Add transformations here.
... ...
// Release original model data after transformations
model_graph_data.reset();
model_weights_data.reset();
// Create the MACE engine from the model buffer // Create the MACE engine from the model buffer
std::shared_ptr<mace::MaceEngine> engine; std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status; MaceStatus create_engine_status;
create_engine_status = create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>( CreateMaceEngineFromProto(transformed_model_graph_data.data(),
model_graph_data->data()), transformed_model_graph_data.size(),
model_graph_data->length(), transformed_model_weights_data.data(),
reinterpret_cast<const unsigned char *>( transformed_model_weights_data.size(),
model_weights_data->data()),
model_weights_data->length(),
input_names, input_names,
output_names, output_names,
config, config,
......
...@@ -26,11 +26,11 @@ namespace mace { ...@@ -26,11 +26,11 @@ namespace mace {
class OpMap { class OpMap {
public: public:
void Init() { void Init() {
#define HTA_DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME; #define DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME;
#include "third_party/hta/hta_ops.h" #include "third_party/hta/hta_ops.h"
#undef HTA_DEF_OP #undef DEF_OP
} }
hta_op_type GetOpId(const std::string &op_type) { hta_op_type GetOpId(const std::string &op_type) {
......
...@@ -73,32 +73,24 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, ...@@ -73,32 +73,24 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
tensor_shape.insert(tensor_shape.begin(), 1); tensor_shape.insert(tensor_shape.begin(), 1);
} }
hexagon_nn_const_node const_node; unsigned char *const_node_data = nullptr;
const_node.node_id = node_id(const_tensor.node_id()); int const_node_data_len = 0;
const_node.tensor.batches = tensor_shape[0]; if (!(const_tensor.data_type() == DataType::DT_INT32 &&
const_node.tensor.height = tensor_shape[1]; const_tensor.data_size() == 0)) {
const_node.tensor.width = tensor_shape[2]; const_node_data =
const_node.tensor.depth = tensor_shape[3];
if (const_tensor.data_type() == DataType::DT_INT32 &&
const_tensor.data_size() == 0) {
const_node.tensor.data = NULL;
const_node.tensor.dataLen = 0;
} else {
const_node.tensor.data =
const_cast<unsigned char *>(model_data + const_tensor.offset()); const_cast<unsigned char *>(model_data + const_tensor.offset());
const_node.tensor.dataLen = const_tensor.data_size() * const_node_data_len = const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type()); GetEnumTypeSize(const_tensor.data_type());
} }
hexagon_hta_nn_append_const_node(nn_id_, hexagon_hta_nn_append_const_node(nn_id_,
const_node.node_id, node_id(const_tensor.node_id()),
const_node.tensor.batches, tensor_shape[0],
const_node.tensor.height, tensor_shape[1],
const_node.tensor.width, tensor_shape[2],
const_node.tensor.depth, tensor_shape[3],
const_node.tensor.data, const_node_data,
const_node.tensor.dataLen); const_node_data_len);
} }
// op node // op node
...@@ -137,23 +129,14 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def, ...@@ -137,23 +129,14 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
auto padding_type = static_cast<hta_padding_type>(op.padding()); auto padding_type = static_cast<hta_padding_type>(op.padding());
hexagon_nn_op_node op_node;
op_node.node_id = node_id(op.node_id());
op_node.operation = op_id;
op_node.padding = padding_type;
op_node.inputs = cached_inputs.back().data();
op_node.inputsLen = inputs.size();
op_node.outputs = cached_outputs.back().data();
op_node.outputsLen = outputs.size();
hexagon_hta_nn_append_node(nn_id_, hexagon_hta_nn_append_node(nn_id_,
op_node.node_id, node_id(op.node_id()),
op_node.operation, op_id,
op_node.padding, padding_type,
op_node.inputs, cached_inputs.back().data(),
op_node.inputsLen, inputs.size(),
op_node.outputs, cached_outputs.back().data(),
op_node.outputsLen); outputs.size());
} }
// input info // input info
......
...@@ -600,6 +600,9 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -600,6 +600,9 @@ MaceStatus MaceEngine::Impl::Init(
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
if (device_type_ == HEXAGON || device_type_ == HTA) { if (device_type_ == HEXAGON || device_type_ == HTA) {
hexagon_controller_ = CreateHexagonControlWrapper(device_.get()); hexagon_controller_ = CreateHexagonControlWrapper(device_.get());
LOG(INFO) << "Hexagon " << (device_type_ == HEXAGON ? "DSP" : "HTA")
<< " version: 0x" << std::hex
<< hexagon_controller_->GetVersion();
MACE_CHECK(hexagon_controller_->Config(), "hexagon config error"); MACE_CHECK(hexagon_controller_->Config(), "hexagon config error");
MACE_CHECK(hexagon_controller_->Init(), "hexagon init error"); MACE_CHECK(hexagon_controller_->Init(), "hexagon init error");
hexagon_controller_->SetDebugLevel( hexagon_controller_->SetDebugLevel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册