Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
dd2cd8ee
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dd2cd8ee
编写于
12月 22, 2017
作者:
L
Liangliang He
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename TensorProto to ConstTensor
上级
ee725558
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
58 addition
and
51 deletion
+58
-51
mace/core/mace.cc
mace/core/mace.cc
+16
-18
mace/core/public/mace.h
mace/core/public/mace.h
+14
-14
mace/core/serializer.cc
mace/core/serializer.cc
+21
-12
mace/core/serializer.h
mace/core/serializer.h
+2
-2
mace/examples/helloworld.cc
mace/examples/helloworld.cc
+1
-1
mace/python/tools/model.template
mace/python/tools/model.template
+4
-4
未找到文件。
mace/core/mace.cc
浏览文件 @
dd2cd8ee
...
...
@@ -10,50 +10,48 @@
namespace
mace
{
TensorProto
::
TensorProto
(
const
std
::
string
&
name
,
ConstTensor
::
ConstTensor
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
DataType
data_type
,
uint32_t
node_id
)
:
name_
(
name
),
data_
(
data
),
data_size_
(
0
),
dims_
(
dims
.
begin
(),
dims
.
end
()),
data_type_
(
data_type
),
node_id_
(
node_id
)
{
data_size_
=
std
::
accumulate
(
dims_
.
begin
(),
dims_
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
}
node_id_
(
node_id
)
,
data_size_
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
()))
{
}
TensorProto
::
TensorProto
(
const
std
::
string
&
name
,
ConstTensor
::
ConstTensor
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
int
data_type
,
uint32_t
node_id
)
:
name_
(
name
),
data_
(
data
),
data_size_
(
0
),
dims_
(
dims
.
begin
(),
dims
.
end
()),
data_type_
(
static_cast
<
DataType
>
(
data_type
)),
node_id_
(
node_id
)
{
data_size_
=
std
::
accumulate
(
dims_
.
begin
(),
dims_
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
}
node_id_
(
node_id
)
,
data_size_
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
()))
{
}
const
std
::
string
&
TensorProto
::
name
()
const
{
const
std
::
string
&
ConstTensor
::
name
()
const
{
return
name_
;
}
unsigned
char
*
TensorProto
::
data
()
const
{
const
unsigned
char
*
ConstTensor
::
data
()
const
{
return
data_
;
}
const
int64_t
TensorProto
::
data_size
()
const
{
int64_t
ConstTensor
::
data_size
()
const
{
return
data_size_
;
}
const
std
::
vector
<
int64_t
>
&
TensorProto
::
dims
()
const
{
const
std
::
vector
<
int64_t
>
&
ConstTensor
::
dims
()
const
{
return
dims_
;
}
DataType
TensorProto
::
data_type
()
const
{
DataType
ConstTensor
::
data_type
()
const
{
return
data_type_
;
}
uint32_t
TensorProto
::
node_id
()
const
{
uint32_t
ConstTensor
::
node_id
()
const
{
return
node_id_
;
}
...
...
@@ -446,10 +444,10 @@ Argument *NetDef::add_arg() {
std
::
vector
<
Argument
>
&
NetDef
::
mutable_arg
()
{
return
arg_
;
}
const
std
::
vector
<
TensorProto
>
&
NetDef
::
tensors
()
const
{
const
std
::
vector
<
ConstTensor
>
&
NetDef
::
tensors
()
const
{
return
tensors_
;
}
std
::
vector
<
TensorProto
>
&
NetDef
::
mutable_tensors
()
{
std
::
vector
<
ConstTensor
>
&
NetDef
::
mutable_tensors
()
{
return
tensors_
;
}
const
MemoryArena
&
NetDef
::
mem_arena
()
const
{
...
...
mace/core/public/mace.h
浏览文件 @
dd2cd8ee
...
...
@@ -38,33 +38,33 @@ enum DataType {
DT_UINT32
=
22
};
class
TensorProto
{
class
ConstTensor
{
public:
TensorProto
(
const
std
::
string
&
name
,
ConstTensor
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
DataType
data_type
=
DT_FLOAT
,
uint32_t
node_id
=
0
);
TensorProto
(
const
std
::
string
&
name
,
ConstTensor
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
int
data_type
,
uint32_t
node_id
=
0
);
const
std
::
string
&
name
()
const
;
unsigned
char
*
data
()
const
;
const
int64_t
data_size
()
const
;
const
unsigned
char
*
data
()
const
;
int64_t
data_size
()
const
;
const
std
::
vector
<
int64_t
>
&
dims
()
const
;
DataType
data_type
()
const
;
uint32_t
node_id
()
const
;
private:
std
::
string
name_
;
unsigned
char
*
data_
;
int64_t
data_size_
;
std
::
vector
<
int64_t
>
dims_
;
DataType
data_type_
;
uint32_t
node_id_
;
const
std
::
string
name_
;
const
unsigned
char
*
data_
;
const
int64_t
data_size_
;
const
std
::
vector
<
int64_t
>
dims_
;
const
DataType
data_type_
;
const
uint32_t
node_id_
;
};
class
Argument
{
...
...
@@ -270,8 +270,8 @@ class NetDef {
const
std
::
vector
<
Argument
>
&
arg
()
const
;
Argument
*
add_arg
();
std
::
vector
<
Argument
>
&
mutable_arg
();
const
std
::
vector
<
TensorProto
>
&
tensors
()
const
;
std
::
vector
<
TensorProto
>
&
mutable_tensors
();
const
std
::
vector
<
ConstTensor
>
&
tensors
()
const
;
std
::
vector
<
ConstTensor
>
&
mutable_tensors
();
const
MemoryArena
&
mem_arena
()
const
;
bool
has_mem_arena
()
const
;
MemoryArena
&
mutable_mem_arena
();
...
...
@@ -288,7 +288,7 @@ class NetDef {
std
::
string
version_
;
std
::
vector
<
OperatorDef
>
op_
;
std
::
vector
<
Argument
>
arg_
;
std
::
vector
<
TensorProto
>
tensors_
;
std
::
vector
<
ConstTensor
>
tensors_
;
// for mem optimization
MemoryArena
mem_arena_
;
...
...
mace/core/serializer.cc
浏览文件 @
dd2cd8ee
...
...
@@ -6,13 +6,13 @@
namespace
mace
{
unique_ptr
<
TensorProto
>
Serializer
::
Serialize
(
const
Tensor
&
tensor
,
unique_ptr
<
ConstTensor
>
Serializer
::
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
)
{
MACE_NOT_IMPLEMENTED
;
return
nullptr
;
}
unique_ptr
<
Tensor
>
Serializer
::
Deserialize
(
const
TensorProto
&
proto
,
unique_ptr
<
Tensor
>
Serializer
::
Deserialize
(
const
ConstTensor
&
proto
,
DeviceType
type
)
{
unique_ptr
<
Tensor
>
tensor
(
new
Tensor
(
GetDeviceAllocator
(
type
),
proto
.
data_type
()));
...
...
@@ -24,31 +24,40 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch
(
proto
.
data_type
())
{
case
DT_FLOAT
:
tensor
->
Copy
<
float
>
(
reinterpret_cast
<
float
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
Copy
<
float
>
(
reinterpret_cast
<
const
float
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_DOUBLE
:
tensor
->
Copy
<
double
>
(
reinterpret_cast
<
double
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
Copy
<
double
>
(
reinterpret_cast
<
const
double
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT32
:
tensor
->
Copy
<
int32_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
Copy
<
int32_t
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT64
:
tensor
->
Copy
<
int64_t
>
(
reinterpret_cast
<
int64_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
Copy
<
int64_t
>
(
reinterpret_cast
<
const
int64_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_UINT8
:
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT16
:
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT8
:
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_UINT16
:
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_BOOL
:
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
reinterpret_cast
<
const
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
default:
MACE_NOT_IMPLEMENTED
;
...
...
mace/core/serializer.h
浏览文件 @
dd2cd8ee
...
...
@@ -16,9 +16,9 @@ class Serializer {
Serializer
()
{}
~
Serializer
()
{}
unique_ptr
<
TensorProto
>
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
);
unique_ptr
<
ConstTensor
>
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
);
unique_ptr
<
Tensor
>
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
);
unique_ptr
<
Tensor
>
Deserialize
(
const
ConstTensor
&
proto
,
DeviceType
type
);
DISABLE_COPY_AND_ASSIGN
(
Serializer
);
};
...
...
mace/examples/helloworld.cc
浏览文件 @
dd2cd8ee
...
...
@@ -45,7 +45,7 @@ int main() {
alignas
(
4
)
unsigned
char
tensor_data
[]
=
"012345678901234567890123"
;
const
std
::
vector
<
int64_t
>
dims
=
{
1
,
2
,
3
,
1
};
TensorProto
input
(
"Input"
,
tensor_data
,
dims
,
DataType
::
DT_FLOAT
);
ConstTensor
input
(
"Input"
,
tensor_data
,
dims
,
DataType
::
DT_FLOAT
);
net_def
.
mutable_tensors
().
push_back
(
input
);
// Create workspace and input tensor
...
...
mace/python/tools/model.template
浏览文件 @
dd2cd8ee
...
...
@@ -13,8 +13,8 @@ alignas(4) unsigned char {{ tensor_info.name }}[] = {
{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%}
};
void Create{{tensor.name}}(std::vector<mace::
TensorProto
> &tensors) {
tensors.emplace_back(mace::
TensorProto
(
void Create{{tensor.name}}(std::vector<mace::
ConstTensor
> &tensors) {
tensors.emplace_back(mace::
ConstTensor
(
{{ tensor.name|tojson }}, {{ tensor.name }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }}));
}
...
...
@@ -100,7 +100,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
namespace {{tag}} {
{% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::
TensorProto
> &tensors);
extern void Create{{ tensor.name }}(std::vector<mace::
ConstTensor
> &tensors);
{% endfor %}
...
...
@@ -159,7 +159,7 @@ static void CreateOperators(std::vector<mace::OperatorDef> &ops) {
}
static void CreateTensors(std::vector<mace::
TensorProto
> &tensors) {
static void CreateTensors(std::vector<mace::
ConstTensor
> &tensors) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录