From f7bb2064e0e2ef74e316354ea2b9bc30d8dd892d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 30 Aug 2017 17:51:07 +0800 Subject: [PATCH] Add deserializer to parse tensor from proto --- mace/core/allocator.cc | 9 +++++ mace/core/allocator.h | 8 +++- mace/core/integral_types.h | 4 +- mace/core/operator.h | 2 +- mace/core/serializer.cc | 78 +++++++++++++++++++++++++++++++++++++ mace/core/serializer.h | 28 +++++++++++++ mace/core/tensor.h | 36 +++++++++++++++++ mace/core/workspace.cc | 8 +++- mace/core/workspace.h | 4 +- mace/examples/BUILD | 2 +- mace/examples/helloworld.cc | 64 +++++++++++++++++++++++++++++- 11 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 mace/core/serializer.cc create mode 100644 mace/core/serializer.h diff --git a/mace/core/allocator.cc b/mace/core/allocator.cc index 61e28e9d..fd1f50c3 100644 --- a/mace/core/allocator.cc +++ b/mace/core/allocator.cc @@ -15,4 +15,13 @@ void SetCPUAllocator(CPUAllocator* alloc) { g_cpu_allocator.reset(alloc); } +Allocator* GetDeviceAllocator(DeviceType type) { + if (type == DeviceType::CPU) { + return cpu_allocator(); + } else { + REQUIRE(false, "device type ", type, " is not supported."); + } + return nullptr; +} + } // namespace mace diff --git a/mace/core/allocator.h b/mace/core/allocator.h index fa4f1889..110b012b 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -21,6 +21,7 @@ class Allocator { virtual ~Allocator() noexcept {} virtual void* New(size_t nbytes) = 0; virtual void Delete(void* data) = 0; + virtual void CopyBytes(void* dst, const void* src, size_t size) = 0; template T* New(size_t num_elements) { @@ -59,6 +60,10 @@ class CPUAllocator: public Allocator { free(data); } #endif + + void CopyBytes(void* dst, const void* src, size_t size) { + memcpy(dst, src, size); + } }; // Get the CPU Alloctor. @@ -72,9 +77,10 @@ struct DeviceContext {}; template <> struct DeviceContext { - static Allocator* alloctor() { return cpu_allocator(); } + static Allocator* allocator() { return cpu_allocator(); } }; +Allocator* GetDeviceAllocator(DeviceType type); } // namespace mace diff --git a/mace/core/integral_types.h b/mace/core/integral_types.h index 10a33053..ac4c8803 100644 --- a/mace/core/integral_types.h +++ b/mace/core/integral_types.h @@ -9,11 +9,11 @@ typedef signed char int8; typedef short int16; typedef int int32; -typedef long long int64; +typedef int64_t int64; typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint32; -typedef unsigned long long uint64; +typedef uint64_t uint64; #endif // MACE_CORE_INTEGRAL_TYPES_H_ diff --git a/mace/core/operator.h b/mace/core/operator.h index 4b755526..27e1fa16 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -101,7 +101,7 @@ class Operator : public OperatorBase { for (const string &output_str : operator_def.output()) { outputs_.push_back(CHECK_NOTNULL(ws->CreateTensor(output_str, - DeviceContext::alloctor(), + DeviceContext::allocator(), DataTypeToEnum::v()))); } } diff --git a/mace/core/serializer.cc b/mace/core/serializer.cc new file mode 100644 index 00000000..310e7629 --- /dev/null +++ b/mace/core/serializer.cc @@ -0,0 +1,78 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/serializer.h" + + +namespace mace { + +unique_ptr Serializer::Serialize(const Tensor &tensor, + const string &name) { + MACE_NOT_IMPLEMENTED; + return nullptr; +} + +unique_ptr Serializer::Deserialize(const TensorProto &proto, + DeviceType type) { + unique_ptr tensor(new Tensor(GetDeviceAllocator(type), + proto.data_type())); + vector dims; + for (const TIndex d : proto.dims()) { + dims.push_back(d); + } + tensor->Resize(dims); + + switch (proto.data_type()) { + case DT_FLOAT: + tensor->Copy(proto.float_data().data(), + proto.float_data().size()); + break; + case DT_DOUBLE: + tensor->Copy(proto.double_data().data(), + proto.double_data().size()); + break; + case DT_INT32: + tensor->template Copy(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_UINT8: + tensor->CopyWithCast(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_INT16: + tensor->CopyWithCast(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_INT8: + tensor->CopyWithCast(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_INT64: + tensor->Copy(proto.int64_data().data(), + proto.int64_data().size()); + break; + case DT_UINT16: + tensor->CopyWithCast(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_BOOL: + tensor->CopyWithCast(proto.int32_data().data(), + proto.int32_data().size()); + break; + case DT_STRING: { + string *content = tensor->mutable_data(); + for (int i = 0; i < proto.string_data().size(); ++i) { + content[i] = proto.string_data(i); + } + } + break; + default: + MACE_NOT_IMPLEMENTED; + break; + } + + return tensor; +} + +} // namespace mace \ No newline at end of file diff --git a/mace/core/serializer.h b/mace/core/serializer.h new file mode 100644 index 00000000..01f20748 --- /dev/null +++ b/mace/core/serializer.h @@ -0,0 +1,28 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_SERIALIZER_H_ +#define MACE_CORE_SERIALIZER_H_ + +#include "mace/proto/mace.pb.h" +#include "mace/core/common.h" +#include "mace/core/tensor.h" + +namespace mace { + +class Serializer { + public: + Serializer() {} + ~Serializer() {} + + unique_ptr Serialize(const Tensor& tensor, const string& name); + + unique_ptr Deserialize(const TensorProto& proto, DeviceType type); + + DISABLE_COPY_AND_ASSIGN(Serializer); +}; + +} // namespace mace + +#endif // MACE_CORE_SERIALIZER_H_ diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 1e15b425..fb34d581 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -9,6 +9,7 @@ #include "mace/proto/mace.pb.h" #include "mace/core/allocator.h" #include "mace/core/types.h" +#include "mace/core/logging.h" namespace mace { @@ -118,6 +119,41 @@ class Tensor { Resize(other->shape()); } + template + inline void Copy(const T* src, size_t size) { + REQUIRE(size == size_, "copy src and dst with different size."); + CopyBytes(static_cast(src), sizeof(T) * size); + } + + template + inline void CopyWithCast(const SrcType* src, size_t size) { + REQUIRE(size == size_, "copy src and dst with different size."); + unique_ptr buffer(new DstType[size]); + for (int i = 0; i < size; ++i) { + buffer[i] = static_cast(src[i]); + } + CopyBytes(static_cast(buffer.get()), sizeof(DstType) * size); + } + + inline void CopyBytes(const void* src, size_t size) { + alloc_->CopyBytes(raw_mutable_data(), src, size); + } + + inline void DebugPrint() { + std::stringstream os; + for (int i: shape_) { + os << i << ", "; + } + LOG(INFO) << "Tensor shape: " << os.str() << " type: " << DataType_Name(dtype_); + + os.str(""); + os.clear(); + for (int i = 0; i < size_; ++i) { + CASES(dtype_, (os << this->data()[i]) << ", "); + } + LOG(INFO) << os.str(); + } + private: inline int64 NumElements() const { return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 14431bc6..ae28d2df 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -4,6 +4,7 @@ #include "mace/core/common.h" #include "mace/core/workspace.h" +#include "mace/core/serializer.h" namespace mace { @@ -48,6 +49,11 @@ Tensor* Workspace::GetTensor(const string& name) { return const_cast(static_cast(this)->GetTensor(name)); } -bool RunNet(); +void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { + Serializer serializer; + for (auto& tensor_proto: net_def.tensors()) { + tensor_map_[tensor_proto.name()] = serializer.Deserialize(tensor_proto, type); + } +} } // namespace mace \ No newline at end of file diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 93043744..7de345bc 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -32,10 +32,12 @@ class Workspace { Tensor* GetTensor(const string& name); + void LoadModelTensor(const NetDef& net_def, DeviceType type); + private: TensorMap tensor_map_; - DISABLE_COPY_AND_ASSIGN(Workspace); + DISABLE_COPY_AND_ASSIGN(Workspace); }; } // namespace mace diff --git a/mace/examples/BUILD b/mace/examples/BUILD index 41362a03..a674593b 100644 --- a/mace/examples/BUILD +++ b/mace/examples/BUILD @@ -7,7 +7,7 @@ cc_binary( "helloworld.cc", ], deps = [ - "//mace/core:core", + "//mace/ops:ops", ], copts = ['-std=c++11'], linkopts = if_android(["-pie", "-llog"]), diff --git a/mace/examples/helloworld.cc b/mace/examples/helloworld.cc index 0ba6d38e..2e9eb1e2 100644 --- a/mace/examples/helloworld.cc +++ b/mace/examples/helloworld.cc @@ -1,7 +1,67 @@ -#include "mace/core/logging.h" +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/net.h" + +using namespace mace; int main() { - LOG(INFO) << "Hello World"; + // Construct graph + OperatorDef op_def_0; + op_def_0.add_input("Input"); + op_def_0.add_output("Output0"); + op_def_0.set_name("ReluTest0"); + op_def_0.set_type("Relu"); + auto arg_0 = op_def_0.add_arg(); + arg_0->set_name("arg0"); + arg_0->set_f(0.5); + + OperatorDef op_def_1; + op_def_1.add_input("Input"); + op_def_1.add_output("Output1"); + op_def_1.set_name("ReluTest1"); + op_def_1.set_type("Relu"); + auto arg_1 = op_def_1.add_arg(); + arg_1->set_name("arg0"); + arg_1->set_f(1.5); + + OperatorDef op_def_2; + op_def_2.add_input("Output1"); + op_def_2.add_output("Output2"); + op_def_2.set_name("ReluTest2"); + op_def_2.set_type("Relu"); + auto arg_2 = op_def_2.add_arg(); + arg_2->set_name("arg0"); + arg_2->set_f(2.5); + + NetDef net_def; + net_def.set_name("NetTest"); + net_def.add_op()->CopyFrom(op_def_0); + net_def.add_op()->CopyFrom(op_def_1); + net_def.add_op()->CopyFrom(op_def_2); + + auto input = net_def.add_tensors(); + input->set_name("Input"); + input->set_data_type(DataType::DT_FLOAT); + input->add_dims(2); + input->add_dims(3); + for (int i = 0; i < 6; ++i) { + input->add_float_data(i - 3); + } + + VLOG(0) << net_def.DebugString(); + + // Create workspace and input tensor + Workspace ws; + ws.LoadModelTensor(net_def, DeviceType::CPU); + + // Create Net & run + auto net = CreateNet(net_def, &ws, DeviceType::CPU); + net->Run(); + + auto out_tensor = ws.GetTensor("Output2"); + out_tensor->DebugPrint(); return 0; } -- GitLab