提交 adace180 编写于 作者: 卢旭辉

Merge branch 'update' into 'master'

Update for HTA

See merge request !1230
......@@ -308,19 +308,22 @@ Transform models after conversion
// Report error or fallback
}
std::vector<unsigned char> transformed_model_graph_data;
std::vector<unsigned char> transformed_model_weights_data;
// Add transformations here.
...
// Release original model data after transformations
model_graph_data.reset();
model_weights_data.reset();
// Create the MACE engine from the model buffer
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>(
model_graph_data->data()),
model_graph_data->length(),
reinterpret_cast<const unsigned char *>(
model_weights_data->data()),
model_weights_data->length(),
CreateMaceEngineFromProto(transformed_model_graph_data.data(),
transformed_model_graph_data.size(),
transformed_model_weights_data.data(),
transformed_model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -260,19 +260,22 @@ and ``model_data_format`` are set as `file`.
// Report error or fallback
}
std::vector<unsigned char> transformed_model_graph_data;
std::vector<unsigned char> transformed_model_weights_data;
// Add transformations here.
...
// Release original model data after transformations
model_graph_data.reset();
model_weights_data.reset();
// Create the MACE engine from the model buffer
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>(
model_graph_data->data()),
model_graph_data->length(),
reinterpret_cast<const unsigned char *>(
model_weights_data->data()),
model_weights_data->length(),
CreateMaceEngineFromProto(transformed_model_graph_data.data(),
transformed_model_graph_data.size(),
transformed_model_weights_data.data(),
transformed_model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -26,11 +26,11 @@ namespace mace {
class OpMap {
public:
void Init() {
#define HTA_DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME;
#define DEF_OP(NAME) op_map_[#NAME] = HTA_OP_##NAME;
#include "third_party/hta/hta_ops.h"
#undef HTA_DEF_OP
#undef DEF_OP
}
hta_op_type GetOpId(const std::string &op_type) {
......
......@@ -73,32 +73,24 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
tensor_shape.insert(tensor_shape.begin(), 1);
}
hexagon_nn_const_node const_node;
const_node.node_id = node_id(const_tensor.node_id());
const_node.tensor.batches = tensor_shape[0];
const_node.tensor.height = tensor_shape[1];
const_node.tensor.width = tensor_shape[2];
const_node.tensor.depth = tensor_shape[3];
if (const_tensor.data_type() == DataType::DT_INT32 &&
const_tensor.data_size() == 0) {
const_node.tensor.data = NULL;
const_node.tensor.dataLen = 0;
} else {
const_node.tensor.data =
unsigned char *const_node_data = nullptr;
int const_node_data_len = 0;
if (!(const_tensor.data_type() == DataType::DT_INT32 &&
const_tensor.data_size() == 0)) {
const_node_data =
const_cast<unsigned char *>(model_data + const_tensor.offset());
const_node.tensor.dataLen = const_tensor.data_size() *
const_node_data_len = const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type());
}
hexagon_hta_nn_append_const_node(nn_id_,
const_node.node_id,
const_node.tensor.batches,
const_node.tensor.height,
const_node.tensor.width,
const_node.tensor.depth,
const_node.tensor.data,
const_node.tensor.dataLen);
node_id(const_tensor.node_id()),
tensor_shape[0],
tensor_shape[1],
tensor_shape[2],
tensor_shape[3],
const_node_data,
const_node_data_len);
}
// op node
......@@ -137,23 +129,14 @@ bool HexagonHTAWrapper::SetupGraph(const NetDef &net_def,
auto padding_type = static_cast<hta_padding_type>(op.padding());
hexagon_nn_op_node op_node;
op_node.node_id = node_id(op.node_id());
op_node.operation = op_id;
op_node.padding = padding_type;
op_node.inputs = cached_inputs.back().data();
op_node.inputsLen = inputs.size();
op_node.outputs = cached_outputs.back().data();
op_node.outputsLen = outputs.size();
hexagon_hta_nn_append_node(nn_id_,
op_node.node_id,
op_node.operation,
op_node.padding,
op_node.inputs,
op_node.inputsLen,
op_node.outputs,
op_node.outputsLen);
node_id(op.node_id()),
op_id,
padding_type,
cached_inputs.back().data(),
inputs.size(),
cached_outputs.back().data(),
outputs.size());
}
// input info
......
......@@ -600,6 +600,9 @@ MaceStatus MaceEngine::Impl::Init(
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
if (device_type_ == HEXAGON || device_type_ == HTA) {
hexagon_controller_ = CreateHexagonControlWrapper(device_.get());
LOG(INFO) << "Hexagon " << (device_type_ == HEXAGON ? "DSP" : "HTA")
<< " version: 0x" << std::hex
<< hexagon_controller_->GetVersion();
MACE_CHECK(hexagon_controller_->Config(), "hexagon config error");
MACE_CHECK(hexagon_controller_->Init(), "hexagon init error");
hexagon_controller_->SetDebugLevel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册