diff --git a/mace/core/BUILD b/mace/core/BUILD index 6f1af8a54e3dbab2f14d30c1b6116aabe1bf183e..9f5ca2cb44e810da944bd72fe7db33dff7ab636b 100644 --- a/mace/core/BUILD +++ b/mace/core/BUILD @@ -62,7 +62,6 @@ cc_library( ]), deps = [ ":logging", - "//mace/proto:cc_proto", "//mace/proto:stats_proto", "//mace/utils", ":opencl_runtime", diff --git a/mace/core/allocator.h b/mace/core/allocator.h index bc803879a5b1a8353f655577e846d3425107f33a..0f30d4adcea0b000ae912a923c6c3d3b9ea1c507 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -9,7 +9,7 @@ #include #include "mace/core/common.h" #include "mace/core/registry.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/types.h" namespace mace { diff --git a/mace/core/arg_helper.cc b/mace/core/arg_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..d2230c44ca3247b532854f194b07ea2c071ded6a --- /dev/null +++ b/mace/core/arg_helper.cc @@ -0,0 +1,116 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/arg_helper.h" + +namespace mace { + +ArgumentHelper::ArgumentHelper(const OperatorDef &def) { + for (auto &arg : def.arg()) { + if (arg_map_.find(arg.name()) != arg_map_.end()) { + LOG(WARNING) << "Duplicated argument name found in operator def."; + } + + arg_map_[arg.name()] = arg; + } +} + +ArgumentHelper::ArgumentHelper(const NetDef &netdef) { + for (auto &arg : netdef.arg()) { + MACE_CHECK(arg_map_.count(arg.name()) == 0, + "Duplicated argument name found in net def."); + arg_map_[arg.name()] = arg; + } +} + +bool ArgumentHelper::HasArgument(const string &name) const { + return arg_map_.count(name); +} + +namespace { +// Helper function to verify that conversion between types won't loose any +// significant bit. +template +bool SupportsLosslessConversion(const InputType &value) { + return static_cast(static_cast(value)) == value; +} +} + +#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \ + enforce_lossless_conversion) \ + template <> \ + T ArgumentHelper::GetSingleArgument(const string &name, \ + const T &default_value) const { \ + if (arg_map_.count(name) == 0) { \ + VLOG(1) << "Using default parameter value " << default_value \ + << " for parameter " << name; \ + return default_value; \ + } \ + MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \ + " does not have the right field: expected field " #fieldname); \ + auto value = arg_map_.at(name).fieldname(); \ + if (enforce_lossless_conversion) { \ + auto supportsConversion = \ + SupportsLosslessConversion(value); \ + MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \ + "cannot be represented correctly in a target type"); \ + } \ + return value; \ + } \ + template <> \ + bool ArgumentHelper::HasSingleArgumentOfType(const string &name) const { \ + if (arg_map_.count(name) == 0) { \ + return false; \ + } \ + return arg_map_.at(name).has_##fieldname(); \ + } + +INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) +#undef INSTANTIATE_GET_SINGLE_ARGUMENT + +#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \ + enforce_lossless_conversion) \ + template <> \ + vector ArgumentHelper::GetRepeatedArgument( \ + const string &name, const std::vector &default_value) const { \ + if (arg_map_.count(name) == 0) { \ + return default_value; \ + } \ + vector values; \ + for (const auto &v : arg_map_.at(name).fieldname()) { \ + if (enforce_lossless_conversion) { \ + auto supportsConversion = \ + SupportsLosslessConversion(v); \ + MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \ + "cannot be represented correctly in a target type"); \ + } \ + values.push_back(v); \ + } \ + return values; \ + } + +INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) +#undef INSTANTIATE_GET_REPEATED_ARGUMENT + +} // namespace mace diff --git a/mace/core/arg_helper.h b/mace/core/arg_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..a7c66173a98032a5decc8cdf7983b4cd91e0fa4a --- /dev/null +++ b/mace/core/arg_helper.h @@ -0,0 +1,71 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_ARG_HELPER_H_ +#define MACE_CORE_ARG_HELPER_H_ + +#include + +#include "mace/core/common.h" +#include "mace/core/mace.h" + +namespace mace { + +using std::string; + +/** + * @brief A helper class to index into arguments. + * + * This helper helps us to more easily index into a set of arguments + * that are present in the operator. To save memory, the argument helper + * does not copy the operator def, so one would need to make sure that the + * lifetime of the OperatorDef object outlives that of the ArgumentHelper. + */ +class ArgumentHelper { + public: + template + static bool HasArgument(const Def &def, const string &name) { + return ArgumentHelper(def).HasArgument(name); + } + + template + static T GetSingleArgument(const Def &def, + const string &name, + const T &default_value) { + return ArgumentHelper(def).GetSingleArgument(name, default_value); + } + + template + static bool HasSingleArgumentOfType(const Def &def, const string &name) { + return ArgumentHelper(def).HasSingleArgumentOfType(name); + } + + template + static vector GetRepeatedArgument( + const Def &def, + const string &name, + const std::vector &default_value = std::vector()) { + return ArgumentHelper(def).GetRepeatedArgument(name, default_value); + } + + explicit ArgumentHelper(const OperatorDef &def); + explicit ArgumentHelper(const NetDef &netdef); + bool HasArgument(const string &name) const; + + template + T GetSingleArgument(const string &name, const T &default_value) const; + template + bool HasSingleArgumentOfType(const string &name) const; + template + vector GetRepeatedArgument( + const string &name, + const std::vector &default_value = std::vector()) const; + + private: + std::map arg_map_; +}; + +} // namespace mace + +#endif // MACE_CORE_ARG_HELPER_H_ diff --git a/mace/core/mace.cc b/mace/core/mace.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb181d200965577e2b2cf77302d94509b683e57d --- /dev/null +++ b/mace/core/mace.cc @@ -0,0 +1,424 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/mace.h" +#include "mace/core/types.h" + +namespace mace { + +TensorProto::TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const DataType data_type, + uint32_t node_id) : + name_(name), + data_(data), + data_size_(0), + dims_(dims.begin(), dims.end()), + data_type_(data_type), + node_id_(node_id) { + data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); +} + +TensorProto::TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const int data_type, + uint32_t node_id) : + name_(name), + data_(data), + data_size_(0), + dims_(dims.begin(), dims.end()), + data_type_(static_cast(data_type)), + node_id_(node_id) { + data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); +} + +const std::string &TensorProto::name() const { + return name_; +} +unsigned char *TensorProto::data() const { + return data_; +} +const int64_t TensorProto::data_size() const { + return data_size_; +} +const std::vector &TensorProto::dims() const { + return dims_; +} +DataType TensorProto::data_type() const { + return data_type_; +} +uint32_t TensorProto::node_id() const { + return node_id_; +} + +Argument::Argument() : has_bits_(0) {} + +void Argument::CopyFrom(const Argument &from) { + this->name_ = from.name(); + this->f_ = from.f(); + this->i_ = from.i(); + this->s_ = from.s(); + auto floats = from.floats(); + this->floats_.resize(floats.size()); + std::copy(floats.begin(), floats.end(), this->floats_.begin()); + auto ints = from.ints(); + this->ints_.resize(ints.size()); + std::copy(ints.begin(), ints.end(), this->ints_.begin()); + auto strings = from.floats(); + this->strings_.resize(strings.size()); + std::copy(floats.begin(), floats.end(), this->floats_.begin()); + + this->has_bits_ = from.has_bits_; +} +const std::string &Argument::name() const { + return name_; +} +void Argument::set_name(const std::string &value) { + name_ = value; +} +bool Argument::has_f() const { + return (has_bits_ & 0x00000001u) != 0; +} +void Argument::set_has_f() { + has_bits_ |= 0x00000001u; +} +float Argument::f() const { + return f_; +} +void Argument::set_f(float value) { + set_has_f(); + f_ = value; +} +bool Argument::has_i() const { + return (has_bits_ & 0x00000002u) != 0; +} +void Argument::set_has_i() { + has_bits_ |= 0x00000002u; +} +int64_t Argument::i() const { + return i_; +} +void Argument::set_i(int64_t value) { + set_has_i(); + i_ = value; +} +bool Argument::has_s() const { + return (has_bits_ & 0x00000004u) != 0; +} +void Argument::set_has_s() { + has_bits_ |= 0x00000004u; +} +std::string Argument::s() const { + return s_; +} +void Argument::set_s(const std::string &value) { + set_has_s(); + s_ = value; +} +const std::vector &Argument::floats() const { + return floats_; +} +void Argument::add_floats(float value) { + floats_.push_back(value); +} +void Argument::set_floats(const std::vector &value) { + floats_.resize(value.size()); + std::copy(value.begin(), value.end(), floats_.begin()); +} +const std::vector &Argument::ints() const { + return ints_; +} +void Argument::add_ints(int64_t value) { + ints_.push_back(value); +} +void Argument::set_ints(const std::vector &value) { + ints_.resize(value.size()); + std::copy(value.begin(), value.end(), ints_.begin()); +} +const std::vector &Argument::strings() const { + return strings_; +} +void Argument::add_strings(const ::std::string &value) { + strings_.push_back(value); +} +void Argument::set_strings(const std::vector &value) { + strings_.resize(value.size()); + std::copy(value.begin(), value.end(), strings_.begin()); +} + + +// OutputShape +OutputShape::OutputShape() {} +OutputShape::OutputShape(const std::vector &dims): + dims_(dims.begin(), dims.end()) {} +void OutputShape::CopyFrom(const OutputShape &from) { + auto from_dims = from.dims(); + dims_.resize(from_dims.size()); + std::copy(from_dims.begin(), from_dims.end(), dims_.begin()); +} +const std::vector &OutputShape::dims() const { + return dims_; +} + +// Operator Def +void OperatorDef::CopyFrom(const OperatorDef &from) { + name_ = from.name(); + type_ = from.type(); + + auto from_input = from.input(); + input_.resize(from_input.size()); + std::copy(from_input.begin(), from_input.end(), input_.begin()); + auto from_output = from.output(); + output_.resize(from_output.size()); + std::copy(from_output.begin(), from_output.end(), output_.begin()); + auto from_arg = from.arg(); + arg_.resize(from_arg.size()); + for (int i = 0; i < from_arg.size(); ++i) { + arg_[i].CopyFrom(from_arg[i]); + } + auto from_output_shape = from.output_shape(); + output_shape_.resize(from_output_shape.size()); + for (int i = 0; i < from_output_shape.size(); ++i) { + output_shape_[i].CopyFrom(from_output_shape[i]); + } + auto from_data_type = from.output_type(); + output_type_.resize(from_data_type.size()); + std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin()); + + mem_id_ = from.mem_id(); + + // nnlib + node_id_ = from.node_id(); + op_id_ = from.op_id(); + padding_ = from.padding(); + auto from_node_input = from.node_input(); + node_input_.resize(from_node_input.size()); + for (int i = 0; i < from_node_input.size(); ++i) { + node_input_[i].CopyFrom(from_node_input[i]); + } + auto from_out_max_byte_size = from.out_max_byte_size(); + out_max_byte_size_.resize(from_out_max_byte_size.size()); + std::copy(from_out_max_byte_size.begin(), from_out_max_byte_size.end(), out_max_byte_size_.begin()); + + has_bits_ = from.has_bits_; + +} + +const std::string &OperatorDef::name() const { + return name_; +} +void OperatorDef::set_name(const std::string &name_) { + set_has_name(); + OperatorDef::name_ = name_; +} +bool OperatorDef::has_name() const { + return (has_bits_ & 0x00000001u) != 0; +} +void OperatorDef::set_has_name() { + has_bits_ |= 0x00000001u; +} +const std::string &OperatorDef::type() const { + return type_; +} +void OperatorDef::set_type(const std::string &type_) { + set_has_type(); + OperatorDef::type_ = type_; +} +bool OperatorDef::has_type() const { + return (has_bits_ & 0x00000002u) != 0; +} +void OperatorDef::set_has_type() { + has_bits_ |= 0x00000002u; +} +int OperatorDef::mem_id() const { + return mem_id_; +} +void OperatorDef::set_mem_id(const int mem_id) { + set_has_mem_id(); + mem_id_ = mem_id; +} +bool OperatorDef::has_mem_id() const { + return (has_bits_ & 0x00000004u) != 0; +} +void OperatorDef::set_has_mem_id() { + has_bits_ |= 0x00000004u; +} +uint32_t OperatorDef::node_id() const { + return node_id_; +} +uint32_t OperatorDef::op_id() const { + return op_id_; +} +uint32_t OperatorDef::padding() const { + return padding_; +} +const std::vector &OperatorDef::node_input() const { + return node_input_; +} +const std::vector &OperatorDef::out_max_byte_size() const { + return out_max_byte_size_; +} +const std::vector &OperatorDef::input() const { + return input_; +} +const std::string &OperatorDef::input(int index) const { + MACE_CHECK(0 <= index && index <= input_.size()); + return input_[index]; +} +std::string *OperatorDef::add_input() { + input_.push_back(""); + return &input_.back(); +} +void OperatorDef::add_input(const ::std::string &value) { + input_.push_back(value); +} +void OperatorDef::add_input(::std::string &&value) { + input_.push_back(value); +} +void OperatorDef::set_input(const std::vector &value) { + input_.resize(value.size()); + std::copy(value.begin(), value.end(), input_.begin()); +} +const std::vector &OperatorDef::output() const { + return output_; +} +const std::string &OperatorDef::output(int index) const { + MACE_CHECK(0 <= index && index <= output_.size()); + return output_[index]; +} +std::string *OperatorDef::add_output() { + output_.push_back(""); + return &output_.back(); +} +void OperatorDef::add_output(const ::std::string &value) { + output_.push_back(value); +} +void OperatorDef::add_output(::std::string &&value) { + output_.push_back(value); +} +void OperatorDef::set_output(const std::vector &value) { + output_.resize(value.size()); + std::copy(value.begin(), value.end(), output_.begin()); +} +const std::vector &OperatorDef::arg() const { + return arg_; +} +Argument *OperatorDef::add_arg() { + arg_.emplace_back(Argument()); + return &arg_.back(); +} +const std::vector &OperatorDef::output_shape() const { + return output_shape_; +} +void OperatorDef::add_output_shape(const OutputShape &value) { + output_shape_.push_back(value); +} +const std::vector &OperatorDef::output_type() const { + return output_type_; +} +void OperatorDef::set_output_type(const std::vector &value) { + output_type_.resize(value.size()); + std::copy(value.begin(), value.end(), output_type_.begin()); +} + +// MemoryBlock +MemoryBlock::MemoryBlock(int mem_id, uint32_t x, uint32_t y) : + mem_id_(mem_id), x_(x), y_(y) {} + +int MemoryBlock::mem_id() const { + return mem_id_; +} +uint32_t MemoryBlock::x() const { + return x_; +} +uint32_t MemoryBlock::y() const { + return y_; +} + +// NetDef +NetDef::NetDef() : has_bits_(0) {} + +const std::string &NetDef::name() const { + return name_; +} +void NetDef::set_name(const std::string &value) { + set_has_name(); + name_ = value; +} +bool NetDef::has_name() const { + return (has_bits_ & 0x00000001u) != 0; +} +void NetDef::set_has_name() { + has_bits_ |= 0x00000001u; +} +const std::string &NetDef::version() const { + return version_; +} +void NetDef::set_version(const std::string &value) { + set_has_version(); + version_ = value; +} +bool NetDef::has_version() const { + return (has_bits_ & 0x00000002u) != 0; +} +void NetDef::set_has_version() { + has_bits_ |= 0x00000002u; +} +const std::vector &NetDef::op() const { + return op_; +} +OperatorDef *NetDef::add_op() { + op_.emplace_back(OperatorDef()); + return &op_.back(); +} +std::vector &NetDef::mutable_op() { + return op_; +} +const std::vector &NetDef::arg() const { + return arg_; +} +Argument *NetDef::add_arg() { + arg_.emplace_back(Argument()); + return &arg_.back(); +} +std::vector &NetDef::mutable_arg() { + return arg_; +} +const std::vector &NetDef::tensors() const { + return tensors_; +} +std::vector &NetDef::mutable_tensors() { + return tensors_; +} +const MemoryArena &NetDef::mem_arena() const { + return mem_arena_; +} +MemoryArena &NetDef::mutable_mem_arena() { + set_has_mem_arena(); + return mem_arena_; +} +bool NetDef::has_mem_arena() const { + return (has_bits_ & 0x00000004u) != 0; +} +void NetDef::set_has_mem_arena() { + has_bits_ |= 0x00000004u; +} +const std::vector &NetDef::input_info() const { + return input_info_; +} +const std::vector &NetDef::output_info() const { + return output_info_; +} + +int NetDef::op_size() const { + return op_.size(); +} + +const OperatorDef &NetDef::op(const int idx) const { + MACE_CHECK(0 <= idx && idx < op_size()); + return op_[idx]; +} +} // namespace mace diff --git a/mace/core/mace.h b/mace/core/mace.h new file mode 100644 index 0000000000000000000000000000000000000000..855860c221d0f4410586ffc2f5e2703a94927bb3 --- /dev/null +++ b/mace/core/mace.h @@ -0,0 +1,337 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_MACE_H_ +#define MACE_CORE_MACE_H_ +#include +#include +#include +#include "mace/core/logging.h" + +namespace mace { + +enum NetMode { + INIT = 0, + NORMAL = 1 +}; + +enum DeviceType { + CPU = 0, + NEON = 1, + OPENCL = 2 +}; + +enum DataType { + DT_INVALID = 0, + DT_FLOAT = 1, + DT_DOUBLE = 2, + DT_INT32 = 3, + DT_UINT8 = 4, + DT_INT16 = 5, + DT_INT8 = 6, + DT_STRING = 7, + DT_INT64 = 8, + DT_UINT16 = 9, + DT_BOOL = 10, + DT_HALF = 19, + DT_UINT32 = 22 +}; + +class TensorProto { + public: + TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const DataType data_type = DT_FLOAT, + uint32_t node_id = 0); + TensorProto(const std::string &name, + unsigned char *data, + const std::vector &dims, + const int data_type, + uint32_t node_id = 0); + + const std::string &name() const; + unsigned char *data() const; + const int64_t data_size() const; + const std::vector &dims() const; + DataType data_type() const; + uint32_t node_id() const; + + private: + std::string name_; + unsigned char *data_; + int64_t data_size_; + std::vector dims_; + DataType data_type_; + uint32_t node_id_; +}; + +class Argument { + public: + Argument(); + void CopyFrom(const Argument &from) ; + public: + const std::string &name() const; + void set_name(const std::string& value); + bool has_f() const; + float f() const ; + void set_f(float value) ; + bool has_i() const ; + int64_t i() const ; + void set_i(int64_t value); + bool has_s() const ; + std::string s() const ; + void set_s(const std::string& value) ; + const std::vector &floats() const ; + void add_floats(float value) ; + void set_floats(const std::vector &value); + const std::vector &ints() const ; + void add_ints(int64_t value) ; + void set_ints(const std::vector &value); + const std::vector &strings() const ; + void add_strings(const ::std::string& value) ; + void set_strings(const std::vector &value); + + private: + void set_has_f() ; + void set_has_i() ; + void set_has_s() ; + + private: + std::string name_; + float f_; + int64_t i_; + std::string s_; + std::vector floats_; + std::vector ints_; + std::vector strings_; + uint32_t has_bits_; +}; + +class NodeInput { + public: + void CopyFrom(const NodeInput &from) { + node_id_ = from.node_id(); + output_port_ = from.output_port(); + } + public: + int node_id() const { + return node_id_; + } + int output_port() const { + return output_port_; + } + private: + int node_id_; + int output_port_; +}; + +class OutputShape { + public: + OutputShape(); + OutputShape(const std::vector &dims); + void CopyFrom(const OutputShape &from); + public: + const std::vector &dims() const; + private: + std::vector dims_; +}; + +class OperatorDef { + public: + void CopyFrom(const OperatorDef &from); + + public: + const std::string &name() const; + void set_name(const std::string &name_); + bool has_name() const; + const std::string &type() const; + void set_type(const std::string &type_); + bool has_type() const; + int mem_id() const; + void set_mem_id(const int mem_id); + bool has_mem_id() const; + uint32_t node_id() const; + uint32_t op_id() const; + uint32_t padding() const; + const std::vector &node_input() const; + const std::vector &out_max_byte_size() const; + const std::vector &input() const; + const std::string& input(int index) const; + std::string* add_input(); + void add_input(const ::std::string& value); + void add_input(::std::string&& value); + void set_input(const std::vector &value); + const std::vector &output() const; + const std::string& output(int index) const; + std::string* add_output(); + void add_output(const ::std::string& value); + void add_output(::std::string&& value); + void set_output(const std::vector &value); + const std::vector &arg() const; + Argument* add_arg(); + const std::vector &output_shape() const; + void add_output_shape(const OutputShape &value); + const std::vector &output_type() const; + void set_output_type(const std::vector &value); + + private: + void set_has_name(); + void set_has_type(); + void set_has_mem_id(); + + private: + std::string name_; + std::string type_; + + std::vector input_; + std::vector output_; + std::vector arg_; + std::vector output_shape_; + std::vector output_type_; + + int mem_id_; + + // nnlib + uint32_t node_id_; + uint32_t op_id_; + uint32_t padding_; + std::vector node_input_; + std::vector out_max_byte_size_; + + uint32_t has_bits_; +}; + +class MemoryBlock { + public: + MemoryBlock(int mem_id, uint32_t x, uint32_t y); + public: + int mem_id() const; + uint32_t x() const; + uint32_t y() const; + private: + int mem_id_; + uint32_t x_; + uint32_t y_; +}; + +class MemoryArena { + public: + inline const std::vector &mem_block() const { + return mem_block_; + } + inline std::vector &mutable_mem_block() { + return mem_block_; + } + inline int mem_block_size() const { + return mem_block_.size(); + } + private: + std::vector mem_block_; + +}; + +// for hexagon mace-nnlib +class InputInfo { + public: + const std::string &name() const { + return name_; + } + int32_t node_id() const { + return node_id_; + } + int32_t max_byte_size() const { + return max_byte_size_; + } + DataType data_type() const { + return data_type_; + } + const std::vector &dims() const { + return dims_; + } + private: + std::string name_; + int32_t node_id_; + int32_t max_byte_size_; // only support 32-bit len + DataType data_type_; + std::vector dims_; +}; + +class OutputInfo { + public: + const std::string &name() const { + return name_; + } + int32_t node_id() const { + return node_id_; + } + int32_t max_byte_size() const { + return max_byte_size_; + } + DataType data_type() const { + return data_type_; + } + const std::vector &dims() const { + return dims_; + } + private: + std::string name_; + int32_t node_id_; + int32_t max_byte_size_; // only support 32-bit len + DataType data_type_; + std::vector dims_; +}; + +class NetDef { + public: + NetDef(); + int op_size() const; + + const OperatorDef &op(const int idx) const; + public: + const std::string &name() const; + bool has_name() const; + void set_name(const std::string& value); + const std::string &version() const; + bool has_version() const; + void set_version(const std::string& value); + + const std::vector &op() const; + OperatorDef* add_op(); + std::vector &mutable_op(); + const std::vector &arg() const; + Argument *add_arg(); + std::vector &mutable_arg(); + const std::vector &tensors() const; + std::vector &mutable_tensors(); + const MemoryArena &mem_arena() const; + bool has_mem_arena() const; + MemoryArena &mutable_mem_arena(); + const std::vector &input_info() const; + const std::vector &output_info() const; + + private: + void set_has_name(); + void set_has_version(); + void set_has_mem_arena(); + + private: + std::string name_; + std::string version_; + std::vector op_; + std::vector arg_; + std::vector tensors_; + + // for mem optimization + MemoryArena mem_arena_; + + // for hexagon mace-nnlib + std::vector input_info_; + std::vector output_info_; + + uint32_t has_bits_; +}; + +} // namespace mace +#endif // MACE_CORE_MACE_H_ diff --git a/mace/core/net.cc b/mace/core/net.cc index e1b16a032f3961f62786cd32595e6c3f87789800..55f3c5f6497f761aa04ff0fe6c638d977c626473 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -50,7 +50,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { } } if (!op->Run()) { - LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); + LOG(ERROR) << "Operator failed: " << op->debug_def().name(); return false; } diff --git a/mace/core/net.h b/mace/core/net.h index 67c954f3e59c9bc1f8c8c46a6ce23858f94c1675..109d2d66c340bddd42799d38c7ecb6fadd66e746 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -8,7 +8,7 @@ #include "mace/core/common.h" #include "mace/core/operator.h" #include "mace/core/workspace.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/proto/stats.pb.h" namespace mace { diff --git a/mace/core/operator.h b/mace/core/operator.h index 4cd52fb366753e21557b38f0baf675d5937cc0dd..ef0cd7bd560fba0a8de62b55d821bb1a812cce26 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -6,11 +6,11 @@ #define MACE_CORE_OPERATOR_H #include "mace/core/common.h" -#include "mace/core/proto_utils.h" +#include "mace/core/arg_helper.h" #include "mace/core/registry.h" #include "mace/core/tensor.h" #include "mace/core/workspace.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/proto_utils.cc b/mace/core/proto_utils.cc deleted file mode 100644 index 064e9b5308036adac33162dae8f5ca9599015f4b..0000000000000000000000000000000000000000 --- a/mace/core/proto_utils.cc +++ /dev/null @@ -1,352 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/proto_utils.h" - -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" - -#ifndef MACE_USE_LITE_PROTO -#include "google/protobuf/text_format.h" -#endif // !MACE_USE_LITE_PROTO - -namespace mace { - -bool ReadStringFromFile(const char *filename, string *str) { - std::ifstream ifs(filename, std::ios::in); - if (!ifs) { - VLOG(1) << "File cannot be opened: " << filename - << " error: " << ifs.rdstate(); - return false; - } - ifs.seekg(0, std::ios::end); - size_t n = ifs.tellg(); - str->resize(n); - ifs.seekg(0); - ifs.read(&(*str)[0], n); - return true; -} - -bool WriteStringToFile(const string &str, const char *filename) { - std::ofstream ofs(filename, std::ios::out | std::ios::trunc); - if (!ofs.is_open()) { - VLOG(1) << "File cannot be created: " << filename - << " error: " << ofs.rdstate(); - return false; - } - ofs << str; - return true; -} - -// IO-specific proto functions: we will deal with the protocol buffer lite and -// full versions differently. - -#ifdef MACE_USE_LITE_PROTO - -// Lite runtime. - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const string &filename) - : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void *buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(filename)); - stream.SetOwnsCopyingStream(true); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -void WriteProtoToBinaryFile(const MessageLite & /*proto*/, - const char * /*filename*/) { - LOG(FATAL) << "Not implemented yet."; -} - -#else // MACE_USE_LITE_PROTO - -// Full protocol buffer. - -using ::google::protobuf::io::FileInputStream; -using ::google::protobuf::io::FileOutputStream; -using ::google::protobuf::io::ZeroCopyInputStream; -using ::google::protobuf::io::CodedInputStream; -using ::google::protobuf::io::ZeroCopyOutputStream; -using ::google::protobuf::io::CodedOutputStream; - -bool ReadProtoFromTextFile(const char *filename, Message *proto) { - int fd = open(filename, O_RDONLY); - MACE_CHECK(fd != -1, "File not found: ", filename); - FileInputStream *input = new FileInputStream(fd); - bool success = google::protobuf::TextFormat::Parse(input, proto); - delete input; - close(fd); - return success; -} - -void WriteProtoToTextFile(const Message &proto, const char *filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - FileOutputStream *output = new FileOutputStream(fd); - MACE_CHECK(google::protobuf::TextFormat::Print(proto, output)); - delete output; - close(fd); -} - -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) { -#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified - int fd = open(filename, O_RDONLY | O_BINARY); -#else - int fd = open(filename, O_RDONLY); -#endif - MACE_CHECK(fd != -1, "File not found: ", filename); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input( - new CodedInputStream(raw_input.get())); - // A hack to manually allow using very large protocol buffers. - coded_input->SetTotalBytesLimit(1073741824, 536870912); - bool success = proto->ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - close(fd); - return success; -} - -void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ", - errno); - std::unique_ptr raw_output(new FileOutputStream(fd)); - std::unique_ptr coded_output( - new CodedOutputStream(raw_output.get())); - MACE_CHECK(proto.SerializeToCodedStream(coded_output.get())); - coded_output.reset(); - raw_output.reset(); - close(fd); -} - -#endif // MACE_USE_LITE_PROTO - -ArgumentHelper::ArgumentHelper(const OperatorDef &def) { - for (auto &arg : def.arg()) { - if (arg_map_.find(arg.name()) != arg_map_.end()) { - MACE_CHECK( - arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(), - "Found argument of the same name '", arg.name(), - "' but with different contents: ", ProtoDebugString(def)); - - LOG(WARNING) << "Duplicated argument name found in operator def: " - << ProtoDebugString(def) - << ", arg: " << ProtoDebugString(arg); - } - - arg_map_[arg.name()] = arg; - } -} - -ArgumentHelper::ArgumentHelper(const NetDef &netdef) { - for (auto &arg : netdef.arg()) { - MACE_CHECK(arg_map_.count(arg.name()) == 0, - "Duplicated argument name found in net def: ", - ProtoDebugString(netdef)); - arg_map_[arg.name()] = arg; - } -} - -bool ArgumentHelper::HasArgument(const string &name) const { - return arg_map_.count(name); -} - -namespace { -// Helper function to verify that conversion between types won't loose any -// significant bit. -template -bool SupportsLosslessConversion(const InputType &value) { - return static_cast(static_cast(value)) == value; -} -} - -#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \ - enforce_lossless_conversion) \ - template <> \ - T ArgumentHelper::GetSingleArgument(const string &name, \ - const T &default_value) const { \ - if (arg_map_.count(name) == 0) { \ - VLOG(1) << "Using default parameter value " << default_value \ - << " for parameter " << name; \ - return default_value; \ - } \ - MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \ - " does not have the right field: expected field " #fieldname); \ - auto value = arg_map_.at(name).fieldname(); \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(value); \ - MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \ - "cannot be represented correctly in a target type"); \ - } \ - return value; \ - } \ - template <> \ - bool ArgumentHelper::HasSingleArgumentOfType(const string &name) const { \ - if (arg_map_.count(name) == 0) { \ - return false; \ - } \ - return arg_map_.at(name).has_##fieldname(); \ - } - -INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) -#undef INSTANTIATE_GET_SINGLE_ARGUMENT - -#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \ - enforce_lossless_conversion) \ - template <> \ - vector ArgumentHelper::GetRepeatedArgument( \ - const string &name, const std::vector &default_value) const { \ - if (arg_map_.count(name) == 0) { \ - return default_value; \ - } \ - vector values; \ - for (const auto &v : arg_map_.at(name).fieldname()) { \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(v); \ - MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \ - "cannot be represented correctly in a target type"); \ - } \ - values.push_back(v); \ - } \ - return values; \ - } - -INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) -#undef INSTANTIATE_GET_REPEATED_ARGUMENT - -#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ - template <> \ - Argument MakeArgument(const string &name, const T &value) { \ - Argument arg; \ - arg.set_name(name); \ - arg.set_##fieldname(value); \ - return arg; \ - } - -MACE_MAKE_SINGULAR_ARGUMENT(bool, i) -MACE_MAKE_SINGULAR_ARGUMENT(float, f) -MACE_MAKE_SINGULAR_ARGUMENT(int, i) -MACE_MAKE_SINGULAR_ARGUMENT(int64_t, i) -MACE_MAKE_SINGULAR_ARGUMENT(string, s) -#undef MACE_MAKE_SINGULAR_ARGUMENT - -template <> -Argument MakeArgument(const string &name, const MessageLite &value) { - Argument arg; - arg.set_name(name); - arg.set_s(value.SerializeAsString()); - return arg; -} - -#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \ - template <> \ - Argument MakeArgument(const string &name, const vector &value) { \ - Argument arg; \ - arg.set_name(name); \ - for (const auto &v : value) { \ - arg.add_##fieldname(v); \ - } \ - return arg; \ - } - -MACE_MAKE_REPEATED_ARGUMENT(float, floats) -MACE_MAKE_REPEATED_ARGUMENT(int, ints) -MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints) -MACE_MAKE_REPEATED_ARGUMENT(string, strings) -#undef MACE_MAKE_REPEATED_ARGUMENT - -const Argument &GetArgument(const OperatorDef &def, const string &name) { - for (const Argument &arg : def.arg()) { - if (arg.name() == name) { - return arg; - } - } - MACE_CHECK(false, "Argument named ", name, "does not exist in operator ", - ProtoDebugString(def)); - // should not reach here, just make compiler happy - return std::move(Argument()); -} - -bool GetFlagArgument(const OperatorDef &def, - const string &name, - bool def_value) { - for (const Argument &arg : def.arg()) { - if (arg.name() == name) { - MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ", - ProtoDebugString(arg)); - return arg.i(); - } - } - return def_value; -} - -Argument *GetMutableArgument(const string &name, - const bool create_if_missing, - OperatorDef *def) { - for (int i = 0; i < def->arg_size(); ++i) { - if (def->arg(i).name() == name) { - return def->mutable_arg(i); - } - } - // If no argument of the right name is found... - if (create_if_missing) { - Argument *arg = def->add_arg(); - arg->set_name(name); - return arg; - } else { - return nullptr; - } -} - -} // namespace mace diff --git a/mace/core/proto_utils.h b/mace/core/proto_utils.h deleted file mode 100644 index 90747a41a36f4415d352ea15be0ca3c2b3da3630..0000000000000000000000000000000000000000 --- a/mace/core/proto_utils.h +++ /dev/null @@ -1,248 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_CORE_PROTO_UTILS_H_ -#define MACE_CORE_PROTO_UTILS_H_ - -#include - -#include "google/protobuf/message_lite.h" -#ifndef MACE_USE_LITE_PROTO -#include "google/protobuf/message.h" -#endif // !MACE_USE_LITE_PROTO - -#include "mace/core/common.h" -#include "mace/proto/mace.pb.h" - -namespace mace { - -using std::string; -using ::google::protobuf::MessageLite; - -// Common interfaces that reads file contents into a string. -bool ReadStringFromFile(const char *filename, string *str); -bool WriteStringToFile(const string &str, const char *filename); - -// Common interfaces that are supported by both lite and full protobuf. -bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto); -inline bool ReadProtoFromBinaryFile(const string filename, MessageLite *proto) { - return ReadProtoFromBinaryFile(filename.c_str(), proto); -} - -void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename); -inline void WriteProtoToBinaryFile(const MessageLite &proto, - const string &filename) { - return WriteProtoToBinaryFile(proto, filename.c_str()); -} - -#ifdef MACE_USE_LITE_PROTO - -inline string ProtoDebugString(const MessageLite &proto) { - return proto.SerializeAsString(); -} - -// Text format MessageLite wrappers: these functions do nothing but just -// allowing things to compile. It will produce a runtime error if you are using -// MessageLite but still want text support. -inline bool ReadProtoFromTextFile(const char * /*filename*/, - MessageLite * /*proto*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; // Just to suppress compiler warning. -} -inline bool ReadProtoFromTextFile(const string filename, MessageLite *proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -inline void WriteProtoToTextFile(const MessageLite & /*proto*/, - const char * /*filename*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; -} -inline void WriteProtoToTextFile(const MessageLite &proto, - const string &filename) { - return WriteProtoToTextFile(proto, filename.c_str()); -} - -inline bool ReadProtoFromFile(const char *filename, MessageLite *proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string &filename, MessageLite *proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#else // MACE_USE_LITE_PROTO - -using ::google::protobuf::Message; - -inline string ProtoDebugString(const Message &proto) { - return proto.ShortDebugString(); -} - -bool ReadProtoFromTextFile(const char *filename, Message *proto); -inline bool ReadProtoFromTextFile(const string filename, Message *proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -void WriteProtoToTextFile(const Message &proto, const char *filename); -inline void WriteProtoToTextFile(const Message &proto, const string &filename) { - return WriteProtoToTextFile(proto, filename.c_str()); -} - -// Read Proto from a file, letting the code figure out if it is text or binary. -inline bool ReadProtoFromFile(const char *filename, Message *proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string &filename, Message *proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#endif // MACE_USE_LITE_PROTO - -template , - class IterableOutputs = std::initializer_list, - class IterableArgs = std::initializer_list> -OperatorDef CreateOperatorDef(const string &type, - const string &name, - const IterableInputs &inputs, - const IterableOutputs &outputs, - const IterableArgs &args) { - OperatorDef def; - def.set_type(type); - def.set_name(name); - for (const string &in : inputs) { - def.add_input(in); - } - for (const string &out : outputs) { - def.add_output(out); - } - for (const Argument &arg : args) { - def.add_arg()->CopyFrom(arg); - } - return def; -} - -// A simplified version compared to the full CreateOperator, if you do not need -// to specify args. -template , - class IterableOutputs = std::initializer_list> -inline OperatorDef CreateOperatorDef(const string &type, - const string &name, - const IterableInputs &inputs, - const IterableOutputs &outputs) { - return CreateOperatorDef(type, name, inputs, outputs, - std::vector()); -} - -/** - * @brief A helper class to index into arguments. - * - * This helper helps us to more easily index into a set of arguments - * that are present in the operator. To save memory, the argument helper - * does not copy the operator def, so one would need to make sure that the - * lifetime of the OperatorDef object outlives that of the ArgumentHelper. - */ -class ArgumentHelper { - public: - template - static bool HasArgument(const Def &def, const string &name) { - return ArgumentHelper(def).HasArgument(name); - } - - template - static T GetSingleArgument(const Def &def, - const string &name, - const T &default_value) { - return ArgumentHelper(def).GetSingleArgument(name, default_value); - } - - template - static bool HasSingleArgumentOfType(const Def &def, const string &name) { - return ArgumentHelper(def).HasSingleArgumentOfType(name); - } - - template - static vector GetRepeatedArgument( - const Def &def, - const string &name, - const std::vector &default_value = std::vector()) { - return ArgumentHelper(def).GetRepeatedArgument(name, default_value); - } - - template - static MessageType GetMessageArgument(const Def &def, const string &name) { - return ArgumentHelper(def).GetMessageArgument(name); - } - - template - static vector GetRepeatedMessageArgument(const Def &def, - const string &name) { - return ArgumentHelper(def).GetRepeatedMessageArgument(name); - } - - explicit ArgumentHelper(const OperatorDef &def); - explicit ArgumentHelper(const NetDef &netdef); - bool HasArgument(const string &name) const; - - template - T GetSingleArgument(const string &name, const T &default_value) const; - template - bool HasSingleArgumentOfType(const string &name) const; - template - vector GetRepeatedArgument( - const string &name, - const std::vector &default_value = std::vector()) const; - - template - MessageType GetMessageArgument(const string &name) const { - MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); - MessageType message; - if (arg_map_.at(name).has_s()) { - MACE_CHECK(message.ParseFromString(arg_map_.at(name).s()), - "Faild to parse content from the string"); - } else { - VLOG(1) << "Return empty message for parameter " << name; - } - return message; - } - - template - vector GetRepeatedMessageArgument(const string &name) const { - MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); - vector messages(arg_map_.at(name).strings_size()); - for (int i = 0; i < messages.size(); ++i) { - MACE_CHECK(messages[i].ParseFromString(arg_map_.at(name).strings(i)), - "Faild to parse content from the string"); - } - return messages; - } - - private: - std::map arg_map_; -}; - -const Argument &GetArgument(const OperatorDef &def, const string &name); -bool GetFlagArgument(const OperatorDef &def, - const string &name, - bool def_value = false); - -Argument *GetMutableArgument(const string &name, - const bool create_if_missing, - OperatorDef *def); - -template -Argument MakeArgument(const string &name, const T &value); - -template -inline void AddArgument(const string &name, const T &value, OperatorDef *def) { - GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); -} - -} // namespace mace - -#endif // MACE_CORE_PROTO_UTILS_H_ diff --git a/mace/core/serializer.cc b/mace/core/serializer.cc index cfe2d935c54590c7769e44a9eefc498ed254f751..568e39fa4fa07628627e3bac7fa53240b597ef86 100644 --- a/mace/core/serializer.cc +++ b/mace/core/serializer.cc @@ -24,46 +24,32 @@ unique_ptr Serializer::Deserialize(const TensorProto &proto, switch (proto.data_type()) { case DT_FLOAT: - tensor->Copy(proto.float_data().data(), proto.float_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_DOUBLE: - tensor->Copy(proto.double_data().data(), - proto.double_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT32: - tensor->template Copy(proto.int32_data().data(), - proto.int32_data().size()); + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + break; + case DT_INT64: + tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_UINT8: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT16: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT8: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); - break; - case DT_INT64: - tensor->Copy(proto.int64_data().data(), - proto.int64_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_UINT16: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; case DT_BOOL: - tensor->CopyWithCast(proto.int32_data().data(), - proto.int32_data().size()); + tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); break; - case DT_STRING: { - string *content = tensor->mutable_data(); - for (int i = 0; i < proto.string_data().size(); ++i) { - content[i] = proto.string_data(i); - } - } break; default: MACE_NOT_IMPLEMENTED; break; diff --git a/mace/core/serializer.h b/mace/core/serializer.h index 107d9f4ed9e8f259dfc4779d7185881cf2aa01a1..9bfeea08192b780bee6d43b72f4c7b18a5d4cdd3 100644 --- a/mace/core/serializer.h +++ b/mace/core/serializer.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 94f9522805c62a484157941b9c87b559c02fca6f..d2d634e66e11e498d6f7c549f8dafb651de81ba0 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -9,7 +9,7 @@ #include "mace/core/common.h" #include "mace/core/logging.h" #include "mace/core/types.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/core/types.h b/mace/core/types.h index 616e40b2aeba81a1eca0ddbe28b7acf4c56b2b0a..2d1e94cb5e5d21797fe83676d1953f0bb2d7f015 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -6,7 +6,7 @@ #define MACE_CORE_TYPES_H_ #include "mace/core/common.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/half.h" diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index e8fc98f9933e9c389cdf98e652c6ac11ee3dab87..b21fcf6db84f065d305b68b28e741d21c5d84b5d 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -4,7 +4,7 @@ #include "mace/core/workspace.h" #include "mace/core/serializer.h" -#include "mace/core/proto_utils.h" +#include "mace/core/arg_helper.h" namespace mace { diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 8a706b876ecae7affbc9288f2b242de627c2725a..6aea528a33a51f6bca3af7e1f0c182b79ea68936 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 73fa676766c12f1cb49a79b0e930f5e441fd8f5b..97d27a5e120cd54a6f7f1f233afa1a7579e70ff2 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -20,6 +20,9 @@ using namespace std; using namespace mace; +namespace mace { +extern NetDef CreateNet(); +} void ParseShape(const string &str, vector *shape) { string tmp = str; while (!tmp.empty()) { @@ -34,6 +37,18 @@ void ParseShape(const string &str, vector *shape) { } } +DeviceType ParseDeviceType(const string &device_str) { + if(device_str.compare("CPU") == 0) { + return DeviceType::CPU; + } else if (device_str.compare("NEON") == 0) { + return DeviceType::NEON; + } else if (device_str.compare("OPENCL") == 0) { + return DeviceType::OPENCL; + } else { + return DeviceType::CPU; + } +} + int main(int argc, char **argv) { string model_file; string input_node; @@ -76,13 +91,13 @@ int main(int argc, char **argv) { ParseShape(input_shape, &shape); // load model - ifstream file_stream(model_file, ios::in | ios::binary); - NetDef net_def; - net_def.ParseFromIstream(&file_stream); - file_stream.close(); +// ifstream file_stream(model_file, ios::in | ios::binary); +// NetDef net_def; +// net_def.ParseFromIstream(&file_stream); +// file_stream.close(); + NetDef net_def = mace::CreateNet(); - DeviceType device_type; - DeviceType_Parse(device, &device_type); + DeviceType device_type = ParseDeviceType(device); VLOG(0) << device_type; Workspace ws; ws.LoadModelTensor(net_def, device_type); diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index e56302e461af155ed5dbd82451e2d35dc195ed8c..9c009b86f59a2bc9807bfc6696b23f491898947a 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_BATCH_NORM_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/bias_add.h b/mace/kernels/bias_add.h index a1e05caed004c303c423eb5decba285eade0b2fd..c738502a0811524154586ca2f3669e0f967d39ad 100644 --- a/mace/kernels/bias_add.h +++ b/mace/kernels/bias_add.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_BIAS_ADD_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index e70b4e73c977b9d8da0735784219739c5dbd468a..5c3b22ab1a97d4d8e1f1b28a67186bba89d8f5dc 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/core/types.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" #include "mace/core/tensor.h" namespace mace { diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 840ce727570f108f365bfcf0b6402e030e18d7d2..c1f1d076ed05f0490ee3724339b6637af84d3a95 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -7,7 +7,7 @@ #include "mace/core/common.h" #include "mace/kernels/conv_pool_2d_util.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/kernels/space_to_batch.h b/mace/kernels/space_to_batch.h index ebf5994b6c84819ec2e08fb7fb45b2eecf7f072b..4f7bd1afe644e52423ca8688ee04289ca014f64d 100644 --- a/mace/kernels/space_to_batch.h +++ b/mace/kernels/space_to_batch.h @@ -6,7 +6,7 @@ #define MACE_KERNELS_CONV_2D_H_ #include "mace/core/tensor.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { namespace kernels { diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 89f6b8b78bb1f31a912817140756d562275a61ce..a1ba1eebe3a06559dccab4d3dcedd2b32f1600fa 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -40,7 +40,6 @@ cc_library( ], deps = [ "//mace/kernels", - "//mace/proto:cc_proto", ], alwayslink = 1, ) diff --git a/mace/ops/concat.h b/mace/ops/concat.h index 77e430304c93341c176dd732c30559f0721e4f8a..0edf34551b1718e365873efb7758ea79d71fe797 100644 --- a/mace/ops/concat.h +++ b/mace/ops/concat.h @@ -7,7 +7,7 @@ #include "mace/core/operator.h" #include "mace/kernels/concat.h" -#include "mace/proto/mace.pb.h" +#include "mace/core/mace.h" namespace mace { template diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 8d593940cf0c5059d5064a27c7edb3558b9f559b..3d8c809be96c2cd92dd136a3ca14a24eeadad780 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -159,7 +159,6 @@ class OpsTestNet { for (auto &op_def_ : op_defs_) { net_def.add_op()->CopyFrom(op_def_); } - VLOG(3) << net_def.DebugString(); net_ = CreateNet(net_def, &ws_, device); device_ = device; return net_->Run(); diff --git a/mace/proto/BUILD b/mace/proto/BUILD index 593fdea5a7740b445f639977a3ed9d9fcad3e8ff..e166ade2bf1b5d1628b815279bc627d3460beac6 100644 --- a/mace/proto/BUILD +++ b/mace/proto/BUILD @@ -10,16 +10,6 @@ licenses(["notice"]) # Apache 2.0 load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") -proto_library( - name = "proto", - srcs = ["mace.proto"], -) - -cc_proto_library( - name = "cc_proto", - deps = [":proto"], -) - proto_library( name = "stats", srcs = ["stats.proto"], diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD index 675f12acb73ee99e810c9add14087ebc63408812..fbe406d33df00839fdd53d0db72790f6eee3e424 100644 --- a/mace/python/tools/BUILD +++ b/mace/python/tools/BUILD @@ -13,12 +13,24 @@ py_library( ], ) +py_library( + name = "source_converter_lib", + srcs = [ + "source_converter_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//mace/proto:mace_py", + ], +) + py_binary( name = "tf_converter", srcs = ["tf_converter.py"], srcs_version = "PY2AND3", deps = [ ":tf_converter_lib", + ":source_converter_lib", "@six_archive//:six", ], ) diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template new file mode 100644 index 0000000000000000000000000000000000000000..0fcbcde420dbb8be87727352ee5c63dc7c68f391 --- /dev/null +++ b/mace/python/tools/model.template @@ -0,0 +1,179 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// Generated by the mace converter. DO NOT EDIT! +// + +{% if mode == 0 %} + +namespace {{tag}}{ + +alignas(4) unsigned char {{ tensor.name }}[] = { +{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%} +}; + +} // namespace {{tag}} +{% else %} +#include +#include +#include "mace/core/mace.h" + +namespace {{tag}} { + +{% for tensor in tensors %} +extern unsigned char {{ tensor.name }}[]; +{% endfor %} + +} // namespace {{ tag }} + +namespace { + +{% if net.arg|length != 0 %} +static void CreateNetArg(mace::NetDef &net_def) { + net_def.mutable_arg().reserve({{ net.arg|length }}); + mace::Argument *arg = nullptr; + {% for arg in net.arg %} + + arg = net_def.add_arg(); + arg->set_name({{ arg.name|tojson }}); + + {%- if arg.HasField('f') %} + arg->set_f({{ arg.f }}); + {% endif %} + + {%- if arg.HasField('i') %} + arg->set_i({{ arg.i }}); + {% endif %} + + {%- if arg.HasField('s') %} + arg->set_s({{ arg.s|tojson }}); + {% endif %} + + {% if arg.floats|length != 0 %} + arg->set_floats({ {{ arg.floats|join(', ') }} }); + {% endif %} + {% if arg.ints|length != 0 %} + arg->set_ints({ {{ arg.ints|join(', ') }} }); + {% endif %} + {% if arg.strings|length != 0 %} + arg->set_strings({ {{ arg.strings|stringfy() }} }); + {% endif %} + + {% endfor %} + +} +{% endif %} + +static void UpdateOp(mace::OperatorDef &op, + const std::string &name, + const std::string &type, + const int mem_id, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &output_types) { + op.set_name(name); + op.set_type(type); + op.set_input(inputs); + op.set_output(outputs); + op.set_mem_id(mem_id); + op.set_output_type(output_types); +} + +static void CreateOperators(std::vector &ops) { + ops.resize({{ net.op|length }}); + mace::Argument *arg = nullptr; + {% for i in range(net.op|length) %} + {% for arg in net.op[i].arg %} + + arg = ops[{{i}}].add_arg(); + arg->set_name({{ arg.name|tojson }}); + + {%- if arg.HasField('f') %} + arg->set_f({{ arg.f }}); + {%- endif %} + {%- if arg.HasField('i') %} + arg->set_i({{ arg.i }}); + {%- endif %} + {%- if arg.HasField('s') %} + arg->set_s({{ arg.s|tojson }}); + {%- endif %} + + {% if arg.floats|length != 0 %} + arg->set_floats({ {{ arg.floats|join(', ') }} }); + {% endif %} + {% if arg.ints|length != 0 %} + arg->set_ints({ {{ arg.ints|join(', ') }} }); + {% endif %} + {% if arg.strings|length != 0 %} + arg->set_strings({ {{ arg.strings|stringfy() }} }); + {% endif %} + {% endfor %} + + {% for shape in net.op[i].output_shape %} + ops[{{i}}].add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} })); + {% endfor %} + + UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }}, + { {{ net.op[i].input|stringfy }} }, + { {{ net.op[i].output|stringfy }} }, + { {{ net.op[i].output_type|join(', ') }} }); + + {% endfor %} + +} + +static void CreateTensors(std::vector &tensors) { + tensors.reserve({{ net.tensors|length }}); + + {% for tensor in net.tensors %} + + tensors.emplace_back(mace::TensorProto( + {{ tensor.name|tojson }}, {{ tag + '::' + tensor.name }}, + { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, + {{ tensor.node_id }} + )); + + {% endfor %} + +} + + +{% if net.mem_arena.mem_block|length != 0 %} +static void CreateMemoryArena(mace::MemoryArena &mem_arena) { + auto mem_block = mem_arena.mutable_mem_block(); + mem_block.reserve({{ net.mem_arena.mem_block|length }}); + + {% for mem_blk in net.mem_arena.mem_block %} + mem_block.emplace_back(mace::MemoryBlock({{ mem_blk.mem_id }}, + {{mem_blk.x}}, + {{mem_blk.y}})); + {% endfor %} + +} +{% endif %} + +} + +namespace mace { + +NetDef {{'Create' + tag}}() { + NetDef net_def; + net_def.set_name("{{ net.name}}"); + net_def.set_version("{{ net.version }}"); + + {% if net.arg|length != 0 %} + CreateNetArg(net_def); + {% endif %} + + CreateOperators(net_def.mutable_op()); + + CreateTensors(net_def.mutable_tensors()); + + {% if net.mem_arena.mem_block|length != 0 %} + CreateMemoryArena(net_def.mutable_mem_arena()); + {% endif %} + + return net_def; +} + +} // namespace mace +{% endif %} diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6495fd18b4305730fd9e575fa59c8212b3cf98 --- /dev/null +++ b/mace/python/tools/source_converter_lib.py @@ -0,0 +1,122 @@ +import struct +import os +import uuid + +from tensorflow import gfile +from mace.proto import mace_pb2 +from jinja2 import Environment, FileSystemLoader + + +GENERATED_NAME = set() + +def generate_random_name(): + name = '_' + uuid.uuid4().hex[:7].upper() + while name in GENERATED_NAME: + name = '_' + uuid.uuid4().hex[:7].upper() + GENERATED_NAME.add(name) + return name + +def generate_tensor_map(tensors): + tensor_map = {} + for t in tensors: + if not tensor_map.has_key(t.name): + tensor_map[t.name] = generate_random_name() + return tensor_map + +def generate_in_out_map(ops, tensor_map): + in_out_map = {} + for op in ops: + op.name = generate_random_name() + for input_name in op.input: + if not in_out_map.has_key(input_name): + if tensor_map.has_key(input_name): + in_out_map[input_name] = tensor_map[input_name] + else: + in_out_map[input_name] = generate_random_name() + for output_name in op.output: + if not in_out_map.has_key(output_name): + if tensor_map.has_key(output_name): + in_out_map[output_name] = tensor_map[output_name] + else: + in_out_map[output_name] = generate_random_name() + return in_out_map + +def confuse_name(net_def): + input_node = "mace_input_node" + output_node = "mace_output_node" + tensor_map = generate_tensor_map(net_def.tensors) + in_out_map = generate_in_out_map(net_def.op, tensor_map) + for t in net_def.tensors: + if input_node not in t.name and output_node not in t.name: + t.name = tensor_map[t.name] + for op in net_def.op: + for i in range(len(op.input)): + if input_node not in op.input[i]: + op.input[i] = in_out_map[op.input[i]] + for i in range(len(op.output)): + if output_node not in op.output[i]: + op.output[i] = in_out_map[op.output[i]] + +def rename_tensor(net_def): + tensor_map = {} + for t in net_def.tensors: + if not tensor_map.has_key(t.name): + tensor_map[t.name] = "_" + t.name[:-2].replace("/", "_") + t.name = tensor_map[t.name] + for op in net_def.op: + for i in range(len(op.input)): + if tensor_map.has_key(op.input[i]): + op.input[i] = tensor_map[op.input[i]] + for i in range(len(op.output)): + if tensor_map.has_key(op.output[i]): + op.output[i] = tensor_map[op.output[i]] + +class TensorInfo: + def __init__(self, t): + self.name = t.name + if t.data_type == mace_pb2.DT_FLOAT: + self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data)) + elif t.data_type == mace_pb2.DT_INT32: + self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data)) + +def stringfy(value): + return ', '.join('"{0}"'.format(w) for w in value) + +def convert_to_source(net_def, template, confuse, model_tag, output): + if confuse: + confuse_name(net_def) + else: + rename_tensor(net_def) + + # Capture our current directory + template_dir = os.path.dirname(template) + template_name = os.path.basename(template) + print template_dir + + # Create the jinja2 environment. + j2_env = Environment(loader=FileSystemLoader(template_dir), + trim_blocks=True) + j2_env.filters['stringfy'] = stringfy + counter = 0 + output_dir = os.path.dirname(output) + '/' + # generate tensor source files + for t in net_def.tensors: + source = j2_env.get_template(template_name).render( + tensor = TensorInfo(t), + tag = model_tag, + mode = 0, + ) + with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f: + f.write(source) + counter += 1 + + # generate model source files + tensors = [TensorInfo(t) for t in net_def.tensors] + source = j2_env.get_template(template_name).render( + tensors = tensors, + net = net_def, + tag = model_tag, + mode = 1 + ) + with gfile.GFile(output, "wb") as f: + f.write(source) diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index 886999d3f59bfb2f49f5db8bf598c8b462f64b17..1251bf55f61c5b674b6bab538e36f485cad383b8 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -2,8 +2,10 @@ import argparse import sys import tensorflow as tf from tensorflow import gfile +from mace.proto import mace_pb2 from mace.python.tools import tf_converter_lib from mace.python.tools import tf_dsp_converter_lib +from mace.python.tools import source_converter_lib # ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3 @@ -19,6 +21,7 @@ def main(unused_args): data = f.read() input_graph_def.ParseFromString(data) + print 'done' if FLAGS.runtime == 'dsp': output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize) @@ -26,11 +29,15 @@ def main(unused_args): output_graph_def = tf_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) - with gfile.GFile(FLAGS.output, "wb") as f: - f.write(output_graph_def.SerializeToString()) - with gfile.GFile(FLAGS.output + '_txt', "wb") as f: - # output_graph_def.ClearField('tensors') - f.write(str(output_graph_def)) + if FLAGS.output_type == 'source': + source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.confuse, + FLAGS.model_tag, FLAGS.output) + else: + with gfile.GFile(FLAGS.output, "wb") as f: + f.write(output_graph_def.SerializeToString()) + with gfile.GFile(FLAGS.output + '_txt', "wb") as f: + # output_graph_def.ClearField('tensors') + f.write(str(output_graph_def)) def parse_args(): @@ -51,7 +58,7 @@ def parse_args(): "--runtime", type=str, default="cpu", - help="Runtime: cpu/gpu/dsp.") + help="Runtime: cpu/gpu/dsp") parser.add_argument( "--input_node", type=str, @@ -72,6 +79,26 @@ def parse_args(): type=str, default='DT_FLOAT', help="e.g., DT_HALF/DT_FLOAT") + parser.add_argument( + "--output_type", + type=str, + default="source", + help="output type: source/pb") + parser.add_argument( + "--template", + type=str, + default="", + help="template path") + parser.add_argument( + "--confuse", + type=bool, + default=False, + help="confuse model names") + parser.add_argument( + "--model_tag", + type=str, + default="", + help="model tag for generated function and namespace") return parser.parse_known_args() diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index 275f1bfb827f7f713d8b93c892c5080ba2cbd6b6..3817094fa0bd4021a36f2ae5853fa5902e0b6e21 100644 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -27,14 +27,14 @@ python tools/validate.py --generate_data true --random_seed 1 \ --input_shape="${IMAGE_SIZE},${IMAGE_SIZE},3" # Step 2: convert tf model to mace model -echo "Step 2: convert tf model to mace model and optimize memory" -bazel build //mace/python/tools:tf_converter -bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ - --output=${MODEL_DIR}/${MACE_MODEL_NAME} \ - --input_node=input \ - --output_node=GCN/br_result_2/fcn_br \ - --data_type=DT_HALF \ - --runtime=gpu +#echo "Step 2: convert tf model to mace model and optimize memory" +#bazel build //mace/python/tools:tf_converter +#bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ +# --output=${MODEL_DIR}/${MACE_MODEL_NAME} \ +# --input_node=input \ +# --output_node=GCN/br_result_2/fcn_br \ +# --data_type=DT_HALF \ +# --runtime=gpu # Step 3: Run model on the phone echo "Step 3: Run model on the phone" @@ -46,7 +46,7 @@ bazel build -c opt --strip always mace/examples:mace_run \ adb shell "mkdir -p ${PHONE_DATA_DIR}" adb shell "mkdir -p ${KERNEL_DIR}" adb push mace/kernels/opencl/cl/* ${KERNEL_DIR} -adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR} +#adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR} adb push ${MODEL_DIR}/${INPUT_FILE_NAME} ${PHONE_DATA_DIR} adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR}