提交 f7bb2064 编写于 作者: 李寅

Add deserializer to parse tensor from proto

上级 16086da9
......@@ -15,4 +15,13 @@ void SetCPUAllocator(CPUAllocator* alloc) {
g_cpu_allocator.reset(alloc);
}
Allocator* GetDeviceAllocator(DeviceType type) {
if (type == DeviceType::CPU) {
return cpu_allocator();
} else {
REQUIRE(false, "device type ", type, " is not supported.");
}
return nullptr;
}
} // namespace mace
......@@ -21,6 +21,7 @@ class Allocator {
virtual ~Allocator() noexcept {}
virtual void* New(size_t nbytes) = 0;
virtual void Delete(void* data) = 0;
virtual void CopyBytes(void* dst, const void* src, size_t size) = 0;
template <typename T>
T* New(size_t num_elements) {
......@@ -59,6 +60,10 @@ class CPUAllocator: public Allocator {
free(data);
}
#endif
void CopyBytes(void* dst, const void* src, size_t size) {
memcpy(dst, src, size);
}
};
// Get the CPU Alloctor.
......@@ -72,9 +77,10 @@ struct DeviceContext {};
template <>
struct DeviceContext<DeviceType::CPU> {
static Allocator* alloctor() { return cpu_allocator(); }
static Allocator* allocator() { return cpu_allocator(); }
};
Allocator* GetDeviceAllocator(DeviceType type);
} // namespace mace
......
......@@ -9,11 +9,11 @@
typedef signed char int8;
typedef short int16;
typedef int int32;
typedef long long int64;
typedef int64_t int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;
typedef uint64_t uint64;
#endif // MACE_CORE_INTEGRAL_TYPES_H_
......@@ -101,7 +101,7 @@ class Operator : public OperatorBase {
for (const string &output_str : operator_def.output()) {
outputs_.push_back(CHECK_NOTNULL(ws->CreateTensor(output_str,
DeviceContext<D>::alloctor(),
DeviceContext<D>::allocator(),
DataTypeToEnum<T>::v())));
}
}
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/serializer.h"
namespace mace {
unique_ptr<TensorProto> Serializer::Serialize(const Tensor &tensor,
const string &name) {
MACE_NOT_IMPLEMENTED;
return nullptr;
}
unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
DeviceType type) {
unique_ptr<Tensor> tensor(new Tensor(GetDeviceAllocator(type),
proto.data_type()));
vector<TIndex> dims;
for (const TIndex d : proto.dims()) {
dims.push_back(d);
}
tensor->Resize(dims);
switch (proto.data_type()) {
case DT_FLOAT:
tensor->Copy<float>(proto.float_data().data(),
proto.float_data().size());
break;
case DT_DOUBLE:
tensor->Copy<double>(proto.double_data().data(),
proto.double_data().size());
break;
case DT_INT32:
tensor->template Copy<int32>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_UINT8:
tensor->CopyWithCast<int32, uint8>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_INT16:
tensor->CopyWithCast<int32, int16>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_INT8:
tensor->CopyWithCast<int32, int8>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_INT64:
tensor->Copy<int64>(proto.int64_data().data(),
proto.int64_data().size());
break;
case DT_UINT16:
tensor->CopyWithCast<int32, uint16>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_BOOL:
tensor->CopyWithCast<int32, bool>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_STRING: {
string *content = tensor->mutable_data<string>();
for (int i = 0; i < proto.string_data().size(); ++i) {
content[i] = proto.string_data(i);
}
}
break;
default:
MACE_NOT_IMPLEMENTED;
break;
}
return tensor;
}
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_SERIALIZER_H_
#define MACE_CORE_SERIALIZER_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
namespace mace {
class Serializer {
public:
Serializer() {}
~Serializer() {}
unique_ptr<TensorProto> Serialize(const Tensor& tensor, const string& name);
unique_ptr<Tensor> Deserialize(const TensorProto& proto, DeviceType type);
DISABLE_COPY_AND_ASSIGN(Serializer);
};
} // namespace mace
#endif // MACE_CORE_SERIALIZER_H_
......@@ -9,6 +9,7 @@
#include "mace/proto/mace.pb.h"
#include "mace/core/allocator.h"
#include "mace/core/types.h"
#include "mace/core/logging.h"
namespace mace {
......@@ -118,6 +119,41 @@ class Tensor {
Resize(other->shape());
}
template <typename T>
inline void Copy(const T* src, size_t size) {
REQUIRE(size == size_, "copy src and dst with different size.");
CopyBytes(static_cast<const void*>(src), sizeof(T) * size);
}
template <typename SrcType, typename DstType>
inline void CopyWithCast(const SrcType* src, size_t size) {
REQUIRE(size == size_, "copy src and dst with different size.");
unique_ptr<DstType[]> buffer(new DstType[size]);
for (int i = 0; i < size; ++i) {
buffer[i] = static_cast<DstType>(src[i]);
}
CopyBytes(static_cast<const void*>(buffer.get()), sizeof(DstType) * size);
}
inline void CopyBytes(const void* src, size_t size) {
alloc_->CopyBytes(raw_mutable_data(), src, size);
}
inline void DebugPrint() {
std::stringstream os;
for (int i: shape_) {
os << i << ", ";
}
LOG(INFO) << "Tensor shape: " << os.str() << " type: " << DataType_Name(dtype_);
os.str("");
os.clear();
for (int i = 0; i < size_; ++i) {
CASES(dtype_, (os << this->data<T>()[i]) << ", ");
}
LOG(INFO) << os.str();
}
private:
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
......
......@@ -4,6 +4,7 @@
#include "mace/core/common.h"
#include "mace/core/workspace.h"
#include "mace/core/serializer.h"
namespace mace {
......@@ -48,6 +49,11 @@ Tensor* Workspace::GetTensor(const string& name) {
return const_cast<Tensor*>(static_cast<const Workspace*>(this)->GetTensor(name));
}
bool RunNet();
void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
Serializer serializer;
for (auto& tensor_proto: net_def.tensors()) {
tensor_map_[tensor_proto.name()] = serializer.Deserialize(tensor_proto, type);
}
}
} // namespace mace
\ No newline at end of file
......@@ -32,10 +32,12 @@ class Workspace {
Tensor* GetTensor(const string& name);
void LoadModelTensor(const NetDef& net_def, DeviceType type);
private:
TensorMap tensor_map_;
DISABLE_COPY_AND_ASSIGN(Workspace);
DISABLE_COPY_AND_ASSIGN(Workspace);
};
} // namespace mace
......
......@@ -7,7 +7,7 @@ cc_binary(
"helloworld.cc",
],
deps = [
"//mace/core:core",
"//mace/ops:ops",
],
copts = ['-std=c++11'],
linkopts = if_android(["-pie", "-llog"]),
......
#include "mace/core/logging.h"
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/net.h"
using namespace mace;
int main() {
LOG(INFO) << "Hello World";
// Construct graph
OperatorDef op_def_0;
op_def_0.add_input("Input");
op_def_0.add_output("Output0");
op_def_0.set_name("ReluTest0");
op_def_0.set_type("Relu");
auto arg_0 = op_def_0.add_arg();
arg_0->set_name("arg0");
arg_0->set_f(0.5);
OperatorDef op_def_1;
op_def_1.add_input("Input");
op_def_1.add_output("Output1");
op_def_1.set_name("ReluTest1");
op_def_1.set_type("Relu");
auto arg_1 = op_def_1.add_arg();
arg_1->set_name("arg0");
arg_1->set_f(1.5);
OperatorDef op_def_2;
op_def_2.add_input("Output1");
op_def_2.add_output("Output2");
op_def_2.set_name("ReluTest2");
op_def_2.set_type("Relu");
auto arg_2 = op_def_2.add_arg();
arg_2->set_name("arg0");
arg_2->set_f(2.5);
NetDef net_def;
net_def.set_name("NetTest");
net_def.add_op()->CopyFrom(op_def_0);
net_def.add_op()->CopyFrom(op_def_1);
net_def.add_op()->CopyFrom(op_def_2);
auto input = net_def.add_tensors();
input->set_name("Input");
input->set_data_type(DataType::DT_FLOAT);
input->add_dims(2);
input->add_dims(3);
for (int i = 0; i < 6; ++i) {
input->add_float_data(i - 3);
}
VLOG(0) << net_def.DebugString();
// Create workspace and input tensor
Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU);
// Create Net & run
auto net = CreateNet(net_def, &ws, DeviceType::CPU);
net->Run();
auto out_tensor = ws.GetTensor("Output2");
out_tensor->DebugPrint();
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册