提交 dd2cd8ee 编写于 作者: L Liangliang He

Rename TensorProto to ConstTensor

上级 ee725558
......@@ -10,50 +10,48 @@
namespace mace {
TensorProto::TensorProto(const std::string &name,
ConstTensor::ConstTensor(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const DataType data_type,
uint32_t node_id) :
name_(name),
data_(data),
data_size_(0),
dims_(dims.begin(), dims.end()),
data_type_(data_type),
node_id_(node_id) {
data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies<int64_t>());
}
node_id_(node_id),
data_size_(std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<int64_t>())) {}
TensorProto::TensorProto(const std::string &name,
ConstTensor::ConstTensor(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const int data_type,
uint32_t node_id) :
name_(name),
data_(data),
data_size_(0),
dims_(dims.begin(), dims.end()),
data_type_(static_cast<DataType>(data_type)),
node_id_(node_id) {
data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies<int64_t>());
}
node_id_(node_id),
data_size_(std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<int64_t>())) {}
const std::string &TensorProto::name() const {
const std::string &ConstTensor::name() const {
return name_;
}
unsigned char *TensorProto::data() const {
const unsigned char *ConstTensor::data() const {
return data_;
}
const int64_t TensorProto::data_size() const {
int64_t ConstTensor::data_size() const {
return data_size_;
}
const std::vector<int64_t> &TensorProto::dims() const {
const std::vector<int64_t> &ConstTensor::dims() const {
return dims_;
}
DataType TensorProto::data_type() const {
DataType ConstTensor::data_type() const {
return data_type_;
}
uint32_t TensorProto::node_id() const {
uint32_t ConstTensor::node_id() const {
return node_id_;
}
......@@ -446,10 +444,10 @@ Argument *NetDef::add_arg() {
std::vector<Argument> &NetDef::mutable_arg() {
return arg_;
}
const std::vector<TensorProto> &NetDef::tensors() const {
const std::vector<ConstTensor> &NetDef::tensors() const {
return tensors_;
}
std::vector<TensorProto> &NetDef::mutable_tensors() {
std::vector<ConstTensor> &NetDef::mutable_tensors() {
return tensors_;
}
const MemoryArena &NetDef::mem_arena() const {
......
......@@ -38,33 +38,33 @@ enum DataType {
DT_UINT32 = 22
};
class TensorProto {
class ConstTensor {
public:
TensorProto(const std::string &name,
ConstTensor(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const DataType data_type = DT_FLOAT,
uint32_t node_id = 0);
TensorProto(const std::string &name,
ConstTensor(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const int data_type,
uint32_t node_id = 0);
const std::string &name() const;
unsigned char *data() const;
const int64_t data_size() const;
const unsigned char *data() const;
int64_t data_size() const;
const std::vector<int64_t> &dims() const;
DataType data_type() const;
uint32_t node_id() const;
private:
std::string name_;
unsigned char *data_;
int64_t data_size_;
std::vector<int64_t> dims_;
DataType data_type_;
uint32_t node_id_;
const std::string name_;
const unsigned char *data_;
const int64_t data_size_;
const std::vector<int64_t> dims_;
const DataType data_type_;
const uint32_t node_id_;
};
class Argument {
......@@ -270,8 +270,8 @@ class NetDef {
const std::vector<Argument> &arg() const;
Argument *add_arg();
std::vector<Argument> &mutable_arg();
const std::vector<TensorProto> &tensors() const;
std::vector<TensorProto> &mutable_tensors();
const std::vector<ConstTensor> &tensors() const;
std::vector<ConstTensor> &mutable_tensors();
const MemoryArena &mem_arena() const;
bool has_mem_arena() const;
MemoryArena &mutable_mem_arena();
......@@ -288,7 +288,7 @@ class NetDef {
std::string version_;
std::vector<OperatorDef> op_;
std::vector<Argument> arg_;
std::vector<TensorProto> tensors_;
std::vector<ConstTensor> tensors_;
// for mem optimization
MemoryArena mem_arena_;
......
......@@ -6,13 +6,13 @@
namespace mace {
unique_ptr<TensorProto> Serializer::Serialize(const Tensor &tensor,
unique_ptr<ConstTensor> Serializer::Serialize(const Tensor &tensor,
const string &name) {
MACE_NOT_IMPLEMENTED;
return nullptr;
}
unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
unique_ptr<Tensor> Serializer::Deserialize(const ConstTensor &proto,
DeviceType type) {
unique_ptr<Tensor> tensor(
new Tensor(GetDeviceAllocator(type), proto.data_type()));
......@@ -24,31 +24,40 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch (proto.data_type()) {
case DT_FLOAT:
tensor->Copy<float>(reinterpret_cast<float*>(proto.data()), proto.data_size());
tensor->Copy<float>(reinterpret_cast<const float *>(proto.data()),
proto.data_size());
break;
case DT_DOUBLE:
tensor->Copy<double>(reinterpret_cast<double*>(proto.data()), proto.data_size());
tensor->Copy<double>(reinterpret_cast<const double *>(proto.data()),
proto.data_size());
break;
case DT_INT32:
tensor->Copy<int32_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->Copy<int32_t>(reinterpret_cast<const int32_t *>(proto.data()),
proto.data_size());
break;
case DT_INT64:
tensor->Copy<int64_t>(reinterpret_cast<int64_t*>(proto.data()), proto.data_size());
tensor->Copy<int64_t>(reinterpret_cast<const int64_t *>(proto.data()),
proto.data_size());
break;
case DT_UINT8:
tensor->CopyWithCast<int32_t, uint8_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->CopyWithCast<int32_t, uint8_t>(
reinterpret_cast<const int32_t *>(proto.data()), proto.data_size());
break;
case DT_INT16:
tensor->CopyWithCast<int32_t, uint16_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->CopyWithCast<int32_t, uint16_t>(
reinterpret_cast<const int32_t *>(proto.data()), proto.data_size());
break;
case DT_INT8:
tensor->CopyWithCast<int32_t, int8_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->CopyWithCast<int32_t, int8_t>(
reinterpret_cast<const int32_t *>(proto.data()), proto.data_size());
break;
case DT_UINT16:
tensor->CopyWithCast<int32_t, int16_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->CopyWithCast<int32_t, int16_t>(
reinterpret_cast<const int32_t *>(proto.data()), proto.data_size());
break;
case DT_BOOL:
tensor->CopyWithCast<int32_t, bool>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
tensor->CopyWithCast<int32_t, bool>(
reinterpret_cast<const int32_t *>(proto.data()), proto.data_size());
break;
default:
MACE_NOT_IMPLEMENTED;
......
......@@ -16,9 +16,9 @@ class Serializer {
Serializer() {}
~Serializer() {}
unique_ptr<TensorProto> Serialize(const Tensor &tensor, const string &name);
unique_ptr<ConstTensor> Serialize(const Tensor &tensor, const string &name);
unique_ptr<Tensor> Deserialize(const TensorProto &proto, DeviceType type);
unique_ptr<Tensor> Deserialize(const ConstTensor &proto, DeviceType type);
DISABLE_COPY_AND_ASSIGN(Serializer);
};
......
......@@ -45,7 +45,7 @@ int main() {
alignas(4) unsigned char tensor_data[] = "012345678901234567890123";
const std::vector<int64_t> dims = {1, 2, 3, 1};
TensorProto input("Input", tensor_data, dims, DataType::DT_FLOAT);
ConstTensor input("Input", tensor_data, dims, DataType::DT_FLOAT);
net_def.mutable_tensors().push_back(input);
// Create workspace and input tensor
......
......@@ -13,8 +13,8 @@ alignas(4) unsigned char {{ tensor_info.name }}[] = {
{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%}
};
void Create{{tensor.name}}(std::vector<mace::TensorProto> &tensors) {
tensors.emplace_back(mace::TensorProto(
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
tensors.emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, {{ tensor.name }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }}));
}
......@@ -100,7 +100,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
namespace {{tag}} {
{% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::TensorProto> &tensors);
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors);
{% endfor %}
......@@ -159,7 +159,7 @@ static void CreateOperators(std::vector<mace::OperatorDef> &ops) {
}
static void CreateTensors(std::vector<mace::TensorProto> &tensors) {
static void CreateTensors(std::vector<mace::ConstTensor> &tensors) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册