提交 629cb8a8 编写于 作者: 李寅

Add dsp wrapper and model convert tool

上级 dff4b94c
......@@ -59,9 +59,9 @@ class CPUAllocator : public Allocator {
}
void Delete(void *data) override { free(data); }
void *Map(void *buffer, size_t nbytes) { return buffer; }
void Unmap(void *buffer, void *mapper_ptr) {}
bool OnHost() { return true; }
void *Map(void *buffer, size_t nbytes) override { return buffer; }
void Unmap(void *buffer, void *mapper_ptr) override {}
bool OnHost() override { return true; }
};
std::map<int32_t, Allocator *> *gAllocatorRegistry();
......
......@@ -45,6 +45,22 @@ namespace mace {
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
namespace numerical_chars {
inline std::ostream &operator<<(std::ostream &os, char c) {
return std::is_signed<char>::value ? os << static_cast<int>(c)
: os << static_cast<unsigned int>(c);
}
inline std::ostream &operator<<(std::ostream &os, signed char c) {
return os << static_cast<int>(c);
}
inline std::ostream &operator<<(std::ostream &os, unsigned char c) {
return os << static_cast<unsigned int>(c);
}
}
class Tensor {
public:
Tensor()
......@@ -71,6 +87,8 @@ class Tensor {
inline DataType dtype() const { return dtype_; }
inline void SetDtype(DataType dtype) { dtype_ = dtype; }
inline const vector<index_t> &shape() const { return shape_; }
inline index_t dim_size() const { return shape_.size(); }
......@@ -82,6 +100,8 @@ class Tensor {
inline index_t size() const { return size_; }
inline index_t raw_size() const { return size_ * SizeOfType(); }
inline int64_t NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<int64_t>());
......@@ -177,6 +197,7 @@ class Tensor {
}
inline void DebugPrint() const {
using namespace numerical_chars;
std::stringstream os;
for (int i : shape_) {
os << i << ", ";
......
# Description:
# Mace dsp.
#
# Only suport arm-v7a now
# bazel build -c opt mace/dsp:dsp --crosstool_top=//external:android/crosstool --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --cpu=armeabi-v7a --verbose_failures
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android")
cc_library(
name = "dsp",
srcs = glob([
"*.cc",
"hexagon/libhexagon_controller.so",
], exclude = [
"*_test.cc",
]),
hdrs = glob([
"*.h",
"hexagon/*.h",
]),
copts = ["-std=c++11"],
deps = [
"//mace/proto:cc_proto",
"//mace/core:core",
],
)
cc_test(
name = "dsp_test",
testonly = 1,
srcs = glob(["*_test.cc"]),
copts = ["-std=c++11"],
linkopts = if_android([
"-ldl",
"-lm",
]),
linkstatic = 1,
deps = [
"@gtest//:gtest_main",
":dsp",
],
)
$(info ------------------------------------------)
$(info --- V = $(V))
$(info --- GLUE_DIR = $(GLUE_DIR))
$(info --- HEXAGON_SDK_ROOT = $(HEXAGON_SDK_ROOT))
$(info ------------------------------------------)
INCDIRS += ../../../libs/common/adspmsgd/ship/android_Release
LIBDIRS += ../../../libs/common/adspmsgd/ship/android_Release
BUILD_DLLS=libhexagon_controller
hexagon_controller_lib_QAICIDLS += \
interface/hexagon_nn \
$(MAKE_D_DSPCV_INCDIR)/dspCV
# hexagon interface
hexagon_controller_lib_C_SRCS += \
$V/hexagon_nn_stub \
$V/dspCV_stub
hexagon_controller_lib_DLLS += libcdsprpc
hexagon_controller_lib_LIBS += rpcmem adspmsgd
hexagon_controller_lib_LD_FLAGS += -llog
hexagon_controller_lib_DEFINES += VERIFY_PRINT_ERROR
libhexagon_controller_QAICIDLS += $(hexagon_controller_lib_QAICIDLS)
libhexagon_controller_C_SRCS += $(hexagon_controller_lib_C_SRCS)
libhexagon_controller_DLLS += $(hexagon_controller_lib_DLLS)
libhexagon_controller_LIBS += $(hexagon_controller_lib_LIBS)
libhexagon_controller_LD_FLAGS += $(hexagon_controller_lib_LD_FLAGS)
libhexagon_controller_DEFINES += $(hexagon_controller_lib_DEFINES)
BUILD_COPIES = \
$(DLLS) \
$(EXES) \
$(LIBS) \
$(SHIP_DIR)/ ;
#ifndef _HEXAGON_NN_H
#define _HEXAGON_NN_H
#ifndef __QAIC_HEADER
#define __QAIC_HEADER(ff) ff
#endif //__QAIC_HEADER
#ifndef __QAIC_HEADER_EXPORT
#define __QAIC_HEADER_EXPORT
#endif // __QAIC_HEADER_EXPORT
#ifndef __QAIC_HEADER_ATTRIBUTE
#define __QAIC_HEADER_ATTRIBUTE
#endif // __QAIC_HEADER_ATTRIBUTE
#ifndef __QAIC_IMPL
#define __QAIC_IMPL(ff) ff
#endif //__QAIC_IMPL
#ifndef __QAIC_IMPL_EXPORT
#define __QAIC_IMPL_EXPORT
#endif // __QAIC_IMPL_EXPORT
#ifndef __QAIC_IMPL_ATTRIBUTE
#define __QAIC_IMPL_ATTRIBUTE
#endif // __QAIC_IMPL_ATTRIBUTE
#ifdef __cplusplus
extern "C" {
#endif
#if !defined(__QAIC_STRING1_OBJECT_DEFINED__) && !defined(__STRING1_OBJECT__)
#define __QAIC_STRING1_OBJECT_DEFINED__
#define __STRING1_OBJECT__
typedef struct _cstring1_s {
char* data;
int dataLen;
} _cstring1_t;
#endif /* __QAIC_STRING1_OBJECT_DEFINED__ */
typedef struct hexagon_nn_input hexagon_nn_input;
struct hexagon_nn_input {
unsigned int src_id;
unsigned int output_idx;
};
typedef struct hexagon_nn_output hexagon_nn_output;
struct hexagon_nn_output {
unsigned int max_size;
unsigned int unused;
};
typedef struct hexagon_nn_perfinfo hexagon_nn_perfinfo;
struct hexagon_nn_perfinfo {
unsigned int node_id;
unsigned int executions;
unsigned int counter_lo;
unsigned int counter_hi;
};
typedef int hexagon_nn_nn_id;
enum hexagon_nn_padding_type {
NN_PAD_NA,
NN_PAD_SAME,
NN_PAD_VALID,
NN_PAD_MIRROR_REFLECT,
NN_PAD_MIRROR_SYMMETRIC,
NN_PAD_SAME_CAFFE,
_32BIT_PLACEHOLDER_hexagon_nn_padding_type = 0x7fffffff
};
typedef enum hexagon_nn_padding_type hexagon_nn_padding_type;
typedef struct hexagon_nn_tensordef hexagon_nn_tensordef;
struct hexagon_nn_tensordef {
unsigned int batches;
unsigned int height;
unsigned int width;
unsigned int depth;
unsigned char* data;
int dataLen;
unsigned int data_valid_len;
unsigned int unused;
};
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_config)(void) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_init)(void) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_set_debug_level)(hexagon_nn_nn_id id, int level) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_snpprint)(hexagon_nn_nn_id id, unsigned char* buf, int bufLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_getlog)(hexagon_nn_nn_id id, unsigned char* buf, int bufLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_append_node)(hexagon_nn_nn_id id, unsigned int node_id, unsigned int operation, hexagon_nn_padding_type padding, const hexagon_nn_input* inputs, int inputsLen, const hexagon_nn_output* outputs, int outputsLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_append_const_node)(hexagon_nn_nn_id id, unsigned int node_id, unsigned int batches, unsigned int height, unsigned int width, unsigned int depth, const unsigned char* data, int dataLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_prepare)(hexagon_nn_nn_id id) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_execute)(hexagon_nn_nn_id id, unsigned int batches_in, unsigned int height_in, unsigned int width_in, unsigned int depth_in, const unsigned char* data_in, int data_inLen, unsigned int* batches_out, unsigned int* height_out, unsigned int* width_out, unsigned int* depth_out, unsigned char* data_out, int data_outLen, unsigned int* data_len_out) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_teardown)(hexagon_nn_nn_id id) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_set_powersave_level)(unsigned int level) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_get_perfinfo)(hexagon_nn_nn_id id, hexagon_nn_perfinfo* info_out, int info_outLen, unsigned int* n_items) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_reset_perfinfo)(hexagon_nn_nn_id id, unsigned int event) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_last_execution_cycles)(hexagon_nn_nn_id id, unsigned int* cycles_lo, unsigned int* cycles_hi) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_version)(int* ver) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_op_name_to_id)(const char* name, unsigned int* node_id) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_op_id_to_name)(unsigned int node_id, char* name, int nameLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_disable_dcvs)(void) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_GetHexagonBinaryVersion)(int* ver) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_PrintLog)(const unsigned char* buf, int bufLen) __QAIC_HEADER_ATTRIBUTE;
__QAIC_HEADER_EXPORT int __QAIC_HEADER(hexagon_nn_execute_new)(hexagon_nn_nn_id id, const hexagon_nn_tensordef* inputs, int inputsLen, hexagon_nn_tensordef* outputs, int outputsLen) __QAIC_HEADER_ATTRIBUTE;
#ifdef __cplusplus
}
#endif
#endif //_HEXAGON_NN_H
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/dsp/hexagon_control_wrapper.h"
#include <fstream>
namespace mace {
int HexagonControlWrapper::GetVersion() {
int version;
hexagon_nn_version(&version);
return version;
}
bool HexagonControlWrapper::Config() {
LOG(INFO) << "Hexagon config";
return hexagon_nn_config();
}
bool HexagonControlWrapper::Init() {
LOG(INFO) << "Hexagon init";
op_map_.Init();
// TODO(liyin): dspCV init
nn_id_ = hexagon_nn_init();
return true;
}
bool HexagonControlWrapper::Finalize() {
LOG(INFO) << "Hexagon finalize";
// TODO(liyin): dspCV deinit
return true;
}
bool HexagonControlWrapper::SetupGraph(NetDef net_def) {
LOG(INFO) << "Hexagon setup graph";
// const node
for (const TensorProto& tensor_proto: net_def.tensors()) {
vector<int> tensor_shape(tensor_proto.dims().begin(), tensor_proto.dims().end());
while (tensor_shape.size() < 4) {
tensor_shape.insert(tensor_shape.begin(), 1);
}
if (tensor_proto.data_type() == DataType::DT_INT32
&& tensor_proto.int32_data_size() == 0) {
hexagon_nn_append_const_node(nn_id_, node_id(tensor_proto.node_id()),
tensor_shape[0], tensor_shape[1],
tensor_shape[2], tensor_shape[3],
NULL,
0);
} else {
unique_ptr<Tensor> tensor = serializer_.Deserialize(tensor_proto, DeviceType::CPU);
VLOG(0) << "Tensor size: " << tensor->size();
hexagon_nn_append_const_node(nn_id_, node_id(tensor_proto.node_id()),
tensor_shape[0], tensor_shape[1],
tensor_shape[2], tensor_shape[3],
reinterpret_cast<const unsigned char *>(
tensor->raw_data()),
tensor->raw_size());
}
VLOG(0) << "Const: " << tensor_proto.name() << ", node_id: " << node_id(tensor_proto.node_id())
<< "\n\t shape: " << tensor_shape[0] << " " << tensor_shape[1] << " " << tensor_shape[2] << " " << tensor_shape[3];
}
// op node
for (const OperatorDef& op: net_def.op()) {
int op_id = op_map_.GetOpId(op.type());
MACE_CHECK(op_id != OP_INVALID, "invalid op: ", op.name());
vector<hexagon_nn_input> inputs(op.node_input_size());
for (size_t i = 0; i < op.node_input_size(); ++i) {
inputs[i].src_id = node_id(op.node_input(i).node_id());
inputs[i].output_idx = op.node_input(i).output_port();
}
vector<hexagon_nn_output> outputs(op.out_max_byte_size_size());
for (size_t i = 0; i < op.out_max_byte_size_size(); ++i) {
outputs[i].max_size = op.out_max_byte_size(i);
}
hexagon_nn_padding_type padding_type = static_cast<hexagon_nn_padding_type>(
op.padding());
hexagon_nn_append_node(nn_id_, node_id(op.node_id()), op_id, padding_type,
inputs.data(), inputs.size(), outputs.data(), outputs.size());
VLOG(0) << "Op: " << op.name() << ", type: " << op.type() << ", node_id: " << node_id(op.node_id()) << ", padding_type: " << padding_type;
for (const auto& input: inputs) {
VLOG(0) << "\t input: " << input.src_id << ":" << input.output_idx;
}
for (const auto& output: outputs) {
VLOG(0) << "\t output: " << output.max_size;
}
}
// input info
const InputInfo& input_info = net_def.input_info()[0];
input_shape_.insert(input_shape_.begin(),
input_info.dims().begin(), input_info.dims().end());
while (input_shape_.size() < 4) {
input_shape_.insert(input_shape_.begin(), 1);
}
input_data_type_ = input_info.data_type();
// output info
const OutputInfo& output_info = net_def.output_info()[0];
output_shape_.insert(output_shape_.begin(),
output_info.dims().begin(), output_info.dims().end());
while (output_shape_.size() < 4) {
output_shape_.insert(output_shape_.begin(), 1);
}
output_data_type_ = output_info.data_type();
bool res = hexagon_nn_prepare(nn_id_) == 0;
return res;
}
bool HexagonControlWrapper::SetupGraph(const std::string& model_file) {
std::ifstream file_stream(model_file, std::ios::in | std::ios::binary);
NetDef net_def;
net_def.ParseFromIstream(&file_stream);
file_stream.close();
return SetupGraph(net_def);
}
bool HexagonControlWrapper::TeardownGraph() {
LOG(INFO) << "Hexagon teardown graph";
return hexagon_nn_teardown(nn_id_) == 0;
}
#define PRINT_BUFSIZE (2*1024*1024)
void HexagonControlWrapper::PrintLog() {
LOG(INFO) << "Print Log";
char *buf;
unsigned char *p;
if ((buf = new char[PRINT_BUFSIZE]) == NULL) return;
hexagon_nn_getlog(nn_id_, reinterpret_cast<unsigned char*>(buf), PRINT_BUFSIZE);
LOG(INFO) << string(buf);
delete []buf;
}
void HexagonControlWrapper::PrintGraph() {
LOG(INFO) << "Print Graph";
char *buf;
unsigned char *p;
if ((buf = new char[PRINT_BUFSIZE]) == NULL) return;
hexagon_nn_snpprint(nn_id_, reinterpret_cast<unsigned char*>(buf), PRINT_BUFSIZE);
LOG(INFO) << string(buf);
delete []buf;
}
void HexagonControlWrapper::SetDebugLevel(int level) {
LOG(INFO) << "Set debug level: " << level;
hexagon_nn_set_debug_level(nn_id_, level);
}
void HexagonControlWrapper::GetPerfInfo() {
LOG(INFO) << "Get perf info";
vector<hexagon_nn_perfinfo> perf_info(10000);
unsigned int n_items;
hexagon_nn_get_perfinfo(nn_id_, perf_info.data(), 10000, &n_items);
for (int i = 0; i < n_items; ++i) {
LOG(INFO) << "node id: " << perf_info[i].node_id
<< ", executions: " << perf_info[i].executions
<< ", counter_hi: " << perf_info[i].counter_hi
<< ", counter_lo: " << perf_info[i].counter_lo;
}
}
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_DSP_HEXAGON_CONTROL_WRAPPER_H_
#define MACE_DSP_HEXAGON_CONTROL_WRAPPER_H_
#include "mace/dsp/hexagon/hexagon_nn.h"
#include "mace/dsp/hexagon_nn_ops.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/serializer.h"
namespace mace {
class HexagonControlWrapper {
public:
HexagonControlWrapper() {};
int GetVersion();
bool Config();
bool Init();
bool Finalize();
bool SetupGraph(NetDef net_def);
bool SetupGraph(const std::string &model_file);
bool ExecuteGraph(const Tensor &input_tensor, Tensor *output_tensor) {
LOG(INFO) << "Execute graph: " << nn_id_;
output_tensor->SetDtype(output_data_type_);
output_tensor->Resize(output_shape_);
vector<uint32_t> output_shape(4);
uint32_t output_bytes;
int res = hexagon_nn_execute(nn_id_,
input_tensor.shape()[0],
input_tensor.shape()[1],
input_tensor.shape()[2],
input_tensor.shape()[3],
reinterpret_cast<const unsigned char *>(
input_tensor.raw_data()),
input_tensor.raw_size(),
&output_shape[0],
&output_shape[1],
&output_shape[2],
&output_shape[3],
reinterpret_cast<unsigned char *>(
output_tensor->raw_mutable_data()),
output_tensor->raw_size(),
&output_bytes);
MACE_ASSERT(output_shape == output_shape_,
"wrong output shape inferred");
MACE_ASSERT(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred.");
return res == 0;
};
bool TeardownGraph();
void PrintLog();
void PrintGraph();
void GetPerfInfo();
void SetDebugLevel(int level);
private:
// CAVEAT: Need offset as HVX library reserves some ids
static constexpr int NODE_ID_OFFSET = 10000;
uint32_t node_id(uint32_t nodeid) {
return NODE_ID_OFFSET + nodeid;
}
int nn_id_;
OpMap op_map_;
Serializer serializer_;
vector<index_t> input_shape_;
vector<index_t> output_shape_;
DataType input_data_type_;
DataType output_data_type_;
DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
};
}
#endif // MACE_DSP_HEXAGON_CONTROL_WRAPPER_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/dsp/hexagon_control_wrapper.h"
#include "mace/core/logging.h"
#include "gtest/gtest.h"
using namespace mace;
TEST(HexagonControlerWrapper, GetVersion) {
testing::internal::LogToStderr();
HexagonControlWrapper wrapper;
VLOG(0) << "version: " << wrapper.GetVersion();
wrapper.Init();
wrapper.SetDebugLevel(3);
wrapper.Config();
VLOG(0) << wrapper.SetupGraph("quantized_test_dsp.pb");
wrapper.PrintGraph();
Tensor input_tensor;
Tensor output_tensor;
input_tensor.Resize({1, 28, 28, 3});
float *input_data = input_tensor.mutable_data<float>();
for (int i = 0; i < input_tensor.size(); ++i) {
input_data[i] = i;
}
VLOG(0) << wrapper.ExecuteGraph(input_tensor, &output_tensor);
wrapper.PrintLog();
wrapper.GetPerfInfo();
const float *output_data = output_tensor.data<float>();
VLOG(0) << output_tensor.size() << output_tensor.dtype();
for (int i = 0; i < output_tensor.size(); ++i) {
std::cout << output_data[i] << " ";
}
std::cout << std::endl;
VLOG(0) << wrapper.TeardownGraph();
wrapper.Finalize();
}
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_HEXAGON_NN_OPS_H_
#define MACE_HEXAGON_NN_OPS_H_
#include "mace/core/logging.h"
#include <unordered_map>
namespace mace {
#define OP_INVALID -1
typedef enum op_type_enum {
#define DEF_OP(NAME, ...) OP_##NAME,
#include "mace/dsp/ops.h"
NN_OPS_MAX
#undef DEF_OP
} op_type;
class OpMap {
public:
void Init() {
#define DEF_OP(NAME) \
op_map_[#NAME] = OP_##NAME;
#include "mace/dsp/ops.h"
#undef DEF_OP
}
int GetOpId(std::string op_type) {
if (op_map_.find(op_type) != end(op_map_)) {
return op_map_[op_type];
} else {
LOG(ERROR) << "DSP unsupoorted op type: " << op_type;
return OP_INVALID;
}
}
private:
std::unordered_map<std::string, int> op_map_;
};
} // namespace mace
#endif // MACE_HEXAGON_NN_OPS_H_
/*
* You probably want to
*
* ## ##### #####
* # # # # # #
* # # # # # #
* ###### # # # #
* # # # # # #
* # # ##### #####
*
*
* # # #### ##### ###### ####
* ## # # # # # # #
* # # # # # # # ##### ####
* # # # # # # # # #
* # ## # # # # # # #
* # # #### ##### ###### ####
*
*
* ## #####
* # # #
* # # #
* ###### #
* # # #
* # # #
*
*
* ##### # # ######
* # # # #
* # ###### #####
* # # # #
* # # # #
* # # # ######
*
*
* ###### # # #####
* # ## # # #
* ##### # # # # #
* # # # # # #
* # # ## # #
* ###### # # #####
*
* otherwise the interface becomes incompatible.
*/
DEF_OP(INPUT)
DEF_OP(OUTPUT)
DEF_OP(Nop)
DEF_OP(Const)
DEF_OP(Check)
DEF_OP(Close_f)
DEF_OP(Close_quint8)
DEF_OP(Close_q_quint8)
DEF_OP(Close_int32)
DEF_OP(Close_qint32)
DEF_OP(PPrint_8)
DEF_OP(PPrint_32)
DEF_OP(PPrint_f)
DEF_OP(PreFree)
DEF_OP(Flatten)
#ifndef DEF_OP_WREF
#define DEF_OP_WREF(NAME) DEF_OP(NAME) DEF_OP(NAME##_ref)
#define __SELF_DEF_OP_WREF
#endif
DEF_OP_WREF(QuantizedConv2d_8x8to32)
DEF_OP_WREF(QuantizedMatMul_8x8to32)
DEF_OP_WREF(QuantizeDownAndShrinkRange_32to8)
DEF_OP_WREF(QuantizedRelu_8)
DEF_OP_WREF(QuantizedReluX_8)
DEF_OP_WREF(QuantizedMaxPool_8)
DEF_OP_WREF(QuantizedAvgPool_8)
DEF_OP_WREF(QuantizedConcat_8)
DEF_OP_WREF(QuantizedBiasAdd_8p8to32)
DEF_OP_WREF(Min_f)
DEF_OP_WREF(Max_f)
DEF_OP_WREF(Quantize)
DEF_OP_WREF(Dequantize)
DEF_OP_WREF(Supernode_8x8p8to8)
DEF_OP(QuantizedFlatten)
DEF_OP(Softmax_f)
DEF_OP(Conv2d_f)
DEF_OP(MatMul_f)
DEF_OP(Relu_f)
DEF_OP(ReluX_f)
DEF_OP(AvgPool_f)
DEF_OP(MaxPool_f)
DEF_OP(Concat_f)
DEF_OP(BiasAdd_f)
DEF_OP(LRN_f)
DEF_OP(Variable)
DEF_OP(Assign)
DEF_OP(Reshape)
DEF_OP(QuantizedReshape)
DEF_OP(Tanh_f)
DEF_OP(Sigmoid_f)
DEF_OP(Slice_8)
DEF_OP(Slice_f)
DEF_OP(QuantizedSlice_8)
DEF_OP(Add_f)
DEF_OP(Mul_f)
DEF_OP(Minimum_f)
DEF_OP(Maximum_f)
DEF_OP_WREF(Requantize_32to8)
DEF_OP_WREF(RequantizationRange_32)
DEF_OP(Neg_f)
DEF_OP(Sub_f)
DEF_OP(AddN_f)
DEF_OP(Range_int32)
DEF_OP(Rank_int32)
DEF_OP(Transpose_int32)
DEF_OP(Transpose_f)
DEF_OP(InstanceNorm_f)
DEF_OP_WREF(QuantizedInstanceNorm_8)
DEF_OP(Sub_int32)
DEF_OP(Add_int32)
DEF_OP(Split_f)
DEF_OP(Dequantize_qint32_f)
DEF_OP(PRelu_f)
DEF_OP_WREF(QuantizedPRelu_8)
DEF_OP(Sum_f)
DEF_OP(Prod_f)
DEF_OP(Mul_int32)
DEF_OP(LogicalAnd_int32)
DEF_OP(LogicalOr_int32)
DEF_OP(LogicalXor_int32)
DEF_OP(Shape_int32)
DEF_OP(Pack_int32)
DEF_OP(MirrorPad_f)
DEF_OP(ResizeNearestNeighbor_f)
DEF_OP(StridedSlice_int32)
DEF_OP(StridedSlice_f)
DEF_OP(ExpandDims_int32)
DEF_OP(ExpandDims_f)
DEF_OP(LogSoftmax_f)
DEF_OP(Split_int32)
DEF_OP(QuantizedSplit_8)
DEF_OP(Deconv_f)
DEF_OP_WREF(QuantizedDeconv_8x8to32)
DEF_OP_WREF(QuantizedMul_8x8to32)
DEF_OP_WREF(QuantizedAdd_8p8to32)
DEF_OP_WREF(QuantizedSigmoid_8)
DEF_OP_WREF(QuantizedTanh_8)
DEF_OP_WREF(QuantizedSoftmax_8)
DEF_OP_WREF(QuantizedLRN_8)
DEF_OP_WREF(Quantizedpad2d_frame_8p)
DEF_OP_WREF(QuantizedSub_8p8to32)
DEF_OP_WREF(QuantizedMaximum_8)
DEF_OP_WREF(QuantizedMinimum_8)
DEF_OP(Pad_f)
DEF_OP(SpaceToBatchND_f)
DEF_OP(BatchToSpaceND_f)
DEF_OP(QuantizedPad_8)
DEF_OP(ResizeBilinear_f)
DEF_OP(ConcatV2_f)
DEF_OP(ConcatV2_int32)
DEF_OP(Prod_int32)
DEF_OP(Slice_int32)
DEF_OP(QuantizedAdd_8p8to8)
DEF_OP_WREF(AutoQuantize)
DEF_OP_WREF(QuantizedDepthwiseConv2d_8x8to32)
DEF_OP(DepthwiseConv2d_f)
#ifdef __SELF_DEF_OP_WREF
#undef __SELF_DEF_OP_WREF
#undef DEF_OP_WREF
#endif
......@@ -45,6 +45,8 @@ message TensorProto {
repeated int64 int64_data = 10 [packed = true];
// Optionally, a name for the tensor.
optional string name = 7;
optional uint32 node_id = 100;
}
message Argument {
......@@ -57,12 +59,41 @@ message Argument {
repeated bytes strings = 7;
}
// for hexagon mace-nnlib
message NodeInput {
optional int32 node_id = 1;
optional int32 output_port = 2;
}
message OperatorDef {
repeated string input = 1;
repeated string output = 2;
optional string name = 3;
optional string type = 4;
repeated Argument arg = 5;
// for hexagon mace-nnlib
optional uint32 node_id = 100;
optional uint32 op_id = 101;
optional uint32 padding = 102;
repeated NodeInput node_input = 103;
repeated int32 out_max_byte_size = 104; // only support 32-bit len
}
// for hexagon mace-nnlib
message InputInfo {
optional string name = 1;
optional int32 node_id = 2;
repeated int32 dims = 3;
optional int32 max_byte_size = 4; // only support 32-bit len
optional DataType data_type = 5 [default = DT_FLOAT];
}
message OutputInfo {
optional string name = 1;
optional int32 node_id = 2;
repeated int32 dims = 3;
optional int32 max_byte_size = 4; // only support 32-bit len
optional DataType data_type = 5 [default = DT_FLOAT];
}
message NetDef {
......@@ -71,4 +102,8 @@ message NetDef {
optional string version = 3;
repeated Argument arg = 4;
repeated TensorProto tensors = 5;
// for hexagon mace-nnlib
repeated InputInfo input_info = 100;
repeated OutputInfo output_info = 101;
}
py_library(
name = "tf_converter_lib",
srcs = ["tf_converter_lib.py"],
srcs = ["tf_converter_lib.py", "tf_dsp_converter_lib.py"],
srcs_version = "PY2AND3",
deps = [
"//mace/proto:mace_py",
......
class DspOps(object):
def __init__(self):
self.dsp_ops = {
'INPUT': 'INPUT"',
'OUTPUT': 'OUTPUT',
'NoOp': 'Nop',
'FLATTEN': 'Flatten',
'Identity': 'Nop',
'Placeholder': 'INPUT',
'Const': 'Const',
'QuantizedConv2D': 'QuantizedConv2d_8x8to32',
'QuantizedMatMul': 'QuantizedMatMul_8x8to32',
'QuantizeDownAndShrinkRange': 'QuantizeDownAndShrinkRange_32to8',
'QuantizedRelu': 'QuantizedRelu_8',
'QuantizedReluX': 'QuantizedReluX_8',
'QuantizedMaxPool': 'QuantizedMaxPool_8',
'QuantizedAvgPool': 'QuantizedAvgPool_8',
'QuantizedConcat': 'QuantizedConcat_8',
'QuantizedBiasAdd': 'QuantizedBiasAdd_8p8to32',
'Min': 'Min_f',
'Max': 'Max_f',
'QuantizeV2': 'Quantize',
'Dequantize': 'Dequantize',
'Softmax': 'Softmax_f',
'Reshape': 'Reshape',
'QuantizedReshape': 'QuantizedReshape',
'Sigmoid': 'Sigmoid_f',
'Slice': 'Slice_f',
'Add': 'Add_f',
'Mul': 'Mul_f',
'Requantize': 'Requantize_32to8',
'RequantizationRange': 'RequantizationRange_32',
'Sub': 'Sub_f',
'Pack': 'Pack_int32',
'StridedSlice': 'StridedSlice_f',
'ExpandDims': 'ExpandDims_f',
'QuantizedMul': 'QuantizedMul_8x8to32',
'QuantizedAdd': 'QuantizedAdd_8p8to32',
'Pad': 'Pad_f',
'SpaceToBatchND': 'SpaceToBatchND_f',
'BatchToSpaceND': 'BatchToSpaceND_f',
'ResizeBilinear': 'ResizeBilinear_f',
'ConcatV2': 'ConcatV2_f',
'Conv2DBackpropInput': 'Deconv_f',
'Tanh': 'Tanh_f',
'Split': 'Split_f',
'Transpose': 'Transpose_f',
'Concat': 'Concat_f',
}
def has_op(self, tf_op):
return tf_op in self.dsp_ops
def map_nn_op(self, tf_op):
if tf_op not in self.dsp_ops:
raise Exception('Could not map nn op')
return self.dsp_ops[tf_op]
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mace/proto/mace.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='mace/proto/mace.proto',
package='mace',
syntax='proto2',
serialized_pb=_b('\n\x15mace/proto/mace.proto\x12\x04mace\"\xdf\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12+\n\tdata_type\x18\x02 \x01(\x0e\x32\x0e.mace.DataType:\x08\x44T_FLOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x17\n\x0b\x64ouble_data\x18\t \x03(\x01\x42\x02\x10\x01\x12\x16\n\nint64_data\x18\n \x03(\x03\x42\x02\x10\x01\x12\x0c\n\x04name\x18\x07 \x01(\t\"h\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\x0c\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x03\x12\x0f\n\x07strings\x18\x07 \x03(\x0c\"e\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x1b\n\x03\x61rg\x18\x05 \x03(\x0b\x32\x0e.mace.Argument\"\x87\x01\n\x06NetDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1d\n\x02op\x18\x02 \x03(\x0b\x32\x11.mace.OperatorDef\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1b\n\x03\x61rg\x18\x04 \x03(\x0b\x32\x0e.mace.Argument\x12\"\n\x07tensors\x18\x05 \x03(\x0b\x32\x11.mace.TensorProto*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04NEON\x10\x01\x12\n\n\x06OPENCL\x10\x02*\xa7\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x0c\n\x08\x44T_INT64\x10\x08\x12\r\n\tDT_UINT16\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n')
)
_DEVICETYPE = _descriptor.EnumDescriptor(
name='DeviceType',
full_name='mace.DeviceType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='CPU', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='NEON', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='OPENCL', index=2, number=2,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=604,
serialized_end=647,
)
_sym_db.RegisterEnumDescriptor(_DEVICETYPE)
DeviceType = enum_type_wrapper.EnumTypeWrapper(_DEVICETYPE)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='mace.DataType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='DT_INVALID', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_DOUBLE', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT32', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_UINT8', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT16', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT8', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_STRING', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT64', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_UINT16', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_BOOL', index=10, number=10,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=650,
serialized_end=817,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
CPU = 0
NEON = 1
OPENCL = 2
DT_INVALID = 0
DT_FLOAT = 1
DT_DOUBLE = 2
DT_INT32 = 3
DT_UINT8 = 4
DT_INT16 = 5
DT_INT8 = 6
DT_STRING = 7
DT_INT64 = 8
DT_UINT16 = 9
DT_BOOL = 10
_TENSORPROTO = _descriptor.Descriptor(
name='TensorProto',
full_name='mace.TensorProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='dims', full_name='mace.TensorProto.dims', index=0,
number=1, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data_type', full_name='mace.TensorProto.data_type', index=1,
number=2, type=14, cpp_type=8, label=1,
has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='mace.TensorProto.float_data', index=2,
number=3, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='int32_data', full_name='mace.TensorProto.int32_data', index=3,
number=4, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='byte_data', full_name='mace.TensorProto.byte_data', index=4,
number=5, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='string_data', full_name='mace.TensorProto.string_data', index=5,
number=6, type=12, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='double_data', full_name='mace.TensorProto.double_data', index=6,
number=9, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='int64_data', full_name='mace.TensorProto.int64_data', index=7,
number=10, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='name', full_name='mace.TensorProto.name', index=8,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=255,
)
_ARGUMENT = _descriptor.Descriptor(
name='Argument',
full_name='mace.Argument',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='mace.Argument.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='mace.Argument.f', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='mace.Argument.i', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='mace.Argument.s', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='mace.Argument.floats', index=4,
number=5, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='mace.Argument.ints', index=5,
number=6, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='mace.Argument.strings', index=6,
number=7, type=12, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=257,
serialized_end=361,
)
_OPERATORDEF = _descriptor.Descriptor(
name='OperatorDef',
full_name='mace.OperatorDef',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='input', full_name='mace.OperatorDef.input', index=0,
number=1, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='output', full_name='mace.OperatorDef.output', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='mace.OperatorDef.name', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='mace.OperatorDef.type', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arg', full_name='mace.OperatorDef.arg', index=4,
number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=363,
serialized_end=464,
)
_NETDEF = _descriptor.Descriptor(
name='NetDef',
full_name='mace.NetDef',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='mace.NetDef.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='op', full_name='mace.NetDef.op', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='mace.NetDef.version', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arg', full_name='mace.NetDef.arg', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensors', full_name='mace.NetDef.tensors', index=4,
number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=467,
serialized_end=602,
)
_TENSORPROTO.fields_by_name['data_type'].enum_type = _DATATYPE
_OPERATORDEF.fields_by_name['arg'].message_type = _ARGUMENT
_NETDEF.fields_by_name['op'].message_type = _OPERATORDEF
_NETDEF.fields_by_name['arg'].message_type = _ARGUMENT
_NETDEF.fields_by_name['tensors'].message_type = _TENSORPROTO
DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO
DESCRIPTOR.message_types_by_name['Argument'] = _ARGUMENT
DESCRIPTOR.message_types_by_name['OperatorDef'] = _OPERATORDEF
DESCRIPTOR.message_types_by_name['NetDef'] = _NETDEF
DESCRIPTOR.enum_types_by_name['DeviceType'] = _DEVICETYPE
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict(
DESCRIPTOR = _TENSORPROTO,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.TensorProto)
))
_sym_db.RegisterMessage(TensorProto)
Argument = _reflection.GeneratedProtocolMessageType('Argument', (_message.Message,), dict(
DESCRIPTOR = _ARGUMENT,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.Argument)
))
_sym_db.RegisterMessage(Argument)
OperatorDef = _reflection.GeneratedProtocolMessageType('OperatorDef', (_message.Message,), dict(
DESCRIPTOR = _OPERATORDEF,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.OperatorDef)
))
_sym_db.RegisterMessage(OperatorDef)
NetDef = _reflection.GeneratedProtocolMessageType('NetDef', (_message.Message,), dict(
DESCRIPTOR = _NETDEF,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.NetDef)
))
_sym_db.RegisterMessage(NetDef)
_TENSORPROTO.fields_by_name['float_data'].has_options = True
_TENSORPROTO.fields_by_name['float_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['int32_data'].has_options = True
_TENSORPROTO.fields_by_name['int32_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['double_data'].has_options = True
_TENSORPROTO.fields_by_name['double_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['int64_data'].has_options = True
_TENSORPROTO.fields_by_name['int64_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
# @@protoc_insertion_point(module_scope)
......@@ -3,9 +3,11 @@ import sys
import tensorflow as tf
from tensorflow import gfile
from mace.python.tools import tf_converter_lib
from mace.python.tools import tf_dsp_converter_lib
FLAGS = None
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS = None
def main(unused_args):
if not gfile.Exists(FLAGS.input):
......@@ -17,13 +19,17 @@ def main(unused_args):
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def)
if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_dim, FLAGS.output_node)
else:
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
output_graph_def.ClearField('tensors')
# output_graph_def.ClearField('tensors')
f.write(str(output_graph_def))
......@@ -41,6 +47,21 @@ def parse_args():
type=str,
default="",
help="File to save the output graph to.")
parser.add_argument(
"--runtime",
type=str,
default="cpu",
help="Runtime: cpu/gpu/dsp.")
parser.add_argument(
"--input_dim",
type=str,
default="input_node,1,28,28,3",
help="e.g., input_node,1,28,28,3")
parser.add_argument(
"--output_node",
type=str,
default="softmax",
help="e.g., softmax")
return parser.parse_known_args()
......
from mace.proto import mace_pb2
# import mace_pb2
import tensorflow as tf
import numpy as np
from operator import mul
from dsp_ops import DspOps
padding_mode = {
'NA': 0,
'SAME': 1,
'VALID': 2,
'MIRROR_REFLECT': 3,
'MIRROR_SYMMETRIC': 4,
'SAME_CAFFE': 5
}
node_count = 0
node_ids = {}
def max_elem_size(tensor):
if len(tensor.shape.as_list()) == 0:
return tensor.dtype.size
else:
return reduce(mul, tensor.shape.as_list()) * tensor.dtype.size
def find_dtype(tensor_dtype):
if tensor_dtype == tf.float32:
return mace_pb2.DT_FLOAT
elif tensor_dtype == tf.uint8 or tensor_dtype == tf.quint8:
return mace_pb2.DT_UINT8
elif tensor_dtype == tf.int32 or tensor_dtype == tf.qint32:
return mace_pb2.DT_INT32
else:
raise Exception('Unsupported data type: ', tensor_dtype)
def has_padding_and_strides(op):
return 'padding' in op.node_def.attr and 'strides' in op.node_def.attr
def is_node_flatten_reshape(op):
return op.type == 'Reshape' and len(op.outputs[0].shape) == 1
def get_input_tensor(op, index):
input_tensor = op.inputs[index]
if input_tensor.op.type == 'Reshape':
input_tensor = get_input_tensor(input_tensor.op, 0)
return input_tensor
def add_shape_const_node(net_def, op, values, name):
print ('Add const node: ', op.name + '/' + name)
global node_count
tensor = net_def.tensors.add()
node_name = op.name + '/' + name
tensor.name = node_name + ':0'
tensor.node_id = node_count
node_count += 1
register_node_id(node_name, tensor.node_id)
tensor.data_type = mace_pb2.DT_INT32
tensor.dims.extend(values)
return tensor.name
def register_node_id(node_name, node_id):
global node_ids
node_ids[node_name] = node_id
def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
global node_count
ops_count = len(unresolved_ops)
resolved_count = 1
first_op = unresolved_ops[0]
print ('Op: ', first_op.name, first_op.type, first_op.outputs[0].shape)
if first_op.type == 'Const':
print ('Add const node: ', first_op.name)
tf_tensor = first_op.outputs[0].eval()
tensor = net_def.tensors.add()
tensor.name = first_op.outputs[0].name
tensor.node_id = node_count
node_count += 1
register_node_id(tensor.name.split(':')[0], tensor.node_id)
tensor.data_type = find_dtype(first_op.outputs[0].dtype)
shape = list(tf_tensor.shape)
if len(shape) > 0:
tensor.dims.extend(shape)
if first_op.outputs[0].dtype == tf.float32:
tensor.float_data.extend(tf_tensor.astype(float).flat)
elif first_op.outputs[0].dtype == tf.int32 or \
first_op.outputs[0].dtype == tf.int8 or \
first_op.outputs[0].dtype == tf.int16 or \
first_op.outputs[0].dtype == tf.quint8 or \
first_op.outputs[0].dtype == tf.quint16:
tensor.int32_data.extend(tf_tensor.astype(int).flat)
else:
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = dsp_ops.map_nn_op(first_op.type)
op_def.node_id = node_count
node_count += 1
register_node_id(op_def.name, op_def.node_id)
op_def.padding = padding_mode['NA']
if has_padding_and_strides(first_op):
op_def.padding = padding_mode[first_op.get_attr('padding')]
op_def.input.extend([t.name for t in first_op.inputs])
if 'ksize' in first_op.node_def.attr:
ksize = first_op.get_attr('ksize')
ksize_tensor = add_shape_const_node(net_def, first_op, ksize, 'ksize')
op_def.input.extend([ksize_tensor])
strides = first_op.get_attr('strides')
strides_tensor = add_shape_const_node(net_def, first_op, strides, 'strides')
op_def.input.extend([strides_tensor])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in first_op.outputs])
elif is_node_flatten_reshape(first_op):
op_def.type = 'Flatten'
op_def.input.extend([t.name for t in first_op.inputs])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in first_op.outputs])
elif dsp_ops.has_op(first_op.type):
op_def.input.extend([t.name for t in first_op.inputs])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in first_op.outputs])
if first_op.type == 'Placeholder':
input_info = net_def.input_info.add()
input_info.name = op_def.name
input_info.node_id = op_def.node_id
input_info.dims.extend(first_op.outputs[0].shape.as_list())
input_info.max_byte_size = max_elem_size(first_op.outputs[0])
input_info.data_type = find_dtype(first_op.outputs[0].dtype)
elif first_op.name == output_node:
output_info = net_def.output_info.add()
output_info.name = op_def.name
output_info.node_id = op_def.node_id
output_info.dims.extend(first_op.outputs[0].shape.as_list())
output_info.max_byte_size = max_elem_size(first_op.outputs[0])
output_info.data_type = find_dtype(first_op.outputs[0].dtype)
else:
raise Exception('Unsupported op: ', first_op)
print ('Add op node: ', first_op.name)
for t in op_def.input:
node, port = t.split(':')
node_id = node_ids[node]
node_input = op_def.node_input.add()
node_input.node_id = node_id
node_input.output_port = int(port)
for i in range(resolved_count):
del unresolved_ops[0]
def add_output_node(net_def, output_node):
global node_count
op_def = net_def.op.add()
op_def.name = 'output'
op_def.type = 'OUTPUT'
op_def.node_id = node_count
node_count += 1
register_node_id(op_def.name, op_def.node_id)
op_def.input.extend([output_node + ':0'])
node_input = op_def.node_input.add()
node_input.node_id = node_ids[output_node]
node_input.output_port = 0
def convert_to_mace_pb(input_graph_def, input_dim, output_node):
inputs = input_dim.split(';')
input_shape = {}
for input in inputs:
input_name_shape = input.split(',')
name = input_name_shape[0]
shape = [int(d) for d in input_name_shape[1:]]
input_shape[name] = shape
net_def = mace_pb2.NetDef()
for node in input_graph_def.node:
if node.op == 'Placeholder':
node.attr['shape'].shape.unknown_rank = False
for d in input_shape[node.name]:
dim = node.attr['shape'].shape.dim.add()
dim.size = d
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
unresolved_ops = ops
dsp_ops = DspOps()
while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, net_def, output_node, dsp_ops)
add_output_node(net_def, output_node)
return net_def
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册