From f05649a463809084d0574a2427ca2b7b8b30162a Mon Sep 17 00:00:00 2001 From: liuqi Date: Fri, 15 Dec 2017 18:19:55 +0800 Subject: [PATCH] Replace pb with code. --- mace/core/BUILD | 1 - mace/core/allocator.h | 2 +- mace/core/mace.cc | 404 ++++++++++++++++++++++++++++++ mace/core/mace.h | 341 +++++++++++++++++++++++++ mace/core/net.cc | 2 +- mace/core/net.h | 2 +- mace/core/operator.h | 2 +- mace/core/proto_utils.cc | 240 +----------------- mace/core/proto_utils.h | 179 +------------ mace/core/serializer.cc | 36 +-- mace/core/serializer.h | 2 +- mace/core/tensor.h | 2 +- mace/core/types.h | 2 +- mace/core/workspace.cc | 2 +- mace/core/workspace.h | 2 +- mace/examples/mace_run.cc | 26 +- mace/kernels/batch_norm.h | 2 +- mace/kernels/bias_add.h | 2 +- mace/kernels/concat.h | 2 +- mace/kernels/depthwise_conv2d.h | 2 +- mace/kernels/space_to_batch.h | 2 +- mace/ops/BUILD | 1 - mace/ops/concat.h | 2 +- mace/ops/ops_test_util.h | 1 - mace/python/tools/model.template | 139 ++++++++++ mace/python/tools/tf_converter.py | 60 ++++- 26 files changed, 987 insertions(+), 471 deletions(-) create mode 100644 mace/core/mace.cc create mode 100644 mace/core/mace.h create mode 100644 mace/python/tools/model.template diff --git a/mace/core/BUILD b/mace/core/BUILD index 6f1af8a5..9f5ca2cb 100644 --- a/mace/core/BUILD +++ b/mace/core/BUILD @@ -62,7 +62,6 @@ cc_library( ]), deps = [ ":logging", - "//mace/proto:cc_proto", "//mace/proto:stats_proto", "//mace/utils", ":opencl_runtime", diff --git a/mace/core/allocator.h b/mace/core/allocator.h index bc803879..0f30d4ad 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -9,7 +9,7 @@ #include #include "mace/core/common.h" #include "mace/core/registry.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/types.h" namespace mace { diff --git a/mace/core/mace.cc b/mace/core/mace.cc new file mode 100644 index 00000000..abd65717 --- /dev/null +++ b/mace/core/mace.cc @@ -0,0 +1,404 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/mace.h" +#include "mace/core/logging.h" + +namespace mace { + +TensorProto::TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const DataType data_type, + uint32_t node_id) : + name_(name), + data_(data), + dims_(dims), + data_type_(data_type), + node_id_(node_id) {} + +TensorProto::TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const int data_type, + uint32_t node_id) : + name_(name), + data_(data), + dims_(dims), + data_type_(static_cast(data_type)), + node_id_(node_id) {} + +const std::string &TensorProto::name() const { + return name_; +} +unsigned char *TensorProto::data() const { + return data_; +} +const int TensorProto::data_size() const { + return data_size_; +} +const std::vector &TensorProto::dims() const { + return dims_; +} +DataType TensorProto::data_type() const { + return data_type_; +} +uint32_t TensorProto::node_id() const { + return node_id_; +} + +Argument::Argument() : has_bits_(0) {} + +void Argument::CopyFrom(const Argument &from) { + this->name_ = from.name(); + this->f_ = from.f(); + this->i_ = from.i(); + this->s_ = from.s(); + auto floats = from.floats(); + this->floats_.resize(floats.size()); + std::copy(floats.begin(), floats.end(), this->floats_.begin()); + auto ints = from.ints(); + this->ints_.resize(ints.size()); + std::copy(ints.begin(), ints.end(), this->ints_.begin()); + auto strings = from.floats(); + this->strings_.resize(strings.size()); + std::copy(floats.begin(), floats.end(), this->floats_.begin()); + + this->has_bits_ = from.has_bits_; +} +const std::string &Argument::name() const { + return name_; +} +void Argument::set_name(const std::string &value) { + name_ = value; +} +bool Argument::has_f() const { + return (has_bits_ & 0x00000001u) != 0; +} +void Argument::set_has_f() { + has_bits_ |= 0x00000001u; +} +float Argument::f() const { + return f_; +} +void Argument::set_f(float value) { + set_has_f(); + f_ = value; +} +bool Argument::has_i() const { + return (has_bits_ & 0x00000002u) != 0; +} +void Argument::set_has_i() { + has_bits_ |= 0x00000002u; +} +int64_t Argument::i() const { + return i_; +} +void Argument::set_i(int64_t value) { + set_has_i(); + i_ = value; +} +bool Argument::has_s() const { + return (has_bits_ & 0x00000004u) != 0; +} +void Argument::set_has_s() { + has_bits_ |= 0x00000004u; +} +std::string Argument::s() const { + return s_; +} +void Argument::set_s(const std::string &value) { + set_has_s(); + s_ = value; +} +const std::vector &Argument::floats() const { + return floats_; +} +void Argument::add_floats(float value) { + floats_.push_back(value); +} +void Argument::set_floats(const std::vector &value) { + floats_.reserve(value.size()); + std::copy(value.begin(), value.end(), floats_.begin()); +} +const std::vector &Argument::ints() const { + return ints_; +} +void Argument::add_ints(int64_t value) { + ints_.push_back(value); +} +void Argument::set_ints(const std::vector &value) { + ints_.reserve(value.size()); + std::copy(value.begin(), value.end(), ints_.begin()); +} +const std::vector &Argument::strings() const { + return strings_; +} +void Argument::add_strings(const ::std::string &value) { + strings_.push_back(value); +} +void Argument::set_strings(const std::vector &value) { + strings_.reserve(value.size()); + std::copy(value.begin(), value.end(), strings_.begin()); +} + +void OperatorDef::CopyFrom(const OperatorDef &from) { + name_ = from.name(); + type_ = from.type(); + + auto from_input = from.input(); + input_.resize(from_input.size()); + std::copy(from_input.begin(), from_input.end(), input_.begin()); + auto from_output = from.output(); + output_.resize(from_output.size()); + std::copy(from_output.begin(), from_output.end(), output_.begin()); + auto from_arg = from.arg(); + arg_.resize(from_arg.size()); + for (int i = 0; i < from_arg.size(); ++i) { + arg_[i].CopyFrom(from_arg[i]); + } + auto from_output_shape = from.output_shape(); + output_shape_.resize(from_output_shape.size()); + for (int i = 0; i < from_output_shape.size(); ++i) { + output_shape_[i].CopyFrom(from_output_shape[i]); + } + auto from_data_type = from.output_type(); + output_type_.resize(from_data_type.size()); + std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin()); + + mem_id_ = from.mem_id(); + + // nnlib + node_id_ = from.node_id(); + op_id_ = from.op_id(); + padding_ = from.padding(); + auto from_node_input = from.node_input(); + node_input_.resize(from_node_input.size()); + for (int i = 0; i < from_node_input.size(); ++i) { + node_input_[i].CopyFrom(from_node_input[i]); + } + auto from_out_max_byte_size = from.out_max_byte_size(); + out_max_byte_size_.resize(from_out_max_byte_size.size()); + std::copy(from_out_max_byte_size.begin(), from_out_max_byte_size.end(), out_max_byte_size_.begin()); + + has_bits_ = from.has_bits_; + +} + +const std::string &OperatorDef::name() const { + return name_; +} +void OperatorDef::set_name(const std::string &name_) { + set_has_name(); + OperatorDef::name_ = name_; +} +bool OperatorDef::has_name() const { + return (has_bits_ & 0x00000001u) != 0; +} +void OperatorDef::set_has_name() { + has_bits_ |= 0x00000001u; +} +const std::string &OperatorDef::type() const { + return type_; +} +void OperatorDef::set_type(const std::string &type_) { + set_has_type(); + OperatorDef::type_ = type_; +} +bool OperatorDef::has_type() const { + return (has_bits_ & 0x00000002u) != 0; +} +void OperatorDef::set_has_type() { + has_bits_ |= 0x00000002u; +} +int OperatorDef::mem_id() const { + return mem_id_; +} +void OperatorDef::set_mem_id(const int mem_id) { + set_has_mem_id(); + mem_id_ = mem_id; +} +bool OperatorDef::has_mem_id() const { + return (has_bits_ & 0x00000004u) != 0; +} +void OperatorDef::set_has_mem_id() { + has_bits_ |= 0x00000004u; +} +uint32_t OperatorDef::node_id() const { + return node_id_; +} +uint32_t OperatorDef::op_id() const { + return op_id_; +} +uint32_t OperatorDef::padding() const { + return padding_; +} +const std::vector &OperatorDef::node_input() const { + return node_input_; +} +const std::vector &OperatorDef::out_max_byte_size() const { + return out_max_byte_size_; +} +const std::vector &OperatorDef::input() const { + return input_; +} +const std::string &OperatorDef::input(int index) const { + MACE_CHECK(0 <= index && index <= input_.size()); + return input_[index]; +} +std::string *OperatorDef::add_input() { + input_.push_back(""); + return &input_.back(); +} +void OperatorDef::add_input(const ::std::string &value) { + input_.push_back(value); +} +void OperatorDef::add_input(::std::string &&value) { + input_.push_back(value); +} +void OperatorDef::set_input(const std::vector &value) { + input_.reserve(value.size()); + std::copy(value.begin(), value.end(), input_.begin()); +} +const std::vector &OperatorDef::output() const { + return output_; +} +const std::string &OperatorDef::output(int index) const { + MACE_CHECK(0 <= index && index <= output_.size()); + return output_[index]; +} +std::string *OperatorDef::add_output() { + output_.push_back(""); + return &output_.back(); +} +void OperatorDef::add_output(const ::std::string &value) { + output_.push_back(value); +} +void OperatorDef::add_output(::std::string &&value) { + output_.push_back(value); +} +void OperatorDef::set_output(const std::vector &value) { + output_.reserve(value.size()); + std::copy(value.begin(), value.end(), output_.begin()); +} +const std::vector &OperatorDef::arg() const { + return arg_; +} +Argument *OperatorDef::add_arg() { + arg_.emplace_back(Argument()); + return &arg_.back(); +} +const std::vector &OperatorDef::output_shape() const { + return output_shape_; +} +void OperatorDef::set_output_shape(const std::vector &value) { + output_shape_.reserve(value.size()); + for (int i = 0; i < value.size(); ++i) { + output_shape_[i].CopyFrom(value[i]); + } +} +const std::vector &OperatorDef::output_type() const { + return output_type_; +} +void OperatorDef::set_output_type(const std::vector &value) { + output_type_.resize(value.size()); + std::copy(value.begin(), value.end(), output_type_.begin()); +} + +MemoryBlock::MemoryBlock(int mem_id, uint32_t x, uint32_t y) : + mem_id_(mem_id), x_(x), y_(y) {} + +int MemoryBlock::mem_id() const { + return mem_id_; +} +uint32_t MemoryBlock::x() const { + return x_; +} +uint32_t MemoryBlock::y() const { + return y_; +} + +NetDef::NetDef() : has_bits_(0) {} + +const std::string &NetDef::name() const { + return name_; +} +void NetDef::set_name(const std::string &value) { + set_has_name(); + name_ = value; +} +bool NetDef::has_name() const { + return (has_bits_ & 0x00000001u) != 0; +} +void NetDef::set_has_name() { + has_bits_ |= 0x00000001u; +} +const std::string &NetDef::version() const { + return version_; +} +void NetDef::set_version(const std::string &value) { + set_has_version(); + version_ = value; +} +bool NetDef::has_version() const { + return (has_bits_ & 0x00000002u) != 0; +} +void NetDef::set_has_version() { + has_bits_ |= 0x00000002u; +} +const std::vector &NetDef::op() const { + return op_; +} +OperatorDef *NetDef::add_op() { + op_.emplace_back(OperatorDef()); + return &op_.back(); +} +std::vector &NetDef::mutable_op() { + return op_; +} +const std::vector &NetDef::arg() const { + return arg_; +} +Argument *NetDef::add_arg() { + arg_.emplace_back(Argument()); + return &arg_.back(); +} +std::vector &NetDef::mutable_arg() { + return arg_; +} +const std::vector &NetDef::tensors() const { + return tensors_; +} +std::vector &NetDef::mutable_tensors() { + return tensors_; +} +const MemoryArena &NetDef::mem_arena() const { + return mem_arena_; +} +MemoryArena &NetDef::mutable_mem_arena() { + set_has_mem_arena(); + return mem_arena_; +} +bool NetDef::has_mem_arena() const { + return (has_bits_ & 0x00000004u) != 0; +} +void NetDef::set_has_mem_arena() { + has_bits_ |= 0x00000004u; +} +const std::vector &NetDef::input_info() const { + return input_info_; +} +const std::vector &NetDef::output_info() const { + return output_info_; +} + +int NetDef::op_size() const { + return op_.size(); +} + +const OperatorDef &NetDef::op(const int idx) const { + MACE_CHECK(0 <= idx && idx < op_size()); + return op_[idx]; +} +} // namespace mace diff --git a/mace/core/mace.h b/mace/core/mace.h new file mode 100644 index 00000000..451e39e7 --- /dev/null +++ b/mace/core/mace.h @@ -0,0 +1,341 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_MACE_H_ +#define MACE_CORE_MACE_H_ +#include +#include +#include +#include "mace/core/logging.h" + +namespace mace { + +enum NetMode { + INIT = 0, + NORMAL = 1 +}; + +enum DeviceType { + CPU = 0, + NEON = 1, + OPENCL = 2 +}; + +enum DataType { + DT_INVALID = 0, + DT_FLOAT = 1, + DT_DOUBLE = 2, + DT_INT32 = 3, + DT_UINT8 = 4, + DT_INT16 = 5, + DT_INT8 = 6, + DT_STRING = 7, + DT_INT64 = 8, + DT_UINT16 = 9, + DT_BOOL = 10, + DT_HALF = 19, + DT_UINT32 = 22 +}; + +class TensorProto { + public: + TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const DataType data_type = DT_FLOAT, + uint32_t node_id = 0); + TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const int data_type, + uint32_t node_id = 0); + + const std::string &name() const; + unsigned char *data() const; + const int data_size() const; + const std::vector &dims() const; + DataType data_type() const; + uint32_t node_id() const; + + private: + std::string name_; + unsigned char *data_; + int data_size_; + std::vector dims_; + DataType data_type_; + uint32_t node_id_; +}; + +class Argument { + public: + Argument(); + void CopyFrom(const Argument &from) ; + public: + const std::string &name() const; + void set_name(const std::string& value); + bool has_f() const; + float f() const ; + void set_f(float value) ; + bool has_i() const ; + int64_t i() const ; + void set_i(int64_t value); + bool has_s() const ; + std::string s() const ; + void set_s(const std::string& value) ; + const std::vector &floats() const ; + void add_floats(float value) ; + void set_floats(const std::vector &value); + const std::vector &ints() const ; + void add_ints(int64_t value) ; + void set_ints(const std::vector &value); + const std::vector &strings() const ; + void add_strings(const ::std::string& value) ; + void set_strings(const std::vector &value); + + private: + void set_has_f() ; + void set_has_i() ; + void set_has_s() ; + + private: + std::string name_; + float f_; + int64_t i_; + std::string s_; + std::vector floats_; + std::vector ints_; + std::vector strings_; + uint32_t has_bits_; +}; + +class NodeInput { + public: + void CopyFrom(const NodeInput &from) { + node_id_ = from.node_id(); + output_port_ = from.output_port(); + } + public: + int node_id() const { + return node_id_; + } + int output_port() const { + return output_port_; + } + private: + int node_id_; + int output_port_; +}; + +class OutputShape { + public: + void CopyFrom(const OutputShape &from) { + auto from_dims = from.dims(); + dims_.resize(from_dims.size()); + std::copy(from_dims.begin(), from_dims.end(), dims_.begin()); + } + public: + const std::vector &dims() const { + return dims_; + } + private: + std::vector dims_; +}; + +class OperatorDef { + public: + void CopyFrom(const OperatorDef &from); + + public: + const std::string &name() const; + void set_name(const std::string &name_); + bool has_name() const; + const std::string &type() const; + void set_type(const std::string &type_); + bool has_type() const; + int mem_id() const; + void set_mem_id(const int mem_id); + bool has_mem_id() const; + uint32_t node_id() const; + uint32_t op_id() const; + uint32_t padding() const; + const std::vector &node_input() const; + const std::vector &out_max_byte_size() const; + const std::vector &input() const; + const std::string& input(int index) const; + std::string* add_input(); + void add_input(const ::std::string& value); + void add_input(::std::string&& value); + void set_input(const std::vector &value); + const std::vector &output() const; + const std::string& output(int index) const; + std::string* add_output(); + void add_output(const ::std::string& value); + void add_output(::std::string&& value); + void set_output(const std::vector &value); + const std::vector &arg() const; + Argument* add_arg(); + const std::vector &output_shape() const; + void set_output_shape(const std::vector &value); + const std::vector &output_type() const; + void set_output_type(const std::vector &value); + + private: + void set_has_name(); + void set_has_type(); + void set_has_mem_id(); + + private: + std::string name_; + std::string type_; + + std::vector input_; + std::vector output_; + std::vector arg_; + std::vector output_shape_; + std::vector output_type_; + + int mem_id_; + + // nnlib + uint32_t node_id_; + uint32_t op_id_; + uint32_t padding_; + std::vector node_input_; + std::vector out_max_byte_size_; + + uint32_t has_bits_; +}; + +class MemoryBlock { + public: + MemoryBlock(int mem_id, uint32_t x, uint32_t y); + public: + int mem_id() const; + uint32_t x() const; + uint32_t y() const; + private: + int mem_id_; + uint32_t x_; + uint32_t y_; +}; + +class MemoryArena { + public: + inline const std::vector &mem_block() const { + return mem_block_; + } + inline std::vector &mutable_mem_block() { + return mem_block_; + } + inline int mem_block_size() const { + return mem_block_.size(); + } + private: + std::vector mem_block_; + +}; + +// for hexagon mace-nnlib +class InputInfo { + public: + const std::string &name() const { + return name_; + } + int32_t node_id() const { + return node_id_; + } + int32_t max_byte_size() const { + return max_byte_size_; + } + DataType data_type() const { + return data_type_; + } + const std::vector &dims() const { + return dims_; + } + private: + std::string name_; + int32_t node_id_; + int32_t max_byte_size_; // only support 32-bit len + DataType data_type_; + std::vector dims_; +}; + +class OutputInfo { + public: + const std::string &name() const { + return name_; + } + int32_t node_id() const { + return node_id_; + } + int32_t max_byte_size() const { + return max_byte_size_; + } + DataType data_type() const { + return data_type_; + } + const std::vector &dims() const { + return dims_; + } + private: + std::string name_; + int32_t node_id_; + int32_t max_byte_size_; // only support 32-bit len + DataType data_type_; + std::vector dims_; +}; + +class NetDef { + public: + NetDef(); + int op_size() const; + + const OperatorDef &op(const int idx) const; + public: + const std::string &name() const; + bool has_name() const; + void set_name(const std::string& value); + const std::string &version() const; + bool has_version() const; + void set_version(const std::string& value); + + const std::vector &op() const; + OperatorDef* add_op(); + std::vector &mutable_op(); + const std::vector &arg() const; + Argument *add_arg(); + std::vector &mutable_arg(); + const std::vector &tensors() const; + std::vector &mutable_tensors(); + const MemoryArena &mem_arena() const; + bool has_mem_arena() const; + MemoryArena &mutable_mem_arena(); + const std::vector &input_info() const; + const std::vector &output_info() const; + + private: + void set_has_name(); + void set_has_version(); + void set_has_mem_arena(); + + private: + std::string name_; + std::string version_; + std::vector op_; + std::vector arg_; + std::vector tensors_; + + // for mem optimization + MemoryArena mem_arena_; + + // for hexagon mace-nnlib + std::vector input_info_; + std::vector output_info_; + + uint32_t has_bits_; +}; + +} // namespace mace +#endif // MACE_CORE_MACE_H_ diff --git a/mace/core/net.cc b/mace/core/net.cc index e1b16a03..55f3c5f6 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -50,7 +50,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { } } if (!op->Run()) { - LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); + LOG(ERROR) << "Operator failed: " << op->debug_def().name(); return false; } diff --git a/mace/core/net.h b/mace/core/net.h index 67c954f3..109d2d66 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -8,7 +8,7 @@ #include "mace/core/common.h" #include "mace/core/operator.h" #include "mace/core/workspace.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/proto/stats.pb.h" namespace mace { diff --git a/mace/core/operator.h b/mace/core/operator.h index 4cd52fb3..e5fc44f0 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -10,7 +10,7 @@ #include "mace/core/registry.h" #include "mace/core/tensor.h" #include "mace/core/workspace.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/proto_utils.cc b/mace/core/proto_utils.cc index 064e9b53..064f5b28 100644 --- a/mace/core/proto_utils.cc +++ b/mace/core/proto_utils.cc @@ -4,163 +4,12 @@ #include "mace/core/proto_utils.h" -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" - -#ifndef MACE_USE_LITE_PROTO -#include "google/protobuf/text_format.h" -#endif // !MACE_USE_LITE_PROTO - namespace mace { -bool ReadStringFromFile(const char *filename, string *str) { - std::ifstream ifs(filename, std::ios::in); - if (!ifs) { - VLOG(1) << "File cannot be opened: " << filename - << " error: " << ifs.rdstate(); - return false; - } - ifs.seekg(0, std::ios::end); - size_t n = ifs.tellg(); - str->resize(n); - ifs.seekg(0); - ifs.read(&(*str)[0], n); - return true; -} - -bool WriteStringToFile(const string &str, const char *filename) { - std::ofstream ofs(filename, std::ios::out | std::ios::trunc); - if (!ofs.is_open()) { - VLOG(1) << "File cannot be created: " << filename - << " error: " << ofs.rdstate(); - return false; - } - ofs << str; - return true; -} - -// IO-specific proto functions: we will deal with the protocol buffer lite and -// full versions differently. - -#ifdef MACE_USE_LITE_PROTO - -// Lite runtime. - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const string &filename) - : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void *buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(filename)); - stream.SetOwnsCopyingStream(true); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -void WriteProtoToBinaryFile(const MessageLite & /*proto*/, - const char * /*filename*/) { - LOG(FATAL) << "Not implemented yet."; -} - -#else // MACE_USE_LITE_PROTO - -// Full protocol buffer. - -using ::google::protobuf::io::FileInputStream; -using ::google::protobuf::io::FileOutputStream; -using ::google::protobuf::io::ZeroCopyInputStream; -using ::google::protobuf::io::CodedInputStream; -using ::google::protobuf::io::ZeroCopyOutputStream; -using ::google::protobuf::io::CodedOutputStream; - -bool ReadProtoFromTextFile(const char *filename, Message *proto) { - int fd = open(filename, O_RDONLY); - MACE_CHECK(fd != -1, "File not found: ", filename); - FileInputStream *input = new FileInputStream(fd); - bool success = google::protobuf::TextFormat::Parse(input, proto); - delete input; - close(fd); - return success; -} - -void WriteProtoToTextFile(const Message &proto, const char *filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - FileOutputStream *output = new FileOutputStream(fd); - MACE_CHECK(google::protobuf::TextFormat::Print(proto, output)); - delete output; - close(fd); -} - -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) { -#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified - int fd = open(filename, O_RDONLY | O_BINARY); -#else - int fd = open(filename, O_RDONLY); -#endif - MACE_CHECK(fd != -1, "File not found: ", filename); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input( - new CodedInputStream(raw_input.get())); - // A hack to manually allow using very large protocol buffers. - coded_input->SetTotalBytesLimit(1073741824, 536870912); - bool success = proto->ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - close(fd); - return success; -} - -void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ", - errno); - std::unique_ptr raw_output(new FileOutputStream(fd)); - std::unique_ptr coded_output( - new CodedOutputStream(raw_output.get())); - MACE_CHECK(proto.SerializeToCodedStream(coded_output.get())); - coded_output.reset(); - raw_output.reset(); - close(fd); -} - -#endif // MACE_USE_LITE_PROTO - ArgumentHelper::ArgumentHelper(const OperatorDef &def) { for (auto &arg : def.arg()) { if (arg_map_.find(arg.name()) != arg_map_.end()) { - MACE_CHECK( - arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(), - "Found argument of the same name '", arg.name(), - "' but with different contents: ", ProtoDebugString(def)); - - LOG(WARNING) << "Duplicated argument name found in operator def: " - << ProtoDebugString(def) - << ", arg: " << ProtoDebugString(arg); + LOG(WARNING) << "Duplicated argument name found in operator def."; } arg_map_[arg.name()] = arg; @@ -170,8 +19,7 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) { ArgumentHelper::ArgumentHelper(const NetDef &netdef) { for (auto &arg : netdef.arg()) { MACE_CHECK(arg_map_.count(arg.name()) == 0, - "Duplicated argument name found in net def: ", - ProtoDebugString(netdef)); + "Duplicated argument name found in net def."); arg_map_[arg.name()] = arg; } } @@ -265,88 +113,4 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) #undef INSTANTIATE_GET_REPEATED_ARGUMENT -#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ - template <> \ - Argument MakeArgument(const string &name, const T &value) { \ - Argument arg; \ - arg.set_name(name); \ - arg.set_##fieldname(value); \ - return arg; \ - } - -MACE_MAKE_SINGULAR_ARGUMENT(bool, i) -MACE_MAKE_SINGULAR_ARGUMENT(float, f) -MACE_MAKE_SINGULAR_ARGUMENT(int, i) -MACE_MAKE_SINGULAR_ARGUMENT(int64_t, i) -MACE_MAKE_SINGULAR_ARGUMENT(string, s) -#undef MACE_MAKE_SINGULAR_ARGUMENT - -template <> -Argument MakeArgument(const string &name, const MessageLite &value) { - Argument arg; - arg.set_name(name); - arg.set_s(value.SerializeAsString()); - return arg; -} - -#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \ - template <> \ - Argument MakeArgument(const string &name, const vector &value) { \ - Argument arg; \ - arg.set_name(name); \ - for (const auto &v : value) { \ - arg.add_##fieldname(v); \ - } \ - return arg; \ - } - -MACE_MAKE_REPEATED_ARGUMENT(float, floats) -MACE_MAKE_REPEATED_ARGUMENT(int, ints) -MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints) -MACE_MAKE_REPEATED_ARGUMENT(string, strings) -#undef MACE_MAKE_REPEATED_ARGUMENT - -const Argument &GetArgument(const OperatorDef &def, const string &name) { - for (const Argument &arg : def.arg()) { - if (arg.name() == name) { - return arg; - } - } - MACE_CHECK(false, "Argument named ", name, "does not exist in operator ", - ProtoDebugString(def)); - // should not reach here, just make compiler happy - return std::move(Argument()); -} - -bool GetFlagArgument(const OperatorDef &def, - const string &name, - bool def_value) { - for (const Argument &arg : def.arg()) { - if (arg.name() == name) { - MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ", - ProtoDebugString(arg)); - return arg.i(); - } - } - return def_value; -} - -Argument *GetMutableArgument(const string &name, - const bool create_if_missing, - OperatorDef *def) { - for (int i = 0; i < def->arg_size(); ++i) { - if (def->arg(i).name() == name) { - return def->mutable_arg(i); - } - } - // If no argument of the right name is found... - if (create_if_missing) { - Argument *arg = def->add_arg(); - arg->set_name(name); - return arg; - } else { - return nullptr; - } -} - } // namespace mace diff --git a/mace/core/proto_utils.h b/mace/core/proto_utils.h index 90747a41..6ccf98a4 100644 --- a/mace/core/proto_utils.h +++ b/mace/core/proto_utils.h @@ -7,137 +7,12 @@ #include -#include "google/protobuf/message_lite.h" -#ifndef MACE_USE_LITE_PROTO -#include "google/protobuf/message.h" -#endif // !MACE_USE_LITE_PROTO - #include "mace/core/common.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { using std::string; -using ::google::protobuf::MessageLite; - -// Common interfaces that reads file contents into a string. -bool ReadStringFromFile(const char *filename, string *str); -bool WriteStringToFile(const string &str, const char *filename); - -// Common interfaces that are supported by both lite and full protobuf. -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto); -inline bool ReadProtoFromBinaryFile(const string filename, MessageLite *proto) { - return ReadProtoFromBinaryFile(filename.c_str(), proto); -} - -void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename); -inline void WriteProtoToBinaryFile(const MessageLite &proto, - const string &filename) { - return WriteProtoToBinaryFile(proto, filename.c_str()); -} - -#ifdef MACE_USE_LITE_PROTO - -inline string ProtoDebugString(const MessageLite &proto) { - return proto.SerializeAsString(); -} - -// Text format MessageLite wrappers: these functions do nothing but just -// allowing things to compile. It will produce a runtime error if you are using -// MessageLite but still want text support. -inline bool ReadProtoFromTextFile(const char * /*filename*/, - MessageLite * /*proto*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; // Just to suppress compiler warning. -} -inline bool ReadProtoFromTextFile(const string filename, MessageLite *proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -inline void WriteProtoToTextFile(const MessageLite & /*proto*/, - const char * /*filename*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; -} -inline void WriteProtoToTextFile(const MessageLite &proto, - const string &filename) { - return WriteProtoToTextFile(proto, filename.c_str()); -} - -inline bool ReadProtoFromFile(const char *filename, MessageLite *proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string &filename, MessageLite *proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#else // MACE_USE_LITE_PROTO - -using ::google::protobuf::Message; - -inline string ProtoDebugString(const Message &proto) { - return proto.ShortDebugString(); -} - -bool ReadProtoFromTextFile(const char *filename, Message *proto); -inline bool ReadProtoFromTextFile(const string filename, Message *proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -void WriteProtoToTextFile(const Message &proto, const char *filename); -inline void WriteProtoToTextFile(const Message &proto, const string &filename) { - return WriteProtoToTextFile(proto, filename.c_str()); -} - -// Read Proto from a file, letting the code figure out if it is text or binary. -inline bool ReadProtoFromFile(const char *filename, Message *proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string &filename, Message *proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#endif // MACE_USE_LITE_PROTO - -template , - class IterableOutputs = std::initializer_list, - class IterableArgs = std::initializer_list> -OperatorDef CreateOperatorDef(const string &type, - const string &name, - const IterableInputs &inputs, - const IterableOutputs &outputs, - const IterableArgs &args) { - OperatorDef def; - def.set_type(type); - def.set_name(name); - for (const string &in : inputs) { - def.add_input(in); - } - for (const string &out : outputs) { - def.add_output(out); - } - for (const Argument &arg : args) { - def.add_arg()->CopyFrom(arg); - } - return def; -} - -// A simplified version compared to the full CreateOperator, if you do not need -// to specify args. -template , - class IterableOutputs = std::initializer_list> -inline OperatorDef CreateOperatorDef(const string &type, - const string &name, - const IterableInputs &inputs, - const IterableOutputs &outputs) { - return CreateOperatorDef(type, name, inputs, outputs, - std::vector()); -} /** * @brief A helper class to index into arguments. @@ -174,17 +49,6 @@ class ArgumentHelper { return ArgumentHelper(def).GetRepeatedArgument(name, default_value); } - template - static MessageType GetMessageArgument(const Def &def, const string &name) { - return ArgumentHelper(def).GetMessageArgument(name); - } - - template - static vector GetRepeatedMessageArgument(const Def &def, - const string &name) { - return ArgumentHelper(def).GetRepeatedMessageArgument(name); - } - explicit ArgumentHelper(const OperatorDef &def); explicit ArgumentHelper(const NetDef &netdef); bool HasArgument(const string &name) const; @@ -198,51 +62,10 @@ class ArgumentHelper { const string &name, const std::vector &default_value = std::vector()) const; - template - MessageType GetMessageArgument(const string &name) const { - MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); - MessageType message; - if (arg_map_.at(name).has_s()) { - MACE_CHECK(message.ParseFromString(arg_map_.at(name).s()), - "Faild to parse content from the string"); - } else { - VLOG(1) << "Return empty message for parameter " << name; - } - return message; - } - - template - vector GetRepeatedMessageArgument(const string &name) const { - MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); - vector messages(arg_map_.at(name).strings_size()); - for (int i = 0; i < messages.size(); ++i) { - MACE_CHECK(messages[i].ParseFromString(arg_map_.at(name).strings(i)), - "Faild to parse content from the string"); - } - return messages; - } - private: std::map arg_map_; }; -const Argument &GetArgument(const OperatorDef &def, const string &name); -bool GetFlagArgument(const OperatorDef &def, - const string &name, - bool def_value = false); - -Argument *GetMutableArgument(const string &name, - const bool create_if_missing, - OperatorDef *def); - -template -Argument MakeArgument(const string &name, const T &value); - -template -inline void AddArgument(const string &name, const T &value, OperatorDef *def) { - GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); -} - } // namespace mace #endif // MACE_CORE_PROTO_UTILS_H_ diff --git a/mace/core/serializer.cc b/mace/core/serializer.cc index cfe2d935..568e39fa 100644 --- a/mace/core/serializer.cc +++ b/mace/core/serializer.cc @@ -24,46 +24,32 @@ unique_ptr Serializer::Deserialize(const TensorProto &proto, switch (proto.data_type()) { case DT_FLOAT: - tensor->Copy(proto.float_data().data(), proto.float_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_DOUBLE: - tensor->Copy(proto.double_data().data(), - proto.double_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT32: - tensor->template Copy(proto.int32_data().data(), - proto.int32_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + break; + case DT_INT64: + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_UINT8: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT16: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT8: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); - break; - case DT_INT64: - tensor->Copy(proto.int64_data().data(), - proto.int64_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_UINT16: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_BOOL: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; - case DT_STRING: { - string *content = tensor->mutable_data(); - for (int i = 0; i < proto.string_data().size(); ++i) { - content[i] = proto.string_data(i); - } - } break; default: MACE_NOT_IMPLEMENTED; break; diff --git a/mace/core/serializer.h b/mace/core/serializer.h index 107d9f4e..9bfeea08 100644 --- a/mace/core/serializer.h +++ b/mace/core/serializer.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 94f95228..d2d634e6 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -9,7 +9,7 @@ #include "mace/core/common.h" #include "mace/core/logging.h" #include "mace/core/types.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/types.h b/mace/core/types.h index 616e40b2..2d1e94cb 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -6,7 +6,7 @@ #define MACE_CORE_TYPES_H_ #include "mace/core/common.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/half.h" diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index e8fc98f9..cc575ddd 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -69,7 +69,7 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { } void Workspace::CreateImageOutputTensor(const NetDef &net_def) { - if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) { + if (net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) { return; } std::map> mem_tensor_map; diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 8a706b87..6aea528a 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 73fa6767..3a394452 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -20,6 +20,8 @@ using namespace std; using namespace mace; +extern NetDef CreateNet() ; + void ParseShape(const string &str, vector *shape) { string tmp = str; while (!tmp.empty()) { @@ -34,6 +36,18 @@ void ParseShape(const string &str, vector *shape) { } } +DeviceType ParseDeviceType(const string &device_str) { + if(device_str.compare("CPU") == 0) { + return DeviceType::CPU; + } else if (device_str.compare("NEON") == 0) { + return DeviceType::NEON; + } else if (device_str.compare("OPENCL") == 0) { + return DeviceType::OPENCL; + } else { + return DeviceType::CPU; + } +} + int main(int argc, char **argv) { string model_file; string input_node; @@ -76,13 +90,13 @@ int main(int argc, char **argv) { ParseShape(input_shape, &shape); // load model - ifstream file_stream(model_file, ios::in | ios::binary); - NetDef net_def; - net_def.ParseFromIstream(&file_stream); - file_stream.close(); +// ifstream file_stream(model_file, ios::in | ios::binary); +// NetDef net_def; +// net_def.ParseFromIstream(&file_stream); +// file_stream.close(); + NetDef net_def = CreateNet(); - DeviceType device_type; - DeviceType_Parse(device, &device_type); + DeviceType device_type = ParseDeviceType(device); VLOG(0) << device_type; Workspace ws; ws.LoadModelTensor(net_def, device_type); diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index e56302e4..9c009b86 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_BATCH_NORM_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/bias_add.h b/mace/kernels/bias_add.h index a1e05cae..c738502a 100644 --- a/mace/kernels/bias_add.h +++ b/mace/kernels/bias_add.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_BIAS_ADD_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index e70b4e73..5c3b22ab 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/types.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/tensor.h" namespace mace { diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 840ce727..c1f1d076 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/kernels/conv_pool_2d_util.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/space_to_batch.h b/mace/kernels/space_to_batch.h index ebf5994b..4f7bd1af 100644 --- a/mace/kernels/space_to_batch.h +++ b/mace/kernels/space_to_batch.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_CONV_2D_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 89f6b8b7..a1ba1eeb 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -40,7 +40,6 @@ cc_library( ], deps = [ "//mace/kernels", - "//mace/proto:cc_proto", ], alwayslink = 1, ) diff --git a/mace/ops/concat.h b/mace/ops/concat.h index 77e43030..0edf3455 100644 --- a/mace/ops/concat.h +++ b/mace/ops/concat.h @@ -7,7 +7,7 @@ #include "mace/core/operator.h" #include "mace/kernels/concat.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { template diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 8d593940..3d8c809b 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -159,7 +159,6 @@ class OpsTestNet { for (auto &op_def_ : op_defs_) { net_def.add_op()->CopyFrom(op_def_); } - VLOG(3) << net_def.DebugString(); net_ = CreateNet(net_def, &ws_, device); device_ = device; return net_->Run(); diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template new file mode 100644 index 00000000..c5a3bdb5 --- /dev/null +++ b/mace/python/tools/model.template @@ -0,0 +1,139 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include +#include "mace/core/mace.h" +namespace mace { + +{% for tensor in tensors %} +static unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = { +{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%} +}; +{% endfor %} + +static void CreateNetArg(NetDef &net_def) { + net_def.mutable_arg().reserve({{ net.arg|length }}); + Argument *arg = nullptr; + {% for arg in net.arg %} + + arg = net_def.add_arg(); + arg->set_name({{ arg.name|tojson }}); + + {% if arg.has_f %} + arg->set_f({{ arg.f }}); + {% endif %} + + {% if arg.has_i %} + arg->set_i({{ arg.i }}); + {% endif %} + + {% if arg.has_s %} + arg->set_s({{ arg.s|tojson }}); + {% endif %} + + arg->set_floats({ {{ arg.floats|join(', ') }} }); + arg->set_ints({ {{ arg.ints|join(', ') }} }); + arg->set_strings({ {{ arg.strings|stringfy() }} }); + + {% endfor %} + +} + +static void UpdateOp(OperatorDef &op, + const std::string &name, + const std::string &type, + const int mem_id, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &output_shapes, + const std::vector &output_types) { + op.set_name(name); + op.set_type(type); + op.set_input(inputs); + op.set_output(outputs); + op.set_mem_id(mem_id); + op.set_output_shape(output_shapes); + op.set_output_type(output_types); +} + +static void CreateOperators(std::vector &ops) { + ops.resize({{ net.op|length }}); + Argument *arg = nullptr; + {% for i in range(net.op|length) %} + {% for arg in net.op[i].arg %} + + arg = ops[{{i}}].add_arg(); + arg->set_name({{ arg.name|tojson }}); + + {%- if arg.HasField('f') %} + arg->set_f({{ arg.f }}); + {%- endif %} + {%- if arg.HasField('i') %} + arg->set_i({{ arg.i }}); + {%- endif %} + {%- if arg.HasField('s') %} + arg->set_s({{ arg.s|tojson }}); + {%- endif %} + + arg->set_floats({ {{ arg.floats|join(', ') }} }); + arg->set_ints({ {{ arg.ints|join(', ') }} }); + arg->set_strings({ {{ arg.strings|stringfy() }} }); + {% endfor %} + + UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }}, + { {{ net.op[i].input|stringfy }} }, + { {{ net.op[i].output|stringfy }} }, + { {{ net.op[i].output_shape.dims|join(', ') }} }, + { {{ net.op[i].output_type|join(', ') }} }); + + {% endfor %} + +} + +static void CreateTensors(std::vector &tensors) { + tensors.reserve({{ net.tensors|length }}); + + {% for tensor in net.tensors %} + + tensors.emplace_back(TensorProto( + {{ tensor.name|tojson }}, {{ "_" + tensor.name[:-2].replace("/", "_") }}, + { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, + {{ tensor.node_id }} + )); + + {% endfor %} + +} + + +static void CreateMemoryArena(MemoryArena &mem_arena) { + auto mem_block = mem_arena.mutable_mem_block(); + mem_block.reserve({{ net.mem_arena.mem_block|length }}); + + {% for mem_blk in net.mem_arena.mem_block %} + mem_block.emplace_back(MemoryBlock({{ mem_blk.mem_id }}, + {{mem_blk.x}}, + {{mem_blk.y}})); + {% endfor %} + +} + +NetDef CreateNet() { + NetDef net_def; + net_def.set_name("{{ net.name}}"); + net_def.set_version("{{ net.version }}"); + + CreateNetArg(net_def); + + CreateOperators(net_def.mutable_op()); + + CreateTensors(net_def.mutable_tensors()); + + CreateMemoryArena(net_def.mutable_mem_arena()); + + return net_def; +} + +} // namespace mace diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index 886999d3..7599cd5b 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -2,13 +2,45 @@ import argparse import sys import tensorflow as tf from tensorflow import gfile +from mace.proto import mace_pb2 from mace.python.tools import tf_converter_lib from mace.python.tools import tf_dsp_converter_lib +import struct +from jinja2 import Environment, FileSystemLoader +import os # ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3 FLAGS = None +class TensorInfo: + def __init__(self, t): + self.name = t.name + if t.data_type == mace_pb2.DT_FLOAT: + self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data)) + elif t.data_type == mace_pb2.DT_INT32: + self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data)) + +def stringfy(value): + return ', '.join('"{0}"'.format(w) for w in value) + +def convert_to_source(net_def): + # Capture our current directory + template_dir = os.path.dirname(FLAGS.template) + template_name = os.path.basename(FLAGS.template) + print template_dir + + # Create the jinja2 environment. + # Notice the use of trim_blocks, which greatly helps control whitespace. + j2_env = Environment(loader=FileSystemLoader(template_dir), + trim_blocks=True) + j2_env.filters['stringfy'] = stringfy + tensors = [TensorInfo(t) for t in net_def.tensors] + return j2_env.get_template(template_name).render( + tensors = tensors, + net = net_def + ) + def main(unused_args): if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") @@ -19,6 +51,7 @@ def main(unused_args): data = f.read() input_graph_def.ParseFromString(data) + print 'done' if FLAGS.runtime == 'dsp': output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize) @@ -26,11 +59,16 @@ def main(unused_args): output_graph_def = tf_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) - with gfile.GFile(FLAGS.output, "wb") as f: - f.write(output_graph_def.SerializeToString()) - with gfile.GFile(FLAGS.output + '_txt', "wb") as f: - # output_graph_def.ClearField('tensors') - f.write(str(output_graph_def)) + if FLAGS.output_type == 'source': + source = convert_to_source(output_graph_def) + with gfile.GFile(FLAGS.output, "wb") as f: + f.write(source) + else: + with gfile.GFile(FLAGS.output, "wb") as f: + f.write(output_graph_def.SerializeToString()) + with gfile.GFile(FLAGS.output + '_txt', "wb") as f: + # output_graph_def.ClearField('tensors') + f.write(str(output_graph_def)) def parse_args(): @@ -51,7 +89,7 @@ def parse_args(): "--runtime", type=str, default="cpu", - help="Runtime: cpu/gpu/dsp.") + help="Runtime: cpu/gpu/dsp") parser.add_argument( "--input_node", type=str, @@ -72,6 +110,16 @@ def parse_args(): type=str, default='DT_FLOAT', help="e.g., DT_HALF/DT_FLOAT") + parser.add_argument( + "--output_type", + type=str, + default="source", + help="output type: source/pb") + parser.add_argument( + "--template", + type=str, + default="", + help="template path") return parser.parse_known_args() -- GitLab