Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
as350144
Mace
提交
f05649a4
Mace
项目概览
as350144
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
2
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f05649a4
编写于
12月 15, 2017
作者:
L
liuqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace pb with code.
上级
23238388
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
987 addition
and
471 deletion
+987
-471
mace/core/BUILD
mace/core/BUILD
+0
-1
mace/core/allocator.h
mace/core/allocator.h
+1
-1
mace/core/mace.cc
mace/core/mace.cc
+404
-0
mace/core/mace.h
mace/core/mace.h
+341
-0
mace/core/net.cc
mace/core/net.cc
+1
-1
mace/core/net.h
mace/core/net.h
+1
-1
mace/core/operator.h
mace/core/operator.h
+1
-1
mace/core/proto_utils.cc
mace/core/proto_utils.cc
+2
-238
mace/core/proto_utils.h
mace/core/proto_utils.h
+1
-178
mace/core/serializer.cc
mace/core/serializer.cc
+11
-25
mace/core/serializer.h
mace/core/serializer.h
+1
-1
mace/core/tensor.h
mace/core/tensor.h
+1
-1
mace/core/types.h
mace/core/types.h
+1
-1
mace/core/workspace.cc
mace/core/workspace.cc
+1
-1
mace/core/workspace.h
mace/core/workspace.h
+1
-1
mace/examples/mace_run.cc
mace/examples/mace_run.cc
+20
-6
mace/kernels/batch_norm.h
mace/kernels/batch_norm.h
+1
-1
mace/kernels/bias_add.h
mace/kernels/bias_add.h
+1
-1
mace/kernels/concat.h
mace/kernels/concat.h
+1
-1
mace/kernels/depthwise_conv2d.h
mace/kernels/depthwise_conv2d.h
+1
-1
mace/kernels/space_to_batch.h
mace/kernels/space_to_batch.h
+1
-1
mace/ops/BUILD
mace/ops/BUILD
+0
-1
mace/ops/concat.h
mace/ops/concat.h
+1
-1
mace/ops/ops_test_util.h
mace/ops/ops_test_util.h
+0
-1
mace/python/tools/model.template
mace/python/tools/model.template
+139
-0
mace/python/tools/tf_converter.py
mace/python/tools/tf_converter.py
+54
-6
未找到文件。
mace/core/BUILD
浏览文件 @
f05649a4
...
...
@@ -62,7 +62,6 @@ cc_library(
]),
deps
=
[
":logging"
,
"//mace/proto:cc_proto"
,
"//mace/proto:stats_proto"
,
"//mace/utils"
,
":opencl_runtime"
,
...
...
mace/core/allocator.h
浏览文件 @
f05649a4
...
...
@@ -9,7 +9,7 @@
#include <malloc.h>
#include "mace/core/common.h"
#include "mace/core/registry.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
#include "mace/core/types.h"
namespace
mace
{
...
...
mace/core/mace.cc
0 → 100644
浏览文件 @
f05649a4
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/mace.h"
#include "mace/core/logging.h"
namespace
mace
{
TensorProto
::
TensorProto
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
DataType
data_type
,
uint32_t
node_id
)
:
name_
(
name
),
data_
(
data
),
dims_
(
dims
),
data_type_
(
data_type
),
node_id_
(
node_id
)
{}
TensorProto
::
TensorProto
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
int
data_type
,
uint32_t
node_id
)
:
name_
(
name
),
data_
(
data
),
dims_
(
dims
),
data_type_
(
static_cast
<
DataType
>
(
data_type
)),
node_id_
(
node_id
)
{}
const
std
::
string
&
TensorProto
::
name
()
const
{
return
name_
;
}
unsigned
char
*
TensorProto
::
data
()
const
{
return
data_
;
}
const
int
TensorProto
::
data_size
()
const
{
return
data_size_
;
}
const
std
::
vector
<
int64_t
>
&
TensorProto
::
dims
()
const
{
return
dims_
;
}
DataType
TensorProto
::
data_type
()
const
{
return
data_type_
;
}
uint32_t
TensorProto
::
node_id
()
const
{
return
node_id_
;
}
Argument
::
Argument
()
:
has_bits_
(
0
)
{}
void
Argument
::
CopyFrom
(
const
Argument
&
from
)
{
this
->
name_
=
from
.
name
();
this
->
f_
=
from
.
f
();
this
->
i_
=
from
.
i
();
this
->
s_
=
from
.
s
();
auto
floats
=
from
.
floats
();
this
->
floats_
.
resize
(
floats
.
size
());
std
::
copy
(
floats
.
begin
(),
floats
.
end
(),
this
->
floats_
.
begin
());
auto
ints
=
from
.
ints
();
this
->
ints_
.
resize
(
ints
.
size
());
std
::
copy
(
ints
.
begin
(),
ints
.
end
(),
this
->
ints_
.
begin
());
auto
strings
=
from
.
floats
();
this
->
strings_
.
resize
(
strings
.
size
());
std
::
copy
(
floats
.
begin
(),
floats
.
end
(),
this
->
floats_
.
begin
());
this
->
has_bits_
=
from
.
has_bits_
;
}
const
std
::
string
&
Argument
::
name
()
const
{
return
name_
;
}
void
Argument
::
set_name
(
const
std
::
string
&
value
)
{
name_
=
value
;
}
bool
Argument
::
has_f
()
const
{
return
(
has_bits_
&
0x00000001u
)
!=
0
;
}
void
Argument
::
set_has_f
()
{
has_bits_
|=
0x00000001u
;
}
float
Argument
::
f
()
const
{
return
f_
;
}
void
Argument
::
set_f
(
float
value
)
{
set_has_f
();
f_
=
value
;
}
bool
Argument
::
has_i
()
const
{
return
(
has_bits_
&
0x00000002u
)
!=
0
;
}
void
Argument
::
set_has_i
()
{
has_bits_
|=
0x00000002u
;
}
int64_t
Argument
::
i
()
const
{
return
i_
;
}
void
Argument
::
set_i
(
int64_t
value
)
{
set_has_i
();
i_
=
value
;
}
bool
Argument
::
has_s
()
const
{
return
(
has_bits_
&
0x00000004u
)
!=
0
;
}
void
Argument
::
set_has_s
()
{
has_bits_
|=
0x00000004u
;
}
std
::
string
Argument
::
s
()
const
{
return
s_
;
}
void
Argument
::
set_s
(
const
std
::
string
&
value
)
{
set_has_s
();
s_
=
value
;
}
const
std
::
vector
<
float
>
&
Argument
::
floats
()
const
{
return
floats_
;
}
void
Argument
::
add_floats
(
float
value
)
{
floats_
.
push_back
(
value
);
}
void
Argument
::
set_floats
(
const
std
::
vector
<
float
>
&
value
)
{
floats_
.
reserve
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
floats_
.
begin
());
}
const
std
::
vector
<
int64_t
>
&
Argument
::
ints
()
const
{
return
ints_
;
}
void
Argument
::
add_ints
(
int64_t
value
)
{
ints_
.
push_back
(
value
);
}
void
Argument
::
set_ints
(
const
std
::
vector
<
int64_t
>
&
value
)
{
ints_
.
reserve
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
ints_
.
begin
());
}
const
std
::
vector
<
std
::
string
>
&
Argument
::
strings
()
const
{
return
strings_
;
}
void
Argument
::
add_strings
(
const
::
std
::
string
&
value
)
{
strings_
.
push_back
(
value
);
}
void
Argument
::
set_strings
(
const
std
::
vector
<
std
::
string
>
&
value
)
{
strings_
.
reserve
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
strings_
.
begin
());
}
void
OperatorDef
::
CopyFrom
(
const
OperatorDef
&
from
)
{
name_
=
from
.
name
();
type_
=
from
.
type
();
auto
from_input
=
from
.
input
();
input_
.
resize
(
from_input
.
size
());
std
::
copy
(
from_input
.
begin
(),
from_input
.
end
(),
input_
.
begin
());
auto
from_output
=
from
.
output
();
output_
.
resize
(
from_output
.
size
());
std
::
copy
(
from_output
.
begin
(),
from_output
.
end
(),
output_
.
begin
());
auto
from_arg
=
from
.
arg
();
arg_
.
resize
(
from_arg
.
size
());
for
(
int
i
=
0
;
i
<
from_arg
.
size
();
++
i
)
{
arg_
[
i
].
CopyFrom
(
from_arg
[
i
]);
}
auto
from_output_shape
=
from
.
output_shape
();
output_shape_
.
resize
(
from_output_shape
.
size
());
for
(
int
i
=
0
;
i
<
from_output_shape
.
size
();
++
i
)
{
output_shape_
[
i
].
CopyFrom
(
from_output_shape
[
i
]);
}
auto
from_data_type
=
from
.
output_type
();
output_type_
.
resize
(
from_data_type
.
size
());
std
::
copy
(
from_data_type
.
begin
(),
from_data_type
.
end
(),
output_type_
.
begin
());
mem_id_
=
from
.
mem_id
();
// nnlib
node_id_
=
from
.
node_id
();
op_id_
=
from
.
op_id
();
padding_
=
from
.
padding
();
auto
from_node_input
=
from
.
node_input
();
node_input_
.
resize
(
from_node_input
.
size
());
for
(
int
i
=
0
;
i
<
from_node_input
.
size
();
++
i
)
{
node_input_
[
i
].
CopyFrom
(
from_node_input
[
i
]);
}
auto
from_out_max_byte_size
=
from
.
out_max_byte_size
();
out_max_byte_size_
.
resize
(
from_out_max_byte_size
.
size
());
std
::
copy
(
from_out_max_byte_size
.
begin
(),
from_out_max_byte_size
.
end
(),
out_max_byte_size_
.
begin
());
has_bits_
=
from
.
has_bits_
;
}
const
std
::
string
&
OperatorDef
::
name
()
const
{
return
name_
;
}
void
OperatorDef
::
set_name
(
const
std
::
string
&
name_
)
{
set_has_name
();
OperatorDef
::
name_
=
name_
;
}
bool
OperatorDef
::
has_name
()
const
{
return
(
has_bits_
&
0x00000001u
)
!=
0
;
}
void
OperatorDef
::
set_has_name
()
{
has_bits_
|=
0x00000001u
;
}
const
std
::
string
&
OperatorDef
::
type
()
const
{
return
type_
;
}
void
OperatorDef
::
set_type
(
const
std
::
string
&
type_
)
{
set_has_type
();
OperatorDef
::
type_
=
type_
;
}
bool
OperatorDef
::
has_type
()
const
{
return
(
has_bits_
&
0x00000002u
)
!=
0
;
}
void
OperatorDef
::
set_has_type
()
{
has_bits_
|=
0x00000002u
;
}
int
OperatorDef
::
mem_id
()
const
{
return
mem_id_
;
}
void
OperatorDef
::
set_mem_id
(
const
int
mem_id
)
{
set_has_mem_id
();
mem_id_
=
mem_id
;
}
bool
OperatorDef
::
has_mem_id
()
const
{
return
(
has_bits_
&
0x00000004u
)
!=
0
;
}
void
OperatorDef
::
set_has_mem_id
()
{
has_bits_
|=
0x00000004u
;
}
uint32_t
OperatorDef
::
node_id
()
const
{
return
node_id_
;
}
uint32_t
OperatorDef
::
op_id
()
const
{
return
op_id_
;
}
uint32_t
OperatorDef
::
padding
()
const
{
return
padding_
;
}
const
std
::
vector
<
NodeInput
>
&
OperatorDef
::
node_input
()
const
{
return
node_input_
;
}
const
std
::
vector
<
int
>
&
OperatorDef
::
out_max_byte_size
()
const
{
return
out_max_byte_size_
;
}
const
std
::
vector
<
std
::
string
>
&
OperatorDef
::
input
()
const
{
return
input_
;
}
const
std
::
string
&
OperatorDef
::
input
(
int
index
)
const
{
MACE_CHECK
(
0
<=
index
&&
index
<=
input_
.
size
());
return
input_
[
index
];
}
std
::
string
*
OperatorDef
::
add_input
()
{
input_
.
push_back
(
""
);
return
&
input_
.
back
();
}
void
OperatorDef
::
add_input
(
const
::
std
::
string
&
value
)
{
input_
.
push_back
(
value
);
}
void
OperatorDef
::
add_input
(
::
std
::
string
&&
value
)
{
input_
.
push_back
(
value
);
}
void
OperatorDef
::
set_input
(
const
std
::
vector
<
std
::
string
>
&
value
)
{
input_
.
reserve
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
input_
.
begin
());
}
const
std
::
vector
<
std
::
string
>
&
OperatorDef
::
output
()
const
{
return
output_
;
}
const
std
::
string
&
OperatorDef
::
output
(
int
index
)
const
{
MACE_CHECK
(
0
<=
index
&&
index
<=
output_
.
size
());
return
output_
[
index
];
}
std
::
string
*
OperatorDef
::
add_output
()
{
output_
.
push_back
(
""
);
return
&
output_
.
back
();
}
void
OperatorDef
::
add_output
(
const
::
std
::
string
&
value
)
{
output_
.
push_back
(
value
);
}
void
OperatorDef
::
add_output
(
::
std
::
string
&&
value
)
{
output_
.
push_back
(
value
);
}
void
OperatorDef
::
set_output
(
const
std
::
vector
<
std
::
string
>
&
value
)
{
output_
.
reserve
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
output_
.
begin
());
}
const
std
::
vector
<
Argument
>
&
OperatorDef
::
arg
()
const
{
return
arg_
;
}
Argument
*
OperatorDef
::
add_arg
()
{
arg_
.
emplace_back
(
Argument
());
return
&
arg_
.
back
();
}
const
std
::
vector
<
OutputShape
>
&
OperatorDef
::
output_shape
()
const
{
return
output_shape_
;
}
void
OperatorDef
::
set_output_shape
(
const
std
::
vector
<
OutputShape
>
&
value
)
{
output_shape_
.
reserve
(
value
.
size
());
for
(
int
i
=
0
;
i
<
value
.
size
();
++
i
)
{
output_shape_
[
i
].
CopyFrom
(
value
[
i
]);
}
}
const
std
::
vector
<
DataType
>
&
OperatorDef
::
output_type
()
const
{
return
output_type_
;
}
void
OperatorDef
::
set_output_type
(
const
std
::
vector
<
DataType
>
&
value
)
{
output_type_
.
resize
(
value
.
size
());
std
::
copy
(
value
.
begin
(),
value
.
end
(),
output_type_
.
begin
());
}
MemoryBlock
::
MemoryBlock
(
int
mem_id
,
uint32_t
x
,
uint32_t
y
)
:
mem_id_
(
mem_id
),
x_
(
x
),
y_
(
y
)
{}
int
MemoryBlock
::
mem_id
()
const
{
return
mem_id_
;
}
uint32_t
MemoryBlock
::
x
()
const
{
return
x_
;
}
uint32_t
MemoryBlock
::
y
()
const
{
return
y_
;
}
NetDef
::
NetDef
()
:
has_bits_
(
0
)
{}
const
std
::
string
&
NetDef
::
name
()
const
{
return
name_
;
}
void
NetDef
::
set_name
(
const
std
::
string
&
value
)
{
set_has_name
();
name_
=
value
;
}
bool
NetDef
::
has_name
()
const
{
return
(
has_bits_
&
0x00000001u
)
!=
0
;
}
void
NetDef
::
set_has_name
()
{
has_bits_
|=
0x00000001u
;
}
const
std
::
string
&
NetDef
::
version
()
const
{
return
version_
;
}
void
NetDef
::
set_version
(
const
std
::
string
&
value
)
{
set_has_version
();
version_
=
value
;
}
bool
NetDef
::
has_version
()
const
{
return
(
has_bits_
&
0x00000002u
)
!=
0
;
}
void
NetDef
::
set_has_version
()
{
has_bits_
|=
0x00000002u
;
}
const
std
::
vector
<
OperatorDef
>
&
NetDef
::
op
()
const
{
return
op_
;
}
OperatorDef
*
NetDef
::
add_op
()
{
op_
.
emplace_back
(
OperatorDef
());
return
&
op_
.
back
();
}
std
::
vector
<
OperatorDef
>
&
NetDef
::
mutable_op
()
{
return
op_
;
}
const
std
::
vector
<
Argument
>
&
NetDef
::
arg
()
const
{
return
arg_
;
}
Argument
*
NetDef
::
add_arg
()
{
arg_
.
emplace_back
(
Argument
());
return
&
arg_
.
back
();
}
std
::
vector
<
Argument
>
&
NetDef
::
mutable_arg
()
{
return
arg_
;
}
const
std
::
vector
<
TensorProto
>
&
NetDef
::
tensors
()
const
{
return
tensors_
;
}
std
::
vector
<
TensorProto
>
&
NetDef
::
mutable_tensors
()
{
return
tensors_
;
}
const
MemoryArena
&
NetDef
::
mem_arena
()
const
{
return
mem_arena_
;
}
MemoryArena
&
NetDef
::
mutable_mem_arena
()
{
set_has_mem_arena
();
return
mem_arena_
;
}
bool
NetDef
::
has_mem_arena
()
const
{
return
(
has_bits_
&
0x00000004u
)
!=
0
;
}
void
NetDef
::
set_has_mem_arena
()
{
has_bits_
|=
0x00000004u
;
}
const
std
::
vector
<
InputInfo
>
&
NetDef
::
input_info
()
const
{
return
input_info_
;
}
const
std
::
vector
<
OutputInfo
>
&
NetDef
::
output_info
()
const
{
return
output_info_
;
}
int
NetDef
::
op_size
()
const
{
return
op_
.
size
();
}
const
OperatorDef
&
NetDef
::
op
(
const
int
idx
)
const
{
MACE_CHECK
(
0
<=
idx
&&
idx
<
op_size
());
return
op_
[
idx
];
}
}
// namespace mace
mace/core/mace.h
0 → 100644
浏览文件 @
f05649a4
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_MACE_H_
#define MACE_CORE_MACE_H_
#include <cstdint>
#include <vector>
#include <string>
#include "mace/core/logging.h"
namespace
mace
{
enum
NetMode
{
INIT
=
0
,
NORMAL
=
1
};
enum
DeviceType
{
CPU
=
0
,
NEON
=
1
,
OPENCL
=
2
};
enum
DataType
{
DT_INVALID
=
0
,
DT_FLOAT
=
1
,
DT_DOUBLE
=
2
,
DT_INT32
=
3
,
DT_UINT8
=
4
,
DT_INT16
=
5
,
DT_INT8
=
6
,
DT_STRING
=
7
,
DT_INT64
=
8
,
DT_UINT16
=
9
,
DT_BOOL
=
10
,
DT_HALF
=
19
,
DT_UINT32
=
22
};
class
TensorProto
{
public:
TensorProto
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
DataType
data_type
=
DT_FLOAT
,
uint32_t
node_id
=
0
);
TensorProto
(
const
std
::
string
&
name
,
unsigned
char
*
data
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
int
data_type
,
uint32_t
node_id
=
0
);
const
std
::
string
&
name
()
const
;
unsigned
char
*
data
()
const
;
const
int
data_size
()
const
;
const
std
::
vector
<
int64_t
>
&
dims
()
const
;
DataType
data_type
()
const
;
uint32_t
node_id
()
const
;
private:
std
::
string
name_
;
unsigned
char
*
data_
;
int
data_size_
;
std
::
vector
<
int64_t
>
dims_
;
DataType
data_type_
;
uint32_t
node_id_
;
};
class
Argument
{
public:
Argument
();
void
CopyFrom
(
const
Argument
&
from
)
;
public:
const
std
::
string
&
name
()
const
;
void
set_name
(
const
std
::
string
&
value
);
bool
has_f
()
const
;
float
f
()
const
;
void
set_f
(
float
value
)
;
bool
has_i
()
const
;
int64_t
i
()
const
;
void
set_i
(
int64_t
value
);
bool
has_s
()
const
;
std
::
string
s
()
const
;
void
set_s
(
const
std
::
string
&
value
)
;
const
std
::
vector
<
float
>
&
floats
()
const
;
void
add_floats
(
float
value
)
;
void
set_floats
(
const
std
::
vector
<
float
>
&
value
);
const
std
::
vector
<
int64_t
>
&
ints
()
const
;
void
add_ints
(
int64_t
value
)
;
void
set_ints
(
const
std
::
vector
<
int64_t
>
&
value
);
const
std
::
vector
<
std
::
string
>
&
strings
()
const
;
void
add_strings
(
const
::
std
::
string
&
value
)
;
void
set_strings
(
const
std
::
vector
<
std
::
string
>
&
value
);
private:
void
set_has_f
()
;
void
set_has_i
()
;
void
set_has_s
()
;
private:
std
::
string
name_
;
float
f_
;
int64_t
i_
;
std
::
string
s_
;
std
::
vector
<
float
>
floats_
;
std
::
vector
<
int64_t
>
ints_
;
std
::
vector
<
std
::
string
>
strings_
;
uint32_t
has_bits_
;
};
class
NodeInput
{
public:
void
CopyFrom
(
const
NodeInput
&
from
)
{
node_id_
=
from
.
node_id
();
output_port_
=
from
.
output_port
();
}
public:
int
node_id
()
const
{
return
node_id_
;
}
int
output_port
()
const
{
return
output_port_
;
}
private:
int
node_id_
;
int
output_port_
;
};
class
OutputShape
{
public:
void
CopyFrom
(
const
OutputShape
&
from
)
{
auto
from_dims
=
from
.
dims
();
dims_
.
resize
(
from_dims
.
size
());
std
::
copy
(
from_dims
.
begin
(),
from_dims
.
end
(),
dims_
.
begin
());
}
public:
const
std
::
vector
<
int64_t
>
&
dims
()
const
{
return
dims_
;
}
private:
std
::
vector
<
int64_t
>
dims_
;
};
class
OperatorDef
{
public:
void
CopyFrom
(
const
OperatorDef
&
from
);
public:
const
std
::
string
&
name
()
const
;
void
set_name
(
const
std
::
string
&
name_
);
bool
has_name
()
const
;
const
std
::
string
&
type
()
const
;
void
set_type
(
const
std
::
string
&
type_
);
bool
has_type
()
const
;
int
mem_id
()
const
;
void
set_mem_id
(
const
int
mem_id
);
bool
has_mem_id
()
const
;
uint32_t
node_id
()
const
;
uint32_t
op_id
()
const
;
uint32_t
padding
()
const
;
const
std
::
vector
<
NodeInput
>
&
node_input
()
const
;
const
std
::
vector
<
int
>
&
out_max_byte_size
()
const
;
const
std
::
vector
<
std
::
string
>
&
input
()
const
;
const
std
::
string
&
input
(
int
index
)
const
;
std
::
string
*
add_input
();
void
add_input
(
const
::
std
::
string
&
value
);
void
add_input
(
::
std
::
string
&&
value
);
void
set_input
(
const
std
::
vector
<
std
::
string
>
&
value
);
const
std
::
vector
<
std
::
string
>
&
output
()
const
;
const
std
::
string
&
output
(
int
index
)
const
;
std
::
string
*
add_output
();
void
add_output
(
const
::
std
::
string
&
value
);
void
add_output
(
::
std
::
string
&&
value
);
void
set_output
(
const
std
::
vector
<
std
::
string
>
&
value
);
const
std
::
vector
<
Argument
>
&
arg
()
const
;
Argument
*
add_arg
();
const
std
::
vector
<
OutputShape
>
&
output_shape
()
const
;
void
set_output_shape
(
const
std
::
vector
<
OutputShape
>
&
value
);
const
std
::
vector
<
DataType
>
&
output_type
()
const
;
void
set_output_type
(
const
std
::
vector
<
DataType
>
&
value
);
private:
void
set_has_name
();
void
set_has_type
();
void
set_has_mem_id
();
private:
std
::
string
name_
;
std
::
string
type_
;
std
::
vector
<
std
::
string
>
input_
;
std
::
vector
<
std
::
string
>
output_
;
std
::
vector
<
Argument
>
arg_
;
std
::
vector
<
OutputShape
>
output_shape_
;
std
::
vector
<
DataType
>
output_type_
;
int
mem_id_
;
// nnlib
uint32_t
node_id_
;
uint32_t
op_id_
;
uint32_t
padding_
;
std
::
vector
<
NodeInput
>
node_input_
;
std
::
vector
<
int
>
out_max_byte_size_
;
uint32_t
has_bits_
;
};
class
MemoryBlock
{
public:
MemoryBlock
(
int
mem_id
,
uint32_t
x
,
uint32_t
y
);
public:
int
mem_id
()
const
;
uint32_t
x
()
const
;
uint32_t
y
()
const
;
private:
int
mem_id_
;
uint32_t
x_
;
uint32_t
y_
;
};
class
MemoryArena
{
public:
inline
const
std
::
vector
<
MemoryBlock
>
&
mem_block
()
const
{
return
mem_block_
;
}
inline
std
::
vector
<
MemoryBlock
>
&
mutable_mem_block
()
{
return
mem_block_
;
}
inline
int
mem_block_size
()
const
{
return
mem_block_
.
size
();
}
private:
std
::
vector
<
MemoryBlock
>
mem_block_
;
};
// for hexagon mace-nnlib
class
InputInfo
{
public:
const
std
::
string
&
name
()
const
{
return
name_
;
}
int32_t
node_id
()
const
{
return
node_id_
;
}
int32_t
max_byte_size
()
const
{
return
max_byte_size_
;
}
DataType
data_type
()
const
{
return
data_type_
;
}
const
std
::
vector
<
int32_t
>
&
dims
()
const
{
return
dims_
;
}
private:
std
::
string
name_
;
int32_t
node_id_
;
int32_t
max_byte_size_
;
// only support 32-bit len
DataType
data_type_
;
std
::
vector
<
int32_t
>
dims_
;
};
class
OutputInfo
{
public:
const
std
::
string
&
name
()
const
{
return
name_
;
}
int32_t
node_id
()
const
{
return
node_id_
;
}
int32_t
max_byte_size
()
const
{
return
max_byte_size_
;
}
DataType
data_type
()
const
{
return
data_type_
;
}
const
std
::
vector
<
int32_t
>
&
dims
()
const
{
return
dims_
;
}
private:
std
::
string
name_
;
int32_t
node_id_
;
int32_t
max_byte_size_
;
// only support 32-bit len
DataType
data_type_
;
std
::
vector
<
int32_t
>
dims_
;
};
class
NetDef
{
public:
NetDef
();
int
op_size
()
const
;
const
OperatorDef
&
op
(
const
int
idx
)
const
;
public:
const
std
::
string
&
name
()
const
;
bool
has_name
()
const
;
void
set_name
(
const
std
::
string
&
value
);
const
std
::
string
&
version
()
const
;
bool
has_version
()
const
;
void
set_version
(
const
std
::
string
&
value
);
const
std
::
vector
<
OperatorDef
>
&
op
()
const
;
OperatorDef
*
add_op
();
std
::
vector
<
OperatorDef
>
&
mutable_op
();
const
std
::
vector
<
Argument
>
&
arg
()
const
;
Argument
*
add_arg
();
std
::
vector
<
Argument
>
&
mutable_arg
();
const
std
::
vector
<
TensorProto
>
&
tensors
()
const
;
std
::
vector
<
TensorProto
>
&
mutable_tensors
();
const
MemoryArena
&
mem_arena
()
const
;
bool
has_mem_arena
()
const
;
MemoryArena
&
mutable_mem_arena
();
const
std
::
vector
<
InputInfo
>
&
input_info
()
const
;
const
std
::
vector
<
OutputInfo
>
&
output_info
()
const
;
private:
void
set_has_name
();
void
set_has_version
();
void
set_has_mem_arena
();
private:
std
::
string
name_
;
std
::
string
version_
;
std
::
vector
<
OperatorDef
>
op_
;
std
::
vector
<
Argument
>
arg_
;
std
::
vector
<
TensorProto
>
tensors_
;
// for mem optimization
MemoryArena
mem_arena_
;
// for hexagon mace-nnlib
std
::
vector
<
InputInfo
>
input_info_
;
std
::
vector
<
OutputInfo
>
output_info_
;
uint32_t
has_bits_
;
};
}
// namespace mace
#endif // MACE_CORE_MACE_H_
mace/core/net.cc
浏览文件 @
f05649a4
...
...
@@ -50,7 +50,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
}
}
if
(
!
op
->
Run
())
{
LOG
(
ERROR
)
<<
"Operator failed: "
<<
ProtoDebugString
(
op
->
debug_def
()
);
LOG
(
ERROR
)
<<
"Operator failed: "
<<
op
->
debug_def
().
name
(
);
return
false
;
}
...
...
mace/core/net.h
浏览文件 @
f05649a4
...
...
@@ -8,7 +8,7 @@
#include "mace/core/common.h"
#include "mace/core/operator.h"
#include "mace/core/workspace.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
#include "mace/proto/stats.pb.h"
namespace
mace
{
...
...
mace/core/operator.h
浏览文件 @
f05649a4
...
...
@@ -10,7 +10,7 @@
#include "mace/core/registry.h"
#include "mace/core/tensor.h"
#include "mace/core/workspace.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
...
...
mace/core/proto_utils.cc
浏览文件 @
f05649a4
...
...
@@ -4,163 +4,12 @@
#include "mace/core/proto_utils.h"
#include <fcntl.h>
#include <unistd.h>
#include <cerrno>
#include <fstream>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/text_format.h"
#endif // !MACE_USE_LITE_PROTO
namespace
mace
{
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
)
{
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
in
);
if
(
!
ifs
)
{
VLOG
(
1
)
<<
"File cannot be opened: "
<<
filename
<<
" error: "
<<
ifs
.
rdstate
();
return
false
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
size_t
n
=
ifs
.
tellg
();
str
->
resize
(
n
);
ifs
.
seekg
(
0
);
ifs
.
read
(
&
(
*
str
)[
0
],
n
);
return
true
;
}
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
)
{
std
::
ofstream
ofs
(
filename
,
std
::
ios
::
out
|
std
::
ios
::
trunc
);
if
(
!
ofs
.
is_open
())
{
VLOG
(
1
)
<<
"File cannot be created: "
<<
filename
<<
" error: "
<<
ofs
.
rdstate
();
return
false
;
}
ofs
<<
str
;
return
true
;
}
// IO-specific proto functions: we will deal with the protocol buffer lite and
// full versions differently.
#ifdef MACE_USE_LITE_PROTO
// Lite runtime.
namespace
{
class
IfstreamInputStream
:
public
::
google
::
protobuf
::
io
::
CopyingInputStream
{
public:
explicit
IfstreamInputStream
(
const
string
&
filename
)
:
ifs_
(
filename
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
)
{}
~
IfstreamInputStream
()
{
ifs_
.
close
();
}
int
Read
(
void
*
buffer
,
int
size
)
{
if
(
!
ifs_
)
{
return
-
1
;
}
ifs_
.
read
(
static_cast
<
char
*>
(
buffer
),
size
);
return
ifs_
.
gcount
();
}
private:
std
::
ifstream
ifs_
;
};
}
// namespace
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
::
google
::
protobuf
::
io
::
CopyingInputStreamAdaptor
stream
(
new
IfstreamInputStream
(
filename
));
stream
.
SetOwnsCopyingStream
(
true
);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
::
google
::
protobuf
::
io
::
CodedInputStream
coded_stream
(
&
stream
);
coded_stream
.
SetTotalBytesLimit
(
1024LL
<<
20
,
512LL
<<
20
);
return
proto
->
ParseFromCodedStream
(
&
coded_stream
);
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"Not implemented yet."
;
}
#else // MACE_USE_LITE_PROTO
// Full protocol buffer.
using
::
google
::
protobuf
::
io
::
FileInputStream
;
using
::
google
::
protobuf
::
io
::
FileOutputStream
;
using
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
;
using
::
google
::
protobuf
::
io
::
CodedInputStream
;
using
::
google
::
protobuf
::
io
::
ZeroCopyOutputStream
;
using
::
google
::
protobuf
::
io
::
CodedOutputStream
;
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
)
{
int
fd
=
open
(
filename
,
O_RDONLY
);
MACE_CHECK
(
fd
!=
-
1
,
"File not found: "
,
filename
);
FileInputStream
*
input
=
new
FileInputStream
(
fd
);
bool
success
=
google
::
protobuf
::
TextFormat
::
Parse
(
input
,
proto
);
delete
input
;
close
(
fd
);
return
success
;
}
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
)
{
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
FileOutputStream
*
output
=
new
FileOutputStream
(
fd
);
MACE_CHECK
(
google
::
protobuf
::
TextFormat
::
Print
(
proto
,
output
));
delete
output
;
close
(
fd
);
}
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int
fd
=
open
(
filename
,
O_RDONLY
|
O_BINARY
);
#else
int
fd
=
open
(
filename
,
O_RDONLY
);
#endif
MACE_CHECK
(
fd
!=
-
1
,
"File not found: "
,
filename
);
std
::
unique_ptr
<
ZeroCopyInputStream
>
raw_input
(
new
FileInputStream
(
fd
));
std
::
unique_ptr
<
CodedInputStream
>
coded_input
(
new
CodedInputStream
(
raw_input
.
get
()));
// A hack to manually allow using very large protocol buffers.
coded_input
->
SetTotalBytesLimit
(
1073741824
,
536870912
);
bool
success
=
proto
->
ParseFromCodedStream
(
coded_input
.
get
());
coded_input
.
reset
();
raw_input
.
reset
();
close
(
fd
);
return
success
;
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
)
{
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
MACE_CHECK
(
fd
!=
-
1
,
"File cannot be created: "
,
filename
,
" error number: "
,
errno
);
std
::
unique_ptr
<
ZeroCopyOutputStream
>
raw_output
(
new
FileOutputStream
(
fd
));
std
::
unique_ptr
<
CodedOutputStream
>
coded_output
(
new
CodedOutputStream
(
raw_output
.
get
()));
MACE_CHECK
(
proto
.
SerializeToCodedStream
(
coded_output
.
get
()));
coded_output
.
reset
();
raw_output
.
reset
();
close
(
fd
);
}
#endif // MACE_USE_LITE_PROTO
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
for
(
auto
&
arg
:
def
.
arg
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
MACE_CHECK
(
arg
.
SerializeAsString
()
==
arg_map_
[
arg
.
name
()].
SerializeAsString
(),
"Found argument of the same name '"
,
arg
.
name
(),
"' but with different contents: "
,
ProtoDebugString
(
def
));
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def: "
<<
ProtoDebugString
(
def
)
<<
", arg: "
<<
ProtoDebugString
(
arg
);
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def."
;
}
arg_map_
[
arg
.
name
()]
=
arg
;
...
...
@@ -170,8 +19,7 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
ArgumentHelper
::
ArgumentHelper
(
const
NetDef
&
netdef
)
{
for
(
auto
&
arg
:
netdef
.
arg
())
{
MACE_CHECK
(
arg_map_
.
count
(
arg
.
name
())
==
0
,
"Duplicated argument name found in net def: "
,
ProtoDebugString
(
netdef
));
"Duplicated argument name found in net def."
);
arg_map_
[
arg
.
name
()]
=
arg
;
}
}
...
...
@@ -265,88 +113,4 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT
(
string
,
strings
,
false
)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string &name, const T &value) { \
Argument arg; \
arg.set_name(name); \
arg.set_##fieldname(value); \
return arg; \
}
MACE_MAKE_SINGULAR_ARGUMENT
(
bool
,
i
)
MACE_MAKE_SINGULAR_ARGUMENT
(
float
,
f
)
MACE_MAKE_SINGULAR_ARGUMENT
(
int
,
i
)
MACE_MAKE_SINGULAR_ARGUMENT
(
int64_t
,
i
)
MACE_MAKE_SINGULAR_ARGUMENT
(
string
,
s
)
#undef MACE_MAKE_SINGULAR_ARGUMENT
template
<
>
Argument
MakeArgument
(
const
string
&
name
,
const
MessageLite
&
value
)
{
Argument
arg
;
arg
.
set_name
(
name
);
arg
.
set_s
(
value
.
SerializeAsString
());
return
arg
;
}
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string &name, const vector<T> &value) { \
Argument arg; \
arg.set_name(name); \
for (const auto &v : value) { \
arg.add_##fieldname(v); \
} \
return arg; \
}
MACE_MAKE_REPEATED_ARGUMENT
(
float
,
floats
)
MACE_MAKE_REPEATED_ARGUMENT
(
int
,
ints
)
MACE_MAKE_REPEATED_ARGUMENT
(
int64_t
,
ints
)
MACE_MAKE_REPEATED_ARGUMENT
(
string
,
strings
)
#undef MACE_MAKE_REPEATED_ARGUMENT
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
)
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
if
(
arg
.
name
()
==
name
)
{
return
arg
;
}
}
MACE_CHECK
(
false
,
"Argument named "
,
name
,
"does not exist in operator "
,
ProtoDebugString
(
def
));
// should not reach here, just make compiler happy
return
std
::
move
(
Argument
());
}
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
bool
def_value
)
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
if
(
arg
.
name
()
==
name
)
{
MACE_CHECK
(
arg
.
has_i
(),
"Can't parse argument as bool: "
,
ProtoDebugString
(
arg
));
return
arg
.
i
();
}
}
return
def_value
;
}
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
OperatorDef
*
def
)
{
for
(
int
i
=
0
;
i
<
def
->
arg_size
();
++
i
)
{
if
(
def
->
arg
(
i
).
name
()
==
name
)
{
return
def
->
mutable_arg
(
i
);
}
}
// If no argument of the right name is found...
if
(
create_if_missing
)
{
Argument
*
arg
=
def
->
add_arg
();
arg
->
set_name
(
name
);
return
arg
;
}
else
{
return
nullptr
;
}
}
}
// namespace mace
mace/core/proto_utils.h
浏览文件 @
f05649a4
...
...
@@ -7,137 +7,12 @@
#include <map>
#include "google/protobuf/message_lite.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/message.h"
#endif // !MACE_USE_LITE_PROTO
#include "mace/core/common.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
using
std
::
string
;
using
::
google
::
protobuf
::
MessageLite
;
// Common interfaces that reads file contents into a string.
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
);
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
);
// Common interfaces that are supported by both lite and full protobuf.
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
);
inline
bool
ReadProtoFromBinaryFile
(
const
string
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromBinaryFile
(
filename
.
c_str
(),
proto
);
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
);
inline
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
string
&
filename
)
{
return
WriteProtoToBinaryFile
(
proto
,
filename
.
c_str
());
}
#ifdef MACE_USE_LITE_PROTO
inline
string
ProtoDebugString
(
const
MessageLite
&
proto
)
{
return
proto
.
SerializeAsString
();
}
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
inline
bool
ReadProtoFromTextFile
(
const
char
*
/*filename*/
,
MessageLite
*
/*proto*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
return
false
;
// Just to suppress compiler warning.
}
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
}
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
}
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
proto
,
const
string
&
filename
)
{
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
}
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
ReadProtoFromTextFile
(
filename
,
proto
));
}
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
}
#else // MACE_USE_LITE_PROTO
using
::
google
::
protobuf
::
Message
;
inline
string
ProtoDebugString
(
const
Message
&
proto
)
{
return
proto
.
ShortDebugString
();
}
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
);
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
Message
*
proto
)
{
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
}
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
);
inline
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
string
&
filename
)
{
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
}
// Read Proto from a file, letting the code figure out if it is text or binary.
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
Message
*
proto
)
{
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
ReadProtoFromTextFile
(
filename
,
proto
));
}
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
Message
*
proto
)
{
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
}
#endif // MACE_USE_LITE_PROTO
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>
,
class
IterableArgs
=
std
::
initializer_list
<
Argument
>>
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
const
IterableInputs
&
inputs
,
const
IterableOutputs
&
outputs
,
const
IterableArgs
&
args
)
{
OperatorDef
def
;
def
.
set_type
(
type
);
def
.
set_name
(
name
);
for
(
const
string
&
in
:
inputs
)
{
def
.
add_input
(
in
);
}
for
(
const
string
&
out
:
outputs
)
{
def
.
add_output
(
out
);
}
for
(
const
Argument
&
arg
:
args
)
{
def
.
add_arg
()
->
CopyFrom
(
arg
);
}
return
def
;
}
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>>
inline
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
const
IterableInputs
&
inputs
,
const
IterableOutputs
&
outputs
)
{
return
CreateOperatorDef
(
type
,
name
,
inputs
,
outputs
,
std
::
vector
<
Argument
>
());
}
/**
* @brief A helper class to index into arguments.
...
...
@@ -174,17 +49,6 @@ class ArgumentHelper {
return
ArgumentHelper
(
def
).
GetRepeatedArgument
<
T
>
(
name
,
default_value
);
}
template
<
typename
Def
,
typename
MessageType
>
static
MessageType
GetMessageArgument
(
const
Def
&
def
,
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
GetMessageArgument
<
MessageType
>
(
name
);
}
template
<
typename
Def
,
typename
MessageType
>
static
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
Def
&
def
,
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
GetRepeatedMessageArgument
<
MessageType
>
(
name
);
}
explicit
ArgumentHelper
(
const
OperatorDef
&
def
);
explicit
ArgumentHelper
(
const
NetDef
&
netdef
);
bool
HasArgument
(
const
string
&
name
)
const
;
...
...
@@ -198,51 +62,10 @@ class ArgumentHelper {
const
string
&
name
,
const
std
::
vector
<
T
>
&
default_value
=
std
::
vector
<
T
>
())
const
;
template
<
typename
MessageType
>
MessageType
GetMessageArgument
(
const
string
&
name
)
const
{
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MessageType
message
;
if
(
arg_map_
.
at
(
name
).
has_s
())
{
MACE_CHECK
(
message
.
ParseFromString
(
arg_map_
.
at
(
name
).
s
()),
"Faild to parse content from the string"
);
}
else
{
VLOG
(
1
)
<<
"Return empty message for parameter "
<<
name
;
}
return
message
;
}
template
<
typename
MessageType
>
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
string
&
name
)
const
{
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
vector
<
MessageType
>
messages
(
arg_map_
.
at
(
name
).
strings_size
());
for
(
int
i
=
0
;
i
<
messages
.
size
();
++
i
)
{
MACE_CHECK
(
messages
[
i
].
ParseFromString
(
arg_map_
.
at
(
name
).
strings
(
i
)),
"Faild to parse content from the string"
);
}
return
messages
;
}
private:
std
::
map
<
string
,
Argument
>
arg_map_
;
};
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
);
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
bool
def_value
=
false
);
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
OperatorDef
*
def
);
template
<
typename
T
>
Argument
MakeArgument
(
const
string
&
name
,
const
T
&
value
);
template
<
typename
T
>
inline
void
AddArgument
(
const
string
&
name
,
const
T
&
value
,
OperatorDef
*
def
)
{
GetMutableArgument
(
name
,
true
,
def
)
->
CopyFrom
(
MakeArgument
(
name
,
value
));
}
}
// namespace mace
#endif // MACE_CORE_PROTO_UTILS_H_
mace/core/serializer.cc
浏览文件 @
f05649a4
...
...
@@ -24,46 +24,32 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch
(
proto
.
data_type
())
{
case
DT_FLOAT
:
tensor
->
Copy
<
float
>
(
proto
.
float_data
().
data
(),
proto
.
float_data
().
size
());
tensor
->
Copy
<
float
>
(
reinterpret_cast
<
float
*>
(
proto
.
data
()),
proto
.
data_
size
());
break
;
case
DT_DOUBLE
:
tensor
->
Copy
<
double
>
(
proto
.
double_data
().
data
(),
proto
.
double_data
().
size
());
tensor
->
Copy
<
double
>
(
reinterpret_cast
<
double
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT32
:
tensor
->
template
Copy
<
int32_t
>(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
tensor
->
Copy
<
int32_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT64
:
tensor
->
Copy
<
int64_t
>
(
reinterpret_cast
<
int64_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_UINT8
:
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT16
:
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_INT8
:
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
break
;
case
DT_INT64
:
tensor
->
Copy
<
int64_t
>
(
proto
.
int64_data
().
data
(),
proto
.
int64_data
().
size
());
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_UINT16
:
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_BOOL
:
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
reinterpret_cast
<
int32_t
*>
(
proto
.
data
()),
proto
.
data_size
());
break
;
case
DT_STRING
:
{
string
*
content
=
tensor
->
mutable_data
<
string
>
();
for
(
int
i
=
0
;
i
<
proto
.
string_data
().
size
();
++
i
)
{
content
[
i
]
=
proto
.
string_data
(
i
);
}
}
break
;
default:
MACE_NOT_IMPLEMENTED
;
break
;
...
...
mace/core/serializer.h
浏览文件 @
f05649a4
...
...
@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
...
...
mace/core/tensor.h
浏览文件 @
f05649a4
...
...
@@ -9,7 +9,7 @@
#include "mace/core/common.h"
#include "mace/core/logging.h"
#include "mace/core/types.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
...
...
mace/core/types.h
浏览文件 @
f05649a4
...
...
@@ -6,7 +6,7 @@
#define MACE_CORE_TYPES_H_
#include "mace/core/common.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
#include "mace/core/half.h"
...
...
mace/core/workspace.cc
浏览文件 @
f05649a4
...
...
@@ -69,7 +69,7 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
}
void
Workspace
::
CreateImageOutputTensor
(
const
NetDef
&
net_def
)
{
if
(
!
net_def
.
has_mem_arena
()
||
net_def
.
mem_arena
().
mem_block_size
()
==
0
)
{
if
(
net_def
.
has_mem_arena
()
||
net_def
.
mem_arena
().
mem_block_size
()
==
0
)
{
return
;
}
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
mem_tensor_map
;
...
...
mace/core/workspace.h
浏览文件 @
f05649a4
...
...
@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
...
...
mace/examples/mace_run.cc
浏览文件 @
f05649a4
...
...
@@ -20,6 +20,8 @@
using
namespace
std
;
using
namespace
mace
;
extern
NetDef
CreateNet
()
;
void
ParseShape
(
const
string
&
str
,
vector
<
index_t
>
*
shape
)
{
string
tmp
=
str
;
while
(
!
tmp
.
empty
())
{
...
...
@@ -34,6 +36,18 @@ void ParseShape(const string &str, vector<index_t> *shape) {
}
}
DeviceType
ParseDeviceType
(
const
string
&
device_str
)
{
if
(
device_str
.
compare
(
"CPU"
)
==
0
)
{
return
DeviceType
::
CPU
;
}
else
if
(
device_str
.
compare
(
"NEON"
)
==
0
)
{
return
DeviceType
::
NEON
;
}
else
if
(
device_str
.
compare
(
"OPENCL"
)
==
0
)
{
return
DeviceType
::
OPENCL
;
}
else
{
return
DeviceType
::
CPU
;
}
}
int
main
(
int
argc
,
char
**
argv
)
{
string
model_file
;
string
input_node
;
...
...
@@ -76,13 +90,13 @@ int main(int argc, char **argv) {
ParseShape
(
input_shape
,
&
shape
);
// load model
ifstream
file_stream
(
model_file
,
ios
::
in
|
ios
::
binary
);
NetDef
net_def
;
net_def
.
ParseFromIstream
(
&
file_stream
);
file_stream
.
close
();
// ifstream file_stream(model_file, ios::in | ios::binary);
// NetDef net_def;
// net_def.ParseFromIstream(&file_stream);
// file_stream.close();
NetDef
net_def
=
CreateNet
();
DeviceType
device_type
;
DeviceType_Parse
(
device
,
&
device_type
);
DeviceType
device_type
=
ParseDeviceType
(
device
);
VLOG
(
0
)
<<
device_type
;
Workspace
ws
;
ws
.
LoadModelTensor
(
net_def
,
device_type
);
...
...
mace/kernels/batch_norm.h
浏览文件 @
f05649a4
...
...
@@ -6,7 +6,7 @@
#define MACE_KERNELS_BATCH_NORM_H_
#include "mace/core/tensor.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
namespace
kernels
{
...
...
mace/kernels/bias_add.h
浏览文件 @
f05649a4
...
...
@@ -6,7 +6,7 @@
#define MACE_KERNELS_BIAS_ADD_H_
#include "mace/core/tensor.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
namespace
kernels
{
...
...
mace/kernels/concat.h
浏览文件 @
f05649a4
...
...
@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/types.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
#include "mace/core/tensor.h"
namespace
mace
{
...
...
mace/kernels/depthwise_conv2d.h
浏览文件 @
f05649a4
...
...
@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
namespace
kernels
{
...
...
mace/kernels/space_to_batch.h
浏览文件 @
f05649a4
...
...
@@ -6,7 +6,7 @@
#define MACE_KERNELS_CONV_2D_H_
#include "mace/core/tensor.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
namespace
kernels
{
...
...
mace/ops/BUILD
浏览文件 @
f05649a4
...
...
@@ -40,7 +40,6 @@ cc_library(
],
deps
=
[
"//mace/kernels"
,
"//mace/proto:cc_proto"
,
],
alwayslink
=
1
,
)
...
...
mace/ops/concat.h
浏览文件 @
f05649a4
...
...
@@ -7,7 +7,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
#include "mace/
proto/mace.pb
.h"
#include "mace/
core/mace
.h"
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
...
...
mace/ops/ops_test_util.h
浏览文件 @
f05649a4
...
...
@@ -159,7 +159,6 @@ class OpsTestNet {
for
(
auto
&
op_def_
:
op_defs_
)
{
net_def
.
add_op
()
->
CopyFrom
(
op_def_
);
}
VLOG
(
3
)
<<
net_def
.
DebugString
();
net_
=
CreateNet
(
net_def
,
&
ws_
,
device
);
device_
=
device
;
return
net_
->
Run
();
...
...
mace/python/tools/model.template
0 → 100644
浏览文件 @
f05649a4
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <vector>
#include <string>
#include "mace/core/mace.h"
namespace mace {
{% for tensor in tensors %}
static unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = {
{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%}
};
{% endfor %}
static void CreateNetArg(NetDef &net_def) {
net_def.mutable_arg().reserve({{ net.arg|length }});
Argument *arg = nullptr;
{% for arg in net.arg %}
arg = net_def.add_arg();
arg->set_name({{ arg.name|tojson }});
{% if arg.has_f %}
arg->set_f({{ arg.f }});
{% endif %}
{% if arg.has_i %}
arg->set_i({{ arg.i }});
{% endif %}
{% if arg.has_s %}
arg->set_s({{ arg.s|tojson }});
{% endif %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
arg->set_ints({ {{ arg.ints|join(', ') }} });
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endfor %}
}
static void UpdateOp(OperatorDef &op,
const std::string &name,
const std::string &type,
const int mem_id,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const std::vector<OutputShape> &output_shapes,
const std::vector<DataType> &output_types) {
op.set_name(name);
op.set_type(type);
op.set_input(inputs);
op.set_output(outputs);
op.set_mem_id(mem_id);
op.set_output_shape(output_shapes);
op.set_output_type(output_types);
}
static void CreateOperators(std::vector<OperatorDef> &ops) {
ops.resize({{ net.op|length }});
Argument *arg = nullptr;
{% for i in range(net.op|length) %}
{% for arg in net.op[i].arg %}
arg = ops[{{i}}].add_arg();
arg->set_name({{ arg.name|tojson }});
{%- if arg.HasField('f') %}
arg->set_f({{ arg.f }});
{%- endif %}
{%- if arg.HasField('i') %}
arg->set_i({{ arg.i }});
{%- endif %}
{%- if arg.HasField('s') %}
arg->set_s({{ arg.s|tojson }});
{%- endif %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
arg->set_ints({ {{ arg.ints|join(', ') }} });
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endfor %}
UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }},
{ {{ net.op[i].input|stringfy }} },
{ {{ net.op[i].output|stringfy }} },
{ {{ net.op[i].output_shape.dims|join(', ') }} },
{ {{ net.op[i].output_type|join(', ') }} });
{% endfor %}
}
static void CreateTensors(std::vector<TensorProto> &tensors) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
tensors.emplace_back(TensorProto(
{{ tensor.name|tojson }}, {{ "_" + tensor.name[:-2].replace("/", "_") }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }},
{{ tensor.node_id }}
));
{% endfor %}
}
static void CreateMemoryArena(MemoryArena &mem_arena) {
auto mem_block = mem_arena.mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }});
{% for mem_blk in net.mem_arena.mem_block %}
mem_block.emplace_back(MemoryBlock({{ mem_blk.mem_id }},
{{mem_blk.x}},
{{mem_blk.y}}));
{% endfor %}
}
NetDef CreateNet() {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
CreateNetArg(net_def);
CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors());
CreateMemoryArena(net_def.mutable_mem_arena());
return net_def;
}
} // namespace mace
mace/python/tools/tf_converter.py
浏览文件 @
f05649a4
...
...
@@ -2,13 +2,45 @@ import argparse
import
sys
import
tensorflow
as
tf
from
tensorflow
import
gfile
from
mace.proto
import
mace_pb2
from
mace.python.tools
import
tf_converter_lib
from
mace.python.tools
import
tf_dsp_converter_lib
import
struct
from
jinja2
import
Environment
,
FileSystemLoader
import
os
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS
=
None
class
TensorInfo
:
def
__init__
(
self
,
t
):
self
.
name
=
t
.
name
if
t
.
data_type
==
mace_pb2
.
DT_FLOAT
:
self
.
data
=
bytearray
(
struct
.
pack
(
'%sf'
%
len
(
t
.
float_data
),
*
t
.
float_data
))
elif
t
.
data_type
==
mace_pb2
.
DT_INT32
:
self
.
data
=
bytearray
(
struct
.
pack
(
'%si'
%
len
(
t
.
int32_data
),
*
t
.
int32_data
))
def
stringfy
(
value
):
return
', '
.
join
(
'"{0}"'
.
format
(
w
)
for
w
in
value
)
def
convert_to_source
(
net_def
):
# Capture our current directory
template_dir
=
os
.
path
.
dirname
(
FLAGS
.
template
)
template_name
=
os
.
path
.
basename
(
FLAGS
.
template
)
print
template_dir
# Create the jinja2 environment.
# Notice the use of trim_blocks, which greatly helps control whitespace.
j2_env
=
Environment
(
loader
=
FileSystemLoader
(
template_dir
),
trim_blocks
=
True
)
j2_env
.
filters
[
'stringfy'
]
=
stringfy
tensors
=
[
TensorInfo
(
t
)
for
t
in
net_def
.
tensors
]
return
j2_env
.
get_template
(
template_name
).
render
(
tensors
=
tensors
,
net
=
net_def
)
def
main
(
unused_args
):
if
not
gfile
.
Exists
(
FLAGS
.
input
):
print
(
"Input graph file '"
+
FLAGS
.
input
+
"' does not exist!"
)
...
...
@@ -19,6 +51,7 @@ def main(unused_args):
data
=
f
.
read
()
input_graph_def
.
ParseFromString
(
data
)
print
'done'
if
FLAGS
.
runtime
==
'dsp'
:
output_graph_def
=
tf_dsp_converter_lib
.
convert_to_mace_pb
(
input_graph_def
,
FLAGS
.
input_node
,
FLAGS
.
output_node
,
FLAGS
.
prequantize
)
...
...
@@ -26,11 +59,16 @@ def main(unused_args):
output_graph_def
=
tf_converter_lib
.
convert_to_mace_pb
(
input_graph_def
,
FLAGS
.
input_node
,
FLAGS
.
output_node
,
FLAGS
.
data_type
,
FLAGS
.
runtime
)
with
gfile
.
GFile
(
FLAGS
.
output
,
"wb"
)
as
f
:
f
.
write
(
output_graph_def
.
SerializeToString
())
with
gfile
.
GFile
(
FLAGS
.
output
+
'_txt'
,
"wb"
)
as
f
:
# output_graph_def.ClearField('tensors')
f
.
write
(
str
(
output_graph_def
))
if
FLAGS
.
output_type
==
'source'
:
source
=
convert_to_source
(
output_graph_def
)
with
gfile
.
GFile
(
FLAGS
.
output
,
"wb"
)
as
f
:
f
.
write
(
source
)
else
:
with
gfile
.
GFile
(
FLAGS
.
output
,
"wb"
)
as
f
:
f
.
write
(
output_graph_def
.
SerializeToString
())
with
gfile
.
GFile
(
FLAGS
.
output
+
'_txt'
,
"wb"
)
as
f
:
# output_graph_def.ClearField('tensors')
f
.
write
(
str
(
output_graph_def
))
def
parse_args
():
...
...
@@ -51,7 +89,7 @@ def parse_args():
"--runtime"
,
type
=
str
,
default
=
"cpu"
,
help
=
"Runtime: cpu/gpu/dsp
.
"
)
help
=
"Runtime: cpu/gpu/dsp"
)
parser
.
add_argument
(
"--input_node"
,
type
=
str
,
...
...
@@ -72,6 +110,16 @@ def parse_args():
type
=
str
,
default
=
'DT_FLOAT'
,
help
=
"e.g., DT_HALF/DT_FLOAT"
)
parser
.
add_argument
(
"--output_type"
,
type
=
str
,
default
=
"source"
,
help
=
"output type: source/pb"
)
parser
.
add_argument
(
"--template"
,
type
=
str
,
default
=
""
,
help
=
"template path"
)
return
parser
.
parse_known_args
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录