Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
d2c2897a
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d2c2897a
编写于
2月 27, 2019
作者:
L
liyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Quantize matmul only, add gather u8 test.
上级
0c3cc381
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
213 addition
and
42 deletion
+213
-42
mace/ops/eltwise.cc
mace/ops/eltwise.cc
+3
-1
mace/ops/gather_test.cc
mace/ops/gather_test.cc
+23
-9
mace/ops/matmul.cc
mace/ops/matmul.cc
+18
-16
mace/ops/matmul_test.cc
mace/ops/matmul_test.cc
+2
-1
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+2
-0
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+12
-3
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+153
-12
未找到文件。
mace/ops/eltwise.cc
浏览文件 @
d2c2897a
...
@@ -926,7 +926,9 @@ class EltwiseOp : public Operation {
...
@@ -926,7 +926,9 @@ class EltwiseOp : public Operation {
const
Tensor
*
input1
,
const
Tensor
*
input1
,
Tensor
*
output
)
{
Tensor
*
output
)
{
bool
swapped
=
false
;
bool
swapped
=
false
;
if
(
input0
->
size
()
<
input1
->
size
())
{
if
(
input0
->
dim_size
()
<
input1
->
dim_size
()
||
(
input0
->
dim_size
()
==
input1
->
dim_size
()
&&
input0
->
size
()
<
input1
->
size
()))
{
std
::
swap
(
input0
,
input1
);
std
::
swap
(
input0
,
input1
);
swapped
=
true
;
swapped
=
true
;
}
}
...
...
mace/ops/gather_test.cc
浏览文件 @
d2c2897a
...
@@ -23,53 +23,67 @@ namespace test {
...
@@ -23,53 +23,67 @@ namespace test {
class
GatherOpTest
:
public
OpsTestBase
{};
class
GatherOpTest
:
public
OpsTestBase
{};
namespace
{
namespace
{
template
<
typename
T
>
void
TestGather
(
const
std
::
vector
<
index_t
>
&
weight_shape
,
void
TestGather
(
const
std
::
vector
<
index_t
>
&
weight_shape
,
const
std
::
vector
<
float
>
&
weight
,
const
std
::
vector
<
T
>
&
weight
,
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
int32_t
>
&
input
,
const
std
::
vector
<
int32_t
>
&
input
,
const
int
axis
,
const
int
axis
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
const
std
::
vector
<
T
>
&
output
)
{
OpsTestNet
net
;
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Params"
,
weight_shape
,
weight
);
net
.
AddInputFromArray
<
CPU
,
T
>
(
"Params"
,
weight_shape
,
weight
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Indices"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Indices"
,
input_shape
,
input
);
OpDefBuilder
(
"Gather"
,
"GatherTest"
)
OpDefBuilder
(
"Gather"
,
"GatherTest"
)
.
Input
(
"Params"
)
.
Input
(
"Params"
)
.
Input
(
"Indices"
)
.
Input
(
"Indices"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"axis"
,
axis
)
.
AddIntArg
(
"axis"
,
axis
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
.
Finalize
(
net
.
NewOperatorDef
());
// Run
// Run
net
.
RunOp
(
CPU
);
net
.
RunOp
(
CPU
);
auto
expected
=
net
.
CreateTensor
<
float
>
(
output_shape
,
output
);
auto
expected
=
net
.
CreateTensor
<
T
>
(
output_shape
,
output
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
ExpectTensorNear
<
T
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
}
}
}
// namespace
}
// namespace
TEST_F
(
GatherOpTest
,
CPUScalarIndex
)
{
TEST_F
(
GatherOpTest
,
CPUScalarIndex
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
TestGather
<
float
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{},
{
5
},
0
,
{
2
},
{
10
,
11
});
TestGather
<
uint8_t
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{},
{
5
},
0
,
{
2
},
{
10
,
11
});
{},
{
5
},
0
,
{
2
},
{
10
,
11
});
}
}
TEST_F
(
GatherOpTest
,
CPURank1Index
)
{
TEST_F
(
GatherOpTest
,
CPURank1Index
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
TestGather
<
float
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
3
},
{
2
,
4
,
6
},
0
,
{
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
TestGather
<
uint8_t
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
3
},
{
2
,
4
,
6
},
0
,
{
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
{
3
},
{
2
,
4
,
6
},
0
,
{
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
}
}
TEST_F
(
GatherOpTest
,
CPURank1IndexWithAxis1
)
{
TEST_F
(
GatherOpTest
,
CPURank1IndexWithAxis1
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
TestGather
<
float
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
},
{
1
},
1
,
{
10
,
1
},
{
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
,
17
,
19
});
TestGather
<
uint8_t
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
},
{
1
},
1
,
{
10
,
1
},
{
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
,
17
,
19
});
{
1
},
{
1
},
1
,
{
10
,
1
},
{
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
,
17
,
19
});
}
}
TEST_F
(
GatherOpTest
,
CPURankHighIndex
)
{
TEST_F
(
GatherOpTest
,
CPURankHighIndex
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
TestGather
<
float
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
,
3
},
{
2
,
4
,
6
},
0
,
{
1
,
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
TestGather
<
uint8_t
>
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
,
3
},
{
2
,
4
,
6
},
0
,
{
1
,
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
{
1
,
3
},
{
2
,
4
,
6
},
0
,
{
1
,
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
}
}
...
...
mace/ops/matmul.cc
浏览文件 @
d2c2897a
...
@@ -233,8 +233,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
...
@@ -233,8 +233,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
const
bool
lhs_bached
,
const
bool
lhs_ba
t
ched
,
const
bool
rhs_bached
,
const
bool
rhs_ba
t
ched
,
Tensor
*
C
)
{
Tensor
*
C
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
if
(
width
==
1
&&
AOrder
==
gemmlowp
::
MapOrder
::
RowMajor
)
{
if
(
width
==
1
&&
AOrder
==
gemmlowp
::
MapOrder
::
RowMajor
)
{
...
@@ -245,8 +245,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
...
@@ -245,8 +245,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
batch
,
batch
,
height
,
height
,
K
,
K
,
true
,
lhs_batched
,
true
,
rhs_batched
,
C
);
C
);
}
else
if
(
height
==
1
&&
BOrder
==
gemmlowp
::
MapOrder
::
ColMajor
)
{
}
else
if
(
height
==
1
&&
BOrder
==
gemmlowp
::
MapOrder
::
ColMajor
)
{
gemv_kernel_
.
Compute
(
context
,
gemv_kernel_
.
Compute
(
context
,
...
@@ -256,8 +256,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
...
@@ -256,8 +256,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
batch
,
batch
,
width
,
width
,
K
,
K
,
true
,
lhs_batched
,
true
,
rhs_batched
,
C
);
C
);
}
else
{
}
else
{
#endif // MACE_ENABLE_NEON
#endif // MACE_ENABLE_NEON
...
@@ -281,11 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
...
@@ -281,11 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
a_matrix
(
a_ptr_base
+
static_cast
<
index_t
>
(
lhs_bached
)
*
i
*
a_size
,
a_matrix
(
a_ptr_base
+
static_cast
<
index_t
>
(
lhs_batched
)
*
i
*
a_size
,
height
,
height
,
K
);
K
);
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
b_matrix
(
b_ptr_base
+
static_cast
<
index_t
>
(
rhs_bached
)
*
i
*
b_size
,
b_matrix
(
b_ptr_base
+
static_cast
<
index_t
>
(
rhs_batched
)
*
i
*
b_size
,
K
,
K
,
width
);
width
);
gemmlowp
::
MatrixMap
<
uint8_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
gemmlowp
::
MatrixMap
<
uint8_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
...
@@ -315,8 +317,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
...
@@ -315,8 +317,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
const
bool
lhs_bached
,
const
bool
lhs_ba
t
ched
,
const
bool
rhs_bached
,
const
bool
rhs_ba
t
ched
,
Tensor
*
C
)
{
Tensor
*
C
)
{
C
->
SetScale
(
A
->
scale
()
*
B
->
scale
());
C
->
SetScale
(
A
->
scale
()
*
B
->
scale
());
C
->
SetZeroPoint
(
0
);
C
->
SetZeroPoint
(
0
);
...
@@ -330,8 +332,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
...
@@ -330,8 +332,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
batch
,
batch
,
height
,
height
,
K
,
K
,
lhs_bached
,
lhs_ba
t
ched
,
rhs_bached
,
rhs_ba
t
ched
,
C
);
C
);
}
else
if
(
height
==
1
&&
BOrder
==
gemmlowp
::
MapOrder
::
ColMajor
)
{
}
else
if
(
height
==
1
&&
BOrder
==
gemmlowp
::
MapOrder
::
ColMajor
)
{
gemv_kernel_
.
Compute
(
context
,
gemv_kernel_
.
Compute
(
context
,
...
@@ -341,8 +343,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
...
@@ -341,8 +343,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
batch
,
batch
,
width
,
width
,
K
,
K
,
lhs_bached
,
lhs_ba
t
ched
,
rhs_bached
,
rhs_ba
t
ched
,
C
);
C
);
}
else
{
}
else
{
#endif // MACE_ENABLE_NEON
#endif // MACE_ENABLE_NEON
...
@@ -366,12 +368,12 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
...
@@ -366,12 +368,12 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
a_matrix
a_matrix
(
a_ptr_base
+
static_cast
<
index_t
>
(
lhs_bached
)
*
i
*
a_size
,
(
a_ptr_base
+
static_cast
<
index_t
>
(
lhs_ba
t
ched
)
*
i
*
a_size
,
height
,
height
,
K
);
K
);
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
b_matrix
b_matrix
(
b_ptr_base
+
static_cast
<
index_t
>
(
rhs_bached
)
*
i
*
b_size
,
(
b_ptr_base
+
static_cast
<
index_t
>
(
rhs_ba
t
ched
)
*
i
*
b_size
,
K
,
K
,
width
);
width
);
gemmlowp
::
MatrixMap
<
int32_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
gemmlowp
::
MatrixMap
<
int32_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
...
...
mace/ops/matmul_test.cc
浏览文件 @
d2c2897a
...
@@ -135,7 +135,8 @@ void Complex(const std::vector<index_t> &batch,
...
@@ -135,7 +135,8 @@ void Complex(const std::vector<index_t> &batch,
rhs_batched
,
rhs_batched
,
&
expected_output_tensor
);
&
expected_output_tensor
);
ExpectTensorNear
<
float
>
(
expected_output_tensor
,
*
net
.
GetTensor
(
"Output"
));
ExpectTensorNear
<
float
>
(
expected_output_tensor
,
*
net
.
GetTensor
(
"Output"
),
1e-4
,
1e-2
);
}
}
}
// namespace
}
// namespace
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
d2c2897a
...
@@ -236,6 +236,7 @@ class MaceKeyword(object):
...
@@ -236,6 +236,7 @@ class MaceKeyword(object):
mace_step_h_str
=
'step_h'
mace_step_h_str
=
'step_h'
mace_step_w_str
=
'step_w'
mace_step_w_str
=
'step_w'
mace_find_range_every_time
=
'find_range_every_time'
mace_find_range_every_time
=
'find_range_every_time'
mace_non_zero
=
'non_zero'
mace_pad_type_str
=
'pad_type'
mace_pad_type_str
=
'pad_type'
...
@@ -279,6 +280,7 @@ class TransformerRule(Enum):
...
@@ -279,6 +280,7 @@ class TransformerRule(Enum):
FOLD_FC_RESHAPE
=
37
FOLD_FC_RESHAPE
=
37
TRANSFORM_CHANNEL_SHUFFLE
=
38
TRANSFORM_CHANNEL_SHUFFLE
=
38
UPDATE_DATA_FORMAT
=
39
UPDATE_DATA_FORMAT
=
39
QUANTIZE_MATMUL_ONLY
=
40
class
ConverterInterface
(
object
):
class
ConverterInterface
(
object
):
...
...
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
d2c2897a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -288,12 +288,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
...
@@ -288,12 +288,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def
.
ParseFromString
(
f
.
read
())
tf_graph_def
.
ParseFromString
(
f
.
read
())
self
.
_placeholders
=
{}
self
.
_placeholders
=
{}
self
.
add_shape_info
(
tf_graph_def
)
print
(
"Run transform_graph: %s"
%
TFTransformGraphOptions
[
print
(
"Run transform_graph: %s"
%
TFTransformGraphOptions
[
option
.
device
])
option
.
device
])
try
:
try
:
print
(
"output keys: "
,
option
.
output_nodes
.
keys
())
print
(
"output keys: "
,
option
.
output_nodes
.
keys
())
transformed_graph_def
=
TransformGraph
(
tf_graph_def
,
transformed_graph_def
=
TransformGraph
(
tf_graph_def
,
option
.
input_nodes
.
keys
(),
option
.
input_nodes
.
keys
(),
option
.
output_nodes
.
keys
(),
option
.
output_nodes
.
keys
(),
...
@@ -303,6 +302,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
...
@@ -303,6 +302,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
print
(
"Failed to transform graph using tf tool: %s"
%
ex
)
print
(
"Failed to transform graph using tf tool: %s"
%
ex
)
transformed_graph_def
=
tf_graph_def
transformed_graph_def
=
tf_graph_def
# To check optimized model, uncomment following code.
# tf.io.write_graph(
# transformed_graph_def,
# ".",
# os.path.basename(src_model_file)[:-3] + "_opt.pb",
# as_text=False
# )
self
.
add_shape_info
(
transformed_graph_def
)
with
tf
.
Session
()
as
session
:
with
tf
.
Session
()
as
session
:
with
session
.
graph
.
as_default
()
as
graph
:
with
session
.
graph
.
as_default
()
as
graph
:
tf
.
import_graph_def
(
transformed_graph_def
,
name
=
''
)
tf
.
import_graph_def
(
transformed_graph_def
,
name
=
''
)
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
d2c2897a
...
@@ -103,6 +103,8 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -103,6 +103,8 @@ class Transformer(base_converter.ConverterInterface):
self
.
transform_caffe_reshape_and_flatten
,
self
.
transform_caffe_reshape_and_flatten
,
TransformerRule
.
TRANSFORM_CHANNEL_SHUFFLE
:
TransformerRule
.
TRANSFORM_CHANNEL_SHUFFLE
:
self
.
transform_channel_shuffle
,
self
.
transform_channel_shuffle
,
TransformerRule
.
QUANTIZE_MATMUL_ONLY
:
self
.
quantize_matmul_only
,
}
}
self
.
_option
=
option
self
.
_option
=
option
...
@@ -191,16 +193,23 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -191,16 +193,23 @@ class Transformer(base_converter.ConverterInterface):
op
=
mace_pb2
.
OperatorDef
()
op
=
mace_pb2
.
OperatorDef
()
op
.
name
=
self
.
normalize_op_name
(
input_node
.
name
)
op
.
name
=
self
.
normalize_op_name
(
input_node
.
name
)
op
.
type
=
"Input"
op
.
type
=
"Input"
data_type_arg
=
op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
i
=
mace_pb2
.
DT_FLOAT
op
.
output
.
extend
([
input_node
.
name
])
op
.
output
.
extend
([
input_node
.
name
])
output_shape
=
op
.
output_shape
.
add
()
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
input_node
.
shape
)
output_shape
.
dims
.
extend
(
input_node
.
shape
)
if
ConverterUtil
.
data_format
(
if
input_node
in
self
.
_consumers
:
self
.
_consumers
[
input_node
.
name
][
0
])
\
if
ConverterUtil
.
data_format
(
==
DataFormat
.
NCHW
:
self
.
_consumers
[
input_node
.
name
][
0
])
\
self
.
transpose_shape
(
output_shape
.
dims
,
[
0
,
3
,
1
,
2
])
==
DataFormat
.
NCHW
:
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
self
.
transpose_shape
(
output_shape
.
dims
,
else
:
[
0
,
3
,
1
,
2
])
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
else
:
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
self
.
_producer
[
op
.
output
[
0
]]
=
op
self
.
_producer
[
op
.
output
[
0
]]
=
op
@
staticmethod
@
staticmethod
...
@@ -221,10 +230,32 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -221,10 +230,32 @@ class Transformer(base_converter.ConverterInterface):
return
name
.
replace
(
':'
,
'_'
)
return
name
.
replace
(
':'
,
'_'
)
def
get_tensor_shape
(
self
,
tensor
):
def
get_tensor_shape
(
self
,
tensor
):
producer
=
self
.
_producer
[
tensor
]
if
tensor
in
self
.
_consts
:
for
i
in
six
.
moves
.
range
(
len
(
producer
.
output
)):
return
list
(
self
.
_consts
[
tensor
].
dims
)
if
producer
.
output
[
i
]
==
tensor
:
elif
tensor
in
self
.
_producer
:
return
list
(
producer
.
output_shape
[
i
].
dims
)
producer
=
self
.
_producer
[
tensor
]
for
i
in
six
.
moves
.
range
(
len
(
producer
.
output
)):
if
producer
.
output
[
i
]
==
tensor
:
return
list
(
producer
.
output_shape
[
i
].
dims
)
else
:
return
None
def
get_tensor_data_type
(
self
,
tensor
):
if
tensor
in
self
.
_consts
:
return
self
.
_consts
[
tensor
].
data_type
elif
tensor
in
self
.
_producer
:
producer
=
self
.
_producer
[
tensor
]
for
i
in
six
.
moves
.
range
(
len
(
producer
.
output
)):
if
producer
.
output
[
i
]
==
tensor
:
if
i
<
len
(
producer
.
output_type
):
return
producer
.
output_type
[
i
]
elif
ConverterUtil
.
get_arg
(
producer
,
"T"
)
is
not
None
:
return
ConverterUtil
.
get_arg
(
producer
,
"T"
).
i
else
:
print
(
"No data type filled: "
,
producer
)
return
None
else
:
return
None
def
consumer_count
(
self
,
tensor_name
):
def
consumer_count
(
self
,
tensor_name
):
return
len
(
self
.
_consumers
.
get
(
tensor_name
,
[]))
return
len
(
self
.
_consumers
.
get
(
tensor_name
,
[]))
...
@@ -1374,6 +1405,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1374,6 +1405,7 @@ class Transformer(base_converter.ConverterInterface):
return
False
return
False
def
update_data_format
(
self
):
def
update_data_format
(
self
):
print
(
"update data format"
)
data_format_flag
=
DataFormat
.
NHWC
.
value
data_format_flag
=
DataFormat
.
NHWC
.
value
for
input_node
in
self
.
_option
.
input_nodes
.
values
():
for
input_node
in
self
.
_option
.
input_nodes
.
values
():
if
input_node
.
data_format
.
value
==
DataFormat
.
DF_NONE
.
value
:
if
input_node
.
data_format
.
value
==
DataFormat
.
DF_NONE
.
value
:
...
@@ -1672,7 +1704,8 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1672,7 +1704,8 @@ class Transformer(base_converter.ConverterInterface):
quantize_util
.
adjust_range
(
input_node
.
range
[
0
],
quantize_util
.
adjust_range
(
input_node
.
range
[
0
],
input_node
.
range
[
1
],
input_node
.
range
[
1
],
non_zero
=
False
)
non_zero
=
False
)
quantize_info
=
mace_pb2
.
QuantizeActivationInfo
()
quantize_info
=
\
mace_pb2
.
QuantizeActivationInfo
()
quantize_info
.
minval
=
minval
quantize_info
.
minval
=
minval
quantize_info
.
maxval
=
maxval
quantize_info
.
maxval
=
maxval
quantize_info
.
scale
=
scale
quantize_info
.
scale
=
scale
...
@@ -1893,3 +1926,111 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1893,3 +1926,111 @@ class Transformer(base_converter.ConverterInterface):
producer_op
.
output_shape
[
0
].
dims
[:]
=
output_shape
producer_op
.
output_shape
[
0
].
dims
[:]
=
output_shape
return
True
return
True
def
quantize_matmul_only
(
self
):
"""
This transform rule is only used internally, we are not gonna make
things too complex for users
"""
to_quantize_ops
=
[
MaceOp
.
MatMul
.
name
]
for
op
in
self
.
_model
.
op
:
if
(
op
.
type
not
in
to_quantize_ops
or
len
(
op
.
output
)
>
1
or
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_op_data_type_str
).
i
!=
mace_pb2
.
DT_FLOAT
):
# noqa
# only support single output
continue
quantized_inputs_names
=
[]
should_quantize
=
True
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
if
self
.
get_tensor_data_type
(
input_tensor
)
\
!=
mace_pb2
.
DT_FLOAT
:
should_quantize
=
False
break
if
not
should_quantize
:
continue
non_zero
=
self
.
_option
.
device
==
DeviceType
.
CPU
.
value
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
quantized_inputs_names
.
append
(
input_tensor
)
if
input_tensor
in
self
.
_consts
:
const_tensor
=
self
.
_consts
[
input_tensor
]
quantized_tensor
=
quantize_util
.
quantize
(
const_tensor
.
float_data
,
non_zero
)
del
const_tensor
.
float_data
[:]
const_tensor
.
int32_data
.
extend
(
quantized_tensor
.
data
)
const_tensor
.
data_type
=
mace_pb2
.
DT_UINT8
const_tensor
.
scale
=
quantized_tensor
.
scale
const_tensor
.
zero_point
=
quantized_tensor
.
zero
const_tensor
.
minval
=
quantized_tensor
.
minval
const_tensor
.
maxval
=
quantized_tensor
.
maxval
const_tensor
.
quantized
=
True
else
:
input_shape
=
self
.
get_tensor_shape
(
input_tensor
)
quantize_op
=
self
.
_model
.
op
.
add
()
quantize_op
.
name
=
self
.
normalize_op_name
(
input_tensor
)
+
"_quant"
quantize_op
.
type
=
MaceOp
.
Quantize
.
name
quantize_op
.
input
.
extend
([
input_tensor
])
quantize_output_name
=
quantize_op
.
name
+
'_0'
quantize_op
.
output
.
extend
([
quantize_output_name
])
output_shape
=
quantize_op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
input_shape
)
quantize_op
.
output_type
.
extend
([
mace_pb2
.
DT_UINT8
])
data_type_arg
=
quantize_op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
i
=
mace_pb2
.
DT_UINT8
data_type_arg
=
quantize_op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_non_zero
if
non_zero
:
data_type_arg
.
i
=
1
else
:
data_type_arg
.
i
=
0
find_range_arg
=
quantize_op
.
arg
.
add
()
find_range_arg
.
name
=
\
MaceKeyword
.
mace_find_range_every_time
find_range_arg
.
i
=
1
quantized_inputs_names
[
-
1
]
=
quantize_output_name
non_zero
=
False
del
op
.
input
[:]
op
.
input
.
extend
(
quantized_inputs_names
)
orginal_output_name
=
op
.
output
[
0
]
op
.
output
[
0
]
=
orginal_output_name
+
"_quant"
op
.
output_type
.
extend
([
mace_pb2
.
DT_INT32
])
data_type_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_op_data_type_str
)
# noqa
if
data_type_arg
is
None
:
data_type_arg
=
op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
i
=
mace_pb2
.
DT_UINT8
dequantize_op
=
self
.
_model
.
op
.
add
()
dequantize_op
.
name
=
op
.
name
+
"_dequant"
dequantize_op
.
type
=
MaceOp
.
Dequantize
.
name
dequantize_op
.
input
.
extend
([
op
.
output
[
0
]])
dequantize_op
.
output
.
extend
([
orginal_output_name
])
dequantize_op
.
output_shape
.
extend
(
op
.
output_shape
)
dequantize_op
.
output_type
.
extend
([
mace_pb2
.
DT_FLOAT
])
data_type_arg
=
dequantize_op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
i
=
mace_pb2
.
DT_INT32
quantize_flag_arg
=
ConverterUtil
.
get_arg
(
self
.
_model
,
MaceKeyword
.
mace_quantize_flag_arg_str
)
# noqa
if
quantize_flag_arg
is
None
:
quantize_flag_arg
=
self
.
_model
.
arg
.
add
()
quantize_flag_arg
.
name
=
MaceKeyword
.
mace_quantize_flag_arg_str
quantize_flag_arg
.
i
=
1
return
True
return
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录