From dd2cd8ee6a0e72c590f7befb22a1060009d7415b Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Fri, 22 Dec 2017 19:02:31 +0800 Subject: [PATCH] Rename TensorProto to ConstTensor --- mace/core/mace.cc | 34 +++++++++++++++----------------- mace/core/public/mace.h | 28 +++++++++++++------------- mace/core/serializer.cc | 33 ++++++++++++++++++++----------- mace/core/serializer.h | 4 ++-- mace/examples/helloworld.cc | 2 +- mace/python/tools/model.template | 8 ++++---- 6 files changed, 58 insertions(+), 51 deletions(-) diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 98478d36..c71d8b45 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -10,50 +10,48 @@ namespace mace { -TensorProto::TensorProto(const std::string &name, +ConstTensor::ConstTensor(const std::string &name, unsigned char *data, const std::vector &dims, const DataType data_type, uint32_t node_id) : name_(name), data_(data), - data_size_(0), dims_(dims.begin(), dims.end()), data_type_(data_type), - node_id_(node_id) { - data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); -} + node_id_(node_id), + data_size_(std::accumulate(dims.begin(), dims.end(), 1, + std::multiplies())) {} -TensorProto::TensorProto(const std::string &name, +ConstTensor::ConstTensor(const std::string &name, unsigned char *data, const std::vector &dims, const int data_type, uint32_t node_id) : name_(name), data_(data), - data_size_(0), dims_(dims.begin(), dims.end()), data_type_(static_cast(data_type)), - node_id_(node_id) { - data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); -} + node_id_(node_id), + data_size_(std::accumulate(dims.begin(), dims.end(), 1, + std::multiplies())) {} -const std::string &TensorProto::name() const { +const std::string &ConstTensor::name() const { return name_; } -unsigned char *TensorProto::data() const { +const unsigned char *ConstTensor::data() const { return data_; } -const int64_t TensorProto::data_size() const { +int64_t ConstTensor::data_size() const { return data_size_; } -const std::vector &TensorProto::dims() const { +const std::vector &ConstTensor::dims() const { return dims_; } -DataType TensorProto::data_type() const { +DataType ConstTensor::data_type() const { return data_type_; } -uint32_t TensorProto::node_id() const { +uint32_t ConstTensor::node_id() const { return node_id_; } @@ -446,10 +444,10 @@ Argument *NetDef::add_arg() { std::vector &NetDef::mutable_arg() { return arg_; } -const std::vector &NetDef::tensors() const { +const std::vector &NetDef::tensors() const { return tensors_; } -std::vector &NetDef::mutable_tensors() { +std::vector &NetDef::mutable_tensors() { return tensors_; } const MemoryArena &NetDef::mem_arena() const { diff --git a/mace/core/public/mace.h b/mace/core/public/mace.h index c56a53b5..caf5b311 100644 --- a/mace/core/public/mace.h +++ b/mace/core/public/mace.h @@ -38,33 +38,33 @@ enum DataType { DT_UINT32 = 22 }; -class TensorProto { +class ConstTensor { public: - TensorProto(const std::string &name, + ConstTensor(const std::string &name, unsigned char *data, const std::vector &dims, const DataType data_type = DT_FLOAT, uint32_t node_id = 0); - TensorProto(const std::string &name, + ConstTensor(const std::string &name, unsigned char *data, const std::vector &dims, const int data_type, uint32_t node_id = 0); const std::string &name() const; - unsigned char *data() const; - const int64_t data_size() const; + const unsigned char *data() const; + int64_t data_size() const; const std::vector &dims() const; DataType data_type() const; uint32_t node_id() const; private: - std::string name_; - unsigned char *data_; - int64_t data_size_; - std::vector dims_; - DataType data_type_; - uint32_t node_id_; + const std::string name_; + const unsigned char *data_; + const int64_t data_size_; + const std::vector dims_; + const DataType data_type_; + const uint32_t node_id_; }; class Argument { @@ -270,8 +270,8 @@ class NetDef { const std::vector &arg() const; Argument *add_arg(); std::vector &mutable_arg(); - const std::vector &tensors() const; - std::vector &mutable_tensors(); + const std::vector &tensors() const; + std::vector &mutable_tensors(); const MemoryArena &mem_arena() const; bool has_mem_arena() const; MemoryArena &mutable_mem_arena(); @@ -288,7 +288,7 @@ class NetDef { std::string version_; std::vector op_; std::vector arg_; - std::vector tensors_; + std::vector tensors_; // for mem optimization MemoryArena mem_arena_; diff --git a/mace/core/serializer.cc b/mace/core/serializer.cc index 568e39fa..a0e3bc59 100644 --- a/mace/core/serializer.cc +++ b/mace/core/serializer.cc @@ -6,13 +6,13 @@ namespace mace { -unique_ptr Serializer::Serialize(const Tensor &tensor, +unique_ptr Serializer::Serialize(const Tensor &tensor, const string &name) { MACE_NOT_IMPLEMENTED; return nullptr; } -unique_ptr Serializer::Deserialize(const TensorProto &proto, +unique_ptr Serializer::Deserialize(const ConstTensor &proto, DeviceType type) { unique_ptr tensor( new Tensor(GetDeviceAllocator(type), proto.data_type())); @@ -24,31 +24,40 @@ unique_ptr Serializer::Deserialize(const TensorProto &proto, switch (proto.data_type()) { case DT_FLOAT: - tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + tensor->Copy(reinterpret_cast(proto.data()), + proto.data_size()); break; case DT_DOUBLE: - tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + tensor->Copy(reinterpret_cast(proto.data()), + proto.data_size()); break; case DT_INT32: - tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + tensor->Copy(reinterpret_cast(proto.data()), + proto.data_size()); break; case DT_INT64: - tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); + tensor->Copy(reinterpret_cast(proto.data()), + proto.data_size()); break; case DT_UINT8: - tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); + tensor->CopyWithCast( + reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT16: - tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); + tensor->CopyWithCast( + reinterpret_cast(proto.data()), proto.data_size()); break; case DT_INT8: - tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); + tensor->CopyWithCast( + reinterpret_cast(proto.data()), proto.data_size()); break; case DT_UINT16: - tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); + tensor->CopyWithCast( + reinterpret_cast(proto.data()), proto.data_size()); break; case DT_BOOL: - tensor->CopyWithCast(reinterpret_cast(proto.data()), proto.data_size()); + tensor->CopyWithCast( + reinterpret_cast(proto.data()), proto.data_size()); break; default: MACE_NOT_IMPLEMENTED; @@ -58,4 +67,4 @@ unique_ptr Serializer::Deserialize(const TensorProto &proto, return tensor; } -} // namespace mace \ No newline at end of file +} // namespace mace diff --git a/mace/core/serializer.h b/mace/core/serializer.h index 64c33b6d..c615300a 100644 --- a/mace/core/serializer.h +++ b/mace/core/serializer.h @@ -16,9 +16,9 @@ class Serializer { Serializer() {} ~Serializer() {} - unique_ptr Serialize(const Tensor &tensor, const string &name); + unique_ptr Serialize(const Tensor &tensor, const string &name); - unique_ptr Deserialize(const TensorProto &proto, DeviceType type); + unique_ptr Deserialize(const ConstTensor &proto, DeviceType type); DISABLE_COPY_AND_ASSIGN(Serializer); }; diff --git a/mace/examples/helloworld.cc b/mace/examples/helloworld.cc index a8904509..90e04317 100644 --- a/mace/examples/helloworld.cc +++ b/mace/examples/helloworld.cc @@ -45,7 +45,7 @@ int main() { alignas(4) unsigned char tensor_data[] = "012345678901234567890123"; const std::vector dims = {1, 2, 3, 1}; - TensorProto input("Input", tensor_data, dims, DataType::DT_FLOAT); + ConstTensor input("Input", tensor_data, dims, DataType::DT_FLOAT); net_def.mutable_tensors().push_back(input); // Create workspace and input tensor diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template index 810b2518..6588d9a2 100644 --- a/mace/python/tools/model.template +++ b/mace/python/tools/model.template @@ -13,8 +13,8 @@ alignas(4) unsigned char {{ tensor_info.name }}[] = { {% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%} }; -void Create{{tensor.name}}(std::vector &tensors) { - tensors.emplace_back(mace::TensorProto( +void Create{{tensor.name}}(std::vector &tensors) { + tensors.emplace_back(mace::ConstTensor( {{ tensor.name|tojson }}, {{ tensor.name }}, { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }})); } @@ -100,7 +100,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) { namespace {{tag}} { {% for tensor in tensors %} -extern void Create{{ tensor.name }}(std::vector &tensors); +extern void Create{{ tensor.name }}(std::vector &tensors); {% endfor %} @@ -159,7 +159,7 @@ static void CreateOperators(std::vector &ops) { } -static void CreateTensors(std::vector &tensors) { +static void CreateTensors(std::vector &tensors) { tensors.reserve({{ net.tensors|length }}); {% for tensor in net.tensors %} -- GitLab