Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
736ceb5f
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
736ceb5f
编写于
9月 21, 2017
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor conv related op
上级
f16cf77e
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
83 addition
and
110 deletion
+83
-110
mace/core/logging.h
mace/core/logging.h
+11
-1
mace/core/net.cc
mace/core/net.cc
+2
-0
mace/core/tensor.h
mace/core/tensor.h
+1
-1
mace/core/workspace.cc
mace/core/workspace.cc
+4
-0
mace/kernels/conv_2d.h
mace/kernels/conv_2d.h
+2
-14
mace/kernels/depthwise_conv2d.h
mace/kernels/depthwise_conv2d.h
+3
-14
mace/kernels/global_avg_pooling.h
mace/kernels/global_avg_pooling.h
+1
-4
mace/kernels/pooling.h
mace/kernels/pooling.h
+1
-3
mace/ops/conv_2d.cc
mace/ops/conv_2d.cc
+2
-2
mace/ops/conv_2d.h
mace/ops/conv_2d.h
+10
-7
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+1
-1
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+5
-5
mace/ops/conv_pool_2d_base.h
mace/ops/conv_pool_2d_base.h
+1
-42
mace/ops/depthwise_conv2d.h
mace/ops/depthwise_conv2d.h
+11
-5
mace/ops/depthwise_conv2d_test.cc
mace/ops/depthwise_conv2d_test.cc
+3
-3
mace/ops/depthwise_conv_2d_benchmark.cc
mace/ops/depthwise_conv_2d_benchmark.cc
+1
-1
mace/ops/pooling.h
mace/ops/pooling.h
+1
-0
mace/python/tools/tf_converter.py
mace/python/tools/tf_converter.py
+1
-0
mace/python/tools/tf_converter_lib.py
mace/python/tools/tf_converter_lib.py
+22
-7
未找到文件。
mace/core/logging.h
浏览文件 @
736ceb5f
...
...
@@ -8,6 +8,7 @@
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#undef ERROR
...
...
@@ -41,7 +42,16 @@ template <typename... Args>
string
MakeString
(
const
Args
&
...
args
)
{
std
::
stringstream
ss
;
MakeStringInternal
(
ss
,
args
...);
return
string
(
ss
.
str
());
return
ss
.
str
();
}
template
<
typename
T
>
string
MakeString
(
const
std
::
vector
<
T
>
&
args
)
{
std
::
stringstream
ss
;
for
(
const
T
&
arg
:
args
)
{
ss
<<
arg
<<
", "
;
}
return
ss
.
str
();
}
// Specializations for already-a-string types.
...
...
mace/core/net.cc
浏览文件 @
736ceb5f
...
...
@@ -35,6 +35,8 @@ bool SimpleNet::Run() {
LOG
(
ERROR
)
<<
"Operator failed: "
<<
ProtoDebugString
(
op
->
debug_def
());
return
false
;
}
VLOG
(
1
)
<<
"Op "
<<
op
->
debug_def
().
name
()
<<
" has shape: "
<<
internal
::
MakeString
(
op
->
Output
(
0
)
->
shape
());
}
return
true
;
}
...
...
mace/core/tensor.h
浏览文件 @
736ceb5f
...
...
@@ -137,7 +137,7 @@ class Tensor {
alloc_
->
CopyBytes
(
raw_mutable_data
(),
src
,
size
);
}
inline
void
DebugPrint
()
{
inline
void
DebugPrint
()
const
{
std
::
stringstream
os
;
for
(
int
i
:
shape_
)
{
os
<<
i
<<
", "
;
...
...
mace/core/workspace.cc
浏览文件 @
736ceb5f
...
...
@@ -53,6 +53,10 @@ Tensor* Workspace::GetTensor(const string& name) {
void
Workspace
::
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
)
{
Serializer
serializer
;
for
(
auto
&
tensor_proto
:
net_def
.
tensors
())
{
VLOG
(
1
)
<<
"Load tensor: "
<<
tensor_proto
.
name
()
<<
" has shape: "
<<
internal
::
MakeString
(
vector
<
index_t
>
(
tensor_proto
.
dims
().
begin
(),
tensor_proto
.
dims
().
end
()));
tensor_map_
[
tensor_proto
.
name
()]
=
serializer
.
Deserialize
(
tensor_proto
,
type
);
}
...
...
mace/kernels/conv_2d.h
浏览文件 @
736ceb5f
...
...
@@ -12,19 +12,8 @@ namespace mace {
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
class
Conv2dFunctor
{
public:
Conv2dFunctor
(
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
int
*
strides
,
const
Padding
padding
,
const
int
*
dilations
)
:
strides_
(
strides
),
paddings_
(
2
,
0
),
dilations_
(
dilations
)
{
CalPaddingSize
(
input_shape
,
filter_shape
,
dilations_
,
strides_
,
padding
,
paddings_
.
data
());
}
struct
Conv2dFunctor
{
Conv2dFunctor
()
{}
Conv2dFunctor
(
const
int
*
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
int
*
dilations
)
:
...
...
@@ -112,7 +101,6 @@ class Conv2dFunctor {
}
}
private:
const
int
*
strides_
;
// [stride_h, stride_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
...
...
mace/kernels/depthwise_conv2d.h
浏览文件 @
736ceb5f
...
...
@@ -13,18 +13,8 @@ namespace mace {
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
class
DepthwiseConv2dFunctor
{
public:
DepthwiseConv2dFunctor
(
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
int
*
strides
,
const
Padding
padding
,
const
int
*
dilations
)
:
strides_
(
strides
),
paddings_
(
2
,
0
),
dilations_
(
dilations
)
{
CalPaddingSize
(
input_shape
,
filter_shape
,
dilations_
,
strides_
,
padding
,
paddings_
.
data
());
}
struct
DepthwiseConv2dFunctor
{
DepthwiseConv2dFunctor
()
{}
DepthwiseConv2dFunctor
(
const
int
*
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
int
*
dilations
)
:
...
...
@@ -39,7 +29,6 @@ class DepthwiseConv2dFunctor {
const
T
*
bias
,
// c_out
T
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
MACE_CHECK_NOTNULL
(
output
);
index_t
batch
=
output_shape
[
0
];
...
...
@@ -111,7 +100,7 @@ class DepthwiseConv2dFunctor {
}
}
}
private:
const
int
*
strides_
;
// [stride_h, stride_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
...
...
mace/kernels/global_avg_pooling.h
浏览文件 @
736ceb5f
...
...
@@ -11,10 +11,7 @@ namespace mace {
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
class
GlobalAvgPoolingFunctor
{
public:
GlobalAvgPoolingFunctor
()
{}
struct
GlobalAvgPoolingFunctor
{
void
operator
()(
const
T
*
input
,
const
index_t
*
input_shape
,
T
*
output
)
{
index_t
batch
=
input_shape
[
0
];
index_t
channels
=
input_shape
[
1
];
...
...
mace/kernels/pooling.h
浏览文件 @
736ceb5f
...
...
@@ -18,8 +18,7 @@ enum PoolingType {
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
class
PoolingFunctor
{
public:
struct
PoolingFunctor
{
PoolingFunctor
(
const
PoolingType
pooling_type
,
const
int
*
kernels
,
const
int
*
strides
,
...
...
@@ -114,7 +113,6 @@ class PoolingFunctor {
}
}
private:
const
PoolingType
pooling_type_
;
const
int
*
kernels_
;
const
int
*
strides_
;
...
...
mace/ops/conv_2d.cc
浏览文件 @
736ceb5f
...
...
@@ -6,10 +6,10 @@
namespace
mace
{
REGISTER_CPU_OPERATOR
(
Conv2
d
,
Conv2dOp
<
DeviceType
::
CPU
,
float
>
);
REGISTER_CPU_OPERATOR
(
Conv2
D
,
Conv2dOp
<
DeviceType
::
CPU
,
float
>
);
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
Conv2
d
,
Conv2dOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
Conv2
D
,
Conv2dOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
}
// namespace mace
mace/ops/conv_2d.h
浏览文件 @
736ceb5f
...
...
@@ -17,12 +17,10 @@ template<DeviceType D, typename T>
class
Conv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
Conv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
),
functor_
(
this
->
Input
(
INPUT
)
->
shape
().
data
(),
this
->
Input
(
FILTER
)
->
shape
().
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
this
->
dilations_
.
data
())
{}
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
)
{
functor_
.
strides_
=
this
->
strides_
.
data
();
functor_
.
dilations_
=
this
->
dilations_
.
data
();
}
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
...
...
@@ -37,8 +35,13 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
this
->
CalOutputSize
(
input
->
shape
().
data
(),
filter
->
shape
().
data
(),
output_shape
.
data
());
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
filter
->
shape
().
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
functor_
.
paddings_
=
paddings
;
functor_
(
input
->
data
<
T
>
(),
input
->
shape
().
data
(),
filter
->
data
<
T
>
(),
filter
->
shape
().
data
(),
bias_data
,
output
->
mutable_data
<
T
>
(),
...
...
mace/ops/conv_2d_benchmark.cc
浏览文件 @
736ceb5f
...
...
@@ -25,7 +25,7 @@ static void Conv2d(int iters,
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpDefBuilder
(
"Conv2
d
"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2
D
"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
736ceb5f
...
...
@@ -13,7 +13,7 @@ class Conv2dOpTest : public OpsTestBase {};
TEST_F
(
Conv2dOpTest
,
Simple_VALID
)
{
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2
d
"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2
D
"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
@@ -47,7 +47,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
TEST_F
(
Conv2dOpTest
,
Simple_SAME
)
{
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2
d
"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2
D
"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
@@ -83,7 +83,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
TEST_F
(
Conv2dOpTest
,
Combined
)
{
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2
d"
,
"Conv2d
Test"
)
OpDefBuilder
(
"Conv2
D"
,
"Conv2D
Test"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
@@ -121,7 +121,7 @@ TEST_F(Conv2dOpTest, Combined) {
TEST_F
(
Conv2dOpTest
,
Conv1x1
)
{
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2
d"
,
"Conv2d
Test"
)
OpDefBuilder
(
"Conv2
D"
,
"Conv2D
Test"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
@@ -179,7 +179,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
index_t
output_channels
=
1
+
rand
()
%
10
;
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2
d
"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2
D
"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
mace/ops/conv_pool_2d_base.h
浏览文件 @
736ceb5f
...
...
@@ -19,48 +19,7 @@ class ConvPool2dOpBase : public Operator<D, T> {
padding_
(
static_cast
<
Padding
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"padding"
,
static_cast
<
int
>
(
SAME
)))),
dilations_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"dilations"
,
{
1
,
1
}))
{}
void
CalOutputSize
(
const
index_t
*
input_shape
,
// NCHW
const
index_t
*
filter_shape
,
// OIHW
index_t
*
output_shape
)
{
MACE_CHECK
(
dilations_
[
0
]
>
0
&&
dilations_
[
1
]
>
0
,
"Invalid dilations, must >= 1"
);
MACE_CHECK
((
dilations_
[
0
]
==
1
||
strides_
[
0
]
==
1
)
&&
(
dilations_
[
1
]
==
1
||
strides_
[
1
]
==
1
),
"If dilations > 1, strides should be 1"
);
MACE_CHECK_NOTNULL
(
output_shape
);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
index_t
output_height
=
0
,
output_width
=
0
;
switch
(
padding_
)
{
case
VALID
:
output_height
=
(
input_shape
[
2
]
-
(
filter_shape
[
2
]
-
1
)
*
dilations_
[
0
]
-
1
)
/
strides_
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
(
filter_shape
[
3
]
-
1
)
*
dilations_
[
1
]
-
1
)
/
strides_
[
1
]
+
1
;
break
;
case
SAME
:
output_height
=
(
input_shape
[
2
]
-
1
)
/
strides_
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
1
)
/
strides_
[
1
]
+
1
;
break
;
case
FULL
:
output_height
=
(
input_shape
[
2
]
+
(
filter_shape
[
2
]
-
1
)
*
dilations_
[
0
]
-
1
)
/
strides_
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
+
(
filter_shape
[
3
]
-
1
)
*
dilations_
[
1
]
-
1
)
/
strides_
[
1
]
+
1
;
break
;
default:
MACE_CHECK
(
false
,
"Unsupported padding type: "
,
padding_
);
}
output_shape
[
0
]
=
input_shape
[
0
];
output_shape
[
1
]
=
filter_shape
[
0
];
output_shape
[
2
]
=
output_height
;
output_shape
[
3
]
=
output_width
;
}
protected:
std
::
vector
<
int
>
strides_
;
Padding
padding_
;
...
...
mace/ops/depthwise_conv2d.h
浏览文件 @
736ceb5f
...
...
@@ -18,10 +18,10 @@ template<DeviceType D, typename T>
class
DepthwiseConv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
DepthwiseConv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
)
,
functor_
(
this
->
Input
(
INPUT
)
->
shape
().
data
(),
this
->
Input
(
FILTER
)
->
shape
().
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
this
->
dilations_
.
data
())
{};
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
)
{
functor_
.
strides_
=
this
->
strides_
.
data
();
functor_
.
dilations_
=
this
->
dilations_
.
data
();
}
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
...
...
@@ -38,8 +38,14 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
filter_shape
[
0
]
*=
filter_shape
[
1
];
filter_shape
[
1
]
=
1
;
std
::
vector
<
index_t
>
output_shape
(
4
);
this
->
CalOutputSize
(
input
->
shape
().
data
(),
filter_shape
.
data
(),
output_shape
.
data
());
std
::
vector
<
int
>
paddings
(
2
);
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
filter_shape
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
functor_
.
paddings_
=
paddings
;
functor_
(
input
->
data
<
T
>
(),
input
->
shape
().
data
(),
filter
->
data
<
T
>
(),
filter_shape
.
data
(),
bias_data
,
output
->
mutable_data
<
T
>
(),
...
...
mace/ops/depthwise_conv2d_test.cc
浏览文件 @
736ceb5f
...
...
@@ -10,9 +10,10 @@ using namespace mace;
class
DepthwiseConv2dOpTest
:
public
OpsTestBase
{};
TEST_F
(
DepthwiseConv2dOpTest
,
Simple_VALID
)
{
testing
::
internal
::
LogToStderr
();
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2
d
Test"
)
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2
D
Test"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
@@ -35,7 +36,6 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
3.0
f
,
7.0
f
,
11.0
f
,
15.0
f
,
4.0
f
,
8.0
f
,
12.0
f
,
16.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
4
},
{
.1
f
,
.2
f
,
.3
f
,
.4
f
});
// Run
net
.
RunOp
();
...
...
@@ -61,7 +61,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
index_t
multiplier
=
3
+
rand
()
%
10
;
// Construct graph
auto
&
net
=
test_net
();
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2
d
Test"
)
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2
D
Test"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
mace/ops/depthwise_conv_2d_benchmark.cc
浏览文件 @
736ceb5f
...
...
@@ -25,7 +25,7 @@ static void DepthwiseConv2d(int iters,
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpDefBuilder
(
"DepthwiseConv2
d"
,
"DepthwiseConv2d
Test"
)
OpDefBuilder
(
"DepthwiseConv2
D"
,
"DepthwiseConv2D
Test"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
...
...
mace/ops/pooling.h
浏览文件 @
736ceb5f
...
...
@@ -28,6 +28,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
std
::
vector
<
index_t
>
filter_shape
(
4
);
// TODO(chenghui): is it kind of a hack?
filter_shape
[
0
]
=
input
->
shape
()[
1
];
filter_shape
[
1
]
=
input
->
shape
()[
0
];
filter_shape
[
2
]
=
kernels_
[
0
];
...
...
mace/python/tools/tf_converter.py
浏览文件 @
736ceb5f
...
...
@@ -23,6 +23,7 @@ def main(unused_args):
with
gfile
.
GFile
(
FLAGS
.
output
,
"wb"
)
as
f
:
f
.
write
(
output_graph_def
.
SerializeToString
())
with
gfile
.
GFile
(
FLAGS
.
output
+
'_txt'
,
"wb"
)
as
f
:
output_graph_def
.
ClearField
(
'tensors'
)
f
.
write
(
str
(
output_graph_def
))
...
...
mace/python/tools/tf_converter_lib.py
浏览文件 @
736ceb5f
from
mace.proto
import
mace_pb2
import
tensorflow
as
tf
import
numpy
as
np
padding_mode
=
{
'VALID'
:
0
,
...
...
@@ -24,11 +25,20 @@ def convert_ops(unresolved_ops, net_def):
tf_tensor
=
first_op
.
outputs
[
0
].
eval
()
tensor
=
net_def
.
tensors
.
add
()
tensor
.
name
=
first_op
.
outputs
[
0
].
name
tensor
.
dims
.
extend
(
tf_tensor
.
shape
)
# TODO: support other type than float
tensor
.
data_type
=
mace_pb2
.
DT_FLOAT
shape
=
list
(
tf_tensor
.
shape
)
if
(
first_op
.
name
.
find
(
'pointwise_kernel'
)
!=
-
1
or
first_op
.
name
.
find
(
'depthwise_kernel'
)
!=
-
1
or
first_op
.
name
.
endswith
(
'weights'
)
or
first_op
.
name
.
endswith
(
'kernel'
))
\
and
first_op
.
outputs
[
0
].
consumers
()[
0
].
type
.
find
(
'Conv'
)
!=
-
1
:
tf_tensor
=
np
.
transpose
(
tf_tensor
,
axes
=
(
3
,
2
,
0
,
1
))
shape
=
[
shape
[
3
],
shape
[
2
],
shape
[
0
],
shape
[
1
]]
# print (tensor.name, shape)
tensor
.
dims
.
extend
(
shape
)
tensor
.
float_data
.
extend
(
tf_tensor
.
astype
(
float
).
flat
)
# net_def.tensors.extend([tensor])
elif
first_op
.
type
==
'Conv2D'
or
first_op
.
type
==
'DepthwiseConv2dNative'
:
op_def
=
net_def
.
op
.
add
()
op_def
.
name
=
first_op
.
name
...
...
@@ -43,10 +53,12 @@ def convert_ops(unresolved_ops, net_def):
padding_arg
.
i
=
padding_mode
[
first_op
.
get_attr
(
'padding'
)]
strides_arg
=
op_def
.
arg
.
add
()
strides_arg
.
name
=
'strides'
strides_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'strides'
))
strides_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'strides'
)
[
2
:]
)
data_format_arg
=
op_def
.
arg
.
add
()
data_format_arg
.
name
=
'data_format'
data_format_arg
.
s
=
first_op
.
get_attr
(
'data_format'
)
if
first_op
.
get_attr
(
'data_format'
)
!=
'NCHW'
:
raise
Exception
(
'only support NCHW now'
)
if
ops_count
>=
2
and
unresolved_ops
[
1
].
type
==
'BiasAdd'
:
bias_add_op
=
unresolved_ops
[
1
]
...
...
@@ -93,7 +105,7 @@ def convert_ops(unresolved_ops, net_def):
op_def
.
type
=
first_op
.
type
op_def
.
input
.
extend
([
input
.
name
for
input
in
first_op
.
inputs
])
op_def
.
output
.
extend
([
output
.
name
for
output
in
first_op
.
outputs
])
elif
first_op
.
type
==
'AvgPool'
:
elif
first_op
.
type
==
'AvgPool'
or
first_op
.
type
==
'MaxPool'
:
op_def
=
net_def
.
op
.
add
()
op_def
.
name
=
first_op
.
name
op_def
.
type
=
'Pooling'
...
...
@@ -107,12 +119,15 @@ def convert_ops(unresolved_ops, net_def):
padding_arg
.
i
=
padding_mode
[
first_op
.
get_attr
(
'padding'
)]
strides_arg
=
op_def
.
arg
.
add
()
strides_arg
.
name
=
'strides'
strides_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'strides'
)[
1
:
-
1
])
strides_arg
.
name
=
'kernels'
strides_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'ksize'
)[
1
:
-
1
])
strides_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'strides'
)[
2
:])
kernels_arg
=
op_def
.
arg
.
add
()
kernels_arg
.
name
=
'kernels'
kernels_arg
.
ints
.
extend
(
first_op
.
get_attr
(
'ksize'
)[
2
:])
data_format_arg
=
op_def
.
arg
.
add
()
data_format_arg
.
name
=
'data_format'
data_format_arg
.
s
=
first_op
.
get_attr
(
'data_format'
)
if
first_op
.
get_attr
(
'data_format'
)
!=
'NCHW'
:
raise
Exception
(
'only support NCHW now'
)
else
:
raise
Exception
(
'Unknown Op: '
+
first_op
.
name
)
pass
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录