Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
04ecb90e
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
04ecb90e
编写于
12月 04, 2019
作者:
L
luxuhui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support BatchMatMulV2 & Select ops for tensorflow
N/A Signed-off-by:
N
Luxuhui
<
luxuhui@xiaomi.com
>
上级
0f37ee96
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
384 addition
and
9 deletion
+384
-9
mace/core/operator.cc
mace/core/operator.cc
+1
-1
mace/ops/registry/ops_registry.cc
mace/ops/registry/ops_registry.cc
+2
-0
mace/ops/select.cc
mace/ops/select.cc
+213
-0
test/ccunit/mace/ops/select_test.cc
test/ccunit/mace/ops/select_test.cc
+151
-0
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+4
-3
tools/python/transform/tensorflow_converter.py
tools/python/transform/tensorflow_converter.py
+9
-1
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+4
-4
未找到文件。
mace/core/operator.cc
浏览文件 @
04ecb90e
...
...
@@ -289,7 +289,7 @@ void OpRegistryBase::GetInOutMemoryTypes(
const
std
::
string
&
op_type
,
OpConditionContext
*
context
)
const
{
MACE_CHECK
(
registry_
.
count
(
op_type
)
!=
0
,
op_type
,
" operation is not registered.
"
);
op_type
,
" operation is not registered.
op_type="
,
op_type
);
return
registry_
.
at
(
op_type
)
->
memory_type_setter
(
context
);
}
...
...
mace/ops/registry/ops_registry.cc
浏览文件 @
04ecb90e
...
...
@@ -64,6 +64,7 @@ extern void RegisterResizeBilinear(OpRegistryBase *op_registry);
extern
void
RegisterResizeNearestNeighbor
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterReverse
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterScalarMath
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterSelect
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterShape
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterSlice
(
OpRegistryBase
*
op_registry
);
extern
void
RegisterSoftmax
(
OpRegistryBase
*
op_registry
);
...
...
@@ -143,6 +144,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops
::
RegisterResizeNearestNeighbor
(
this
);
ops
::
RegisterReverse
(
this
);
ops
::
RegisterScalarMath
(
this
);
ops
::
RegisterSelect
(
this
);
ops
::
RegisterShape
(
this
);
ops
::
RegisterSlice
(
this
);
ops
::
RegisterSoftmax
(
this
);
...
...
mace/ops/select.cc
0 → 100644
浏览文件 @
04ecb90e
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
namespace
mace
{
namespace
ops
{
template
<
DeviceType
D
,
typename
T
>
class
SelectOp
;
template
<
>
class
SelectOp
<
DeviceType
::
CPU
,
float
>
:
public
Operation
{
public:
explicit
SelectOp
(
OpConstructContext
*
context
)
:
Operation
(
context
)
{}
MaceStatus
Run
(
OpContext
*
context
)
override
{
if
(
this
->
InputSize
()
==
1
)
{
return
RunWithNoData
(
context
);
}
else
{
return
RunWithData
(
context
);
}
}
MaceStatus
RunWithNoData
(
OpContext
*
context
)
{
const
Tensor
*
condition
=
this
->
Input
(
CONDITION
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
const
index_t
condition_rank
=
condition
->
dim_size
();
MACE_RETURN_IF_ERROR
(
output
->
Resize
({
condition
->
size
(),
condition_rank
}));
float
*
output_data
=
output
->
mutable_data
<
float
>
();
const
bool
*
condition_data
=
condition
->
data
<
bool
>
();
index_t
i
=
0
;
if
(
condition_rank
==
1
)
{
const
index_t
channel
=
condition
->
dim
(
0
);
for
(
index_t
c
=
0
;
c
<
channel
;
++
c
)
{
if
(
condition_data
[
c
])
{
output_data
[
i
++
]
=
c
;
}
}
}
else
if
(
condition_rank
==
2
)
{
const
index_t
width
=
condition
->
dim
(
0
);
const
index_t
channel
=
condition
->
dim
(
1
);
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
index_t
w_base
=
w
*
channel
;
for
(
index_t
c
=
0
;
c
<
channel
;
++
c
)
{
if
(
condition_data
[
w_base
+
c
])
{
output_data
[
i
++
]
=
w
;
output_data
[
i
++
]
=
c
;
}
}
}
}
else
if
(
condition_rank
==
3
)
{
const
index_t
height
=
condition
->
dim
(
0
);
const
index_t
width
=
condition
->
dim
(
1
);
const
index_t
channel
=
condition
->
dim
(
2
);
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
index_t
h_base
=
h
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
index_t
w_base
=
(
w
+
h_base
)
*
channel
;
for
(
index_t
c
=
0
;
c
<
channel
;
++
c
)
{
if
(
condition_data
[
w_base
+
c
])
{
output_data
[
i
++
]
=
h
;
output_data
[
i
++
]
=
w
;
output_data
[
i
++
]
=
c
;
}
}
}
}
}
else
if
(
condition_rank
==
4
)
{
const
index_t
batch
=
condition
->
dim
(
0
);
const
index_t
height
=
condition
->
dim
(
1
);
const
index_t
width
=
condition
->
dim
(
2
);
const
index_t
channel
=
condition
->
dim
(
3
);
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
index_t
b_base
=
b
*
height
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
index_t
h_base
=
(
b_base
+
h
)
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
index_t
w_base
=
(
w
+
h_base
)
*
channel
;
for
(
index_t
c
=
0
;
c
<
channel
;
++
c
)
{
if
(
condition_data
[
w_base
+
c
])
{
output_data
[
i
++
]
=
b
;
output_data
[
i
++
]
=
h
;
output_data
[
i
++
]
=
w
;
output_data
[
i
++
]
=
c
;
}
}
}
}
}
}
else
{
const
index_t
condition_size
=
condition
->
size
();
const
index_t
condition_rank
=
condition
->
dim_size
();
auto
div_buffer
=
context
->
device
()
->
scratch_buffer
();
div_buffer
->
Rewind
();
MACE_RETURN_IF_ERROR
(
div_buffer
->
GrowSize
(
condition_rank
*
sizeof
(
index_t
)));
index_t
*
div_ptr
=
div_buffer
->
mutable_data
<
index_t
>
();
div_ptr
[
condition_rank
-
1
]
=
1
;
for
(
index_t
dim
=
condition_rank
-
1
;
dim
>
0
;
--
dim
)
{
div_ptr
[
dim
-
1
]
=
div_ptr
[
dim
]
*
condition
->
dim
(
dim
);
}
for
(
index_t
c
=
0
;
c
<
condition_size
;
++
c
)
{
if
(
condition_data
[
c
])
{
auto
remainder
=
c
;
for
(
index_t
dim
=
0
;
dim
<
condition_rank
;
++
dim
)
{
output_data
[
i
++
]
=
remainder
/
div_ptr
[
dim
];
remainder
=
remainder
%
div_ptr
[
dim
];
}
}
}
}
MACE_RETURN_IF_ERROR
(
output
->
Resize
({
i
/
condition_rank
,
condition_rank
}));
return
MaceStatus
::
MACE_SUCCESS
;
}
bool
CheckDataValid
(
const
Tensor
*
condition
,
const
Tensor
*
x
,
const
Tensor
*
y
)
{
const
index_t
x_rank
=
x
->
dim_size
();
const
index_t
y_rank
=
y
->
dim_size
();
const
index_t
condition_rank
=
condition
->
dim_size
();
MACE_CHECK
(
condition_rank
<=
x_rank
&&
x_rank
==
y_rank
);
for
(
index_t
i
=
0
;
i
<
condition_rank
;
++
i
)
{
MACE_CHECK
(
condition
->
dim
(
i
)
==
x
->
dim
(
i
),
"dimensions are not equal: "
,
MakeString
(
condition
->
shape
()),
" vs. "
,
MakeString
(
x
->
shape
()));
}
for
(
index_t
i
=
0
;
i
<
x_rank
;
++
i
)
{
MACE_CHECK
(
y
->
dim
(
i
)
==
x
->
dim
(
i
),
"dimensions are not equal: "
,
MakeString
(
y
->
shape
()),
" vs. "
,
MakeString
(
x
->
shape
()));
}
return
true
;
}
MaceStatus
RunWithData
(
OpContext
*
context
)
{
const
Tensor
*
condition
=
this
->
Input
(
CONDITION
);
const
Tensor
*
x
=
this
->
Input
(
X
);
const
Tensor
*
y
=
this
->
Input
(
Y
);
MACE_ASSERT
(
CheckDataValid
(
condition
,
x
,
y
));
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
MACE_RETURN_IF_ERROR
(
output
->
Resize
(
x
->
shape
()));
float
*
output_data
=
output
->
mutable_data
<
float
>
();
const
bool
*
condition_data
=
condition
->
data
<
bool
>
();
const
float
*
x_data
=
x
->
data
<
float
>
();
const
float
*
y_data
=
y
->
data
<
float
>
();
const
index_t
condition_size
=
condition
->
size
();
const
index_t
x_size
=
x
->
size
();
utils
::
ThreadPool
&
thread_pool
=
context
->
device
()
->
cpu_runtime
()
->
thread_pool
();
if
(
condition_size
==
x_size
)
{
thread_pool
.
Compute1D
([
=
](
index_t
start
,
index_t
end
,
index_t
step
)
{
for
(
index_t
k
=
start
;
k
<
end
;
k
+=
step
)
{
// LOG(INFO) << "condition_data[" << k << "] = " << condition_data[k];
output_data
[
k
]
=
condition_data
[
k
]
?
x_data
[
k
]
:
y_data
[
k
];
}
},
0
,
x_size
,
1
);
}
else
if
(
x_size
>
condition_size
)
{
// broadcast
const
auto
block_size
=
x_size
/
condition_size
;
MACE_ASSERT
(
block_size
>
1
&&
x_size
%
condition_size
==
0
,
"x_size should be a multiple of condition_size and greater than 1"
);
const
auto
raw_block_size
=
block_size
*
sizeof
(
float
);
thread_pool
.
Compute1D
([
=
](
index_t
start
,
index_t
end
,
index_t
step
)
{
for
(
index_t
k
=
start
;
k
<
end
;
k
+=
step
)
{
auto
offset
=
block_size
*
k
;
if
(
condition_data
[
k
])
{
memcpy
(
output_data
+
offset
,
x_data
+
offset
,
raw_block_size
);
}
else
{
memcpy
(
output_data
+
offset
,
y_data
+
offset
,
raw_block_size
);
}
}
},
0
,
condition_size
,
1
);
}
else
{
MACE_CHECK
(
false
,
"x_size should be bigger than condition_size"
);
}
return
MaceStatus
::
MACE_SUCCESS
;
}
private:
MACE_OP_INPUT_TAGS
(
CONDITION
,
X
,
Y
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
};
void
RegisterSelect
(
OpRegistryBase
*
op_registry
)
{
MACE_REGISTER_OP
(
op_registry
,
"Select"
,
SelectOp
,
DeviceType
::
CPU
,
float
);
}
}
// namespace ops
}
// namespace mace
test/ccunit/mace/ops/select_test.cc
0 → 100644
浏览文件 @
04ecb90e
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
ops
{
namespace
test
{
class
SelectOpTest
:
public
OpsTestBase
{};
namespace
{
template
<
DeviceType
D
,
typename
T
>
void
TestSelect
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
uint8_t
>
&
input
,
const
std
::
vector
<
index_t
>
&
x_shape
,
const
std
::
vector
<
T
>
&
x
,
const
std
::
vector
<
index_t
>
&
y_shape
,
const
std
::
vector
<
T
>
&
y
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
T
>
&
output
)
{
// Construct graph
OpsTestNet
net
;
OpDefBuilder
builder
(
"Select"
,
"SelectTest"
);
builder
.
Input
(
"Input"
);
if
(
x
.
size
()
>
0
)
{
builder
.
Input
(
"X"
).
Input
(
"Y"
);
}
builder
.
Output
(
"Output"
).
Finalize
(
net
.
NewOperatorDef
());
net
.
AddInputFromArray
<
D
,
uint8_t
>
(
MakeString
(
"Input"
),
input_shape
,
input
);
if
(
x
.
size
()
>
0
)
{
net
.
AddInputFromArray
<
D
,
T
>
(
MakeString
(
"X"
),
x_shape
,
x
);
net
.
AddInputFromArray
<
D
,
T
>
(
MakeString
(
"Y"
),
y_shape
,
y
);
}
// Run
net
.
RunOp
();
net
.
AddInputFromArray
<
D
,
T
>
(
"ExpectedOutput"
,
output_shape
,
output
);
ExpectTensorNear
<
T
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
*
net
.
GetOutput
(
"Output"
));
}
}
// namespace
TEST_F
(
SelectOpTest
,
SimpleTestWithData
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
,
3
},
{
true
,
false
,
false
,
false
,
true
,
true
},
{
2
,
3
},
{
3.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
},
{
2
,
3
},
{
3.0
,
-
1.0
,
-
2.0
,
-
3.0
,
8.0
,
9.0
},
{
2
,
3
},
{
3.0
,
-
1.0
,
-
2.0
,
-
3.0
,
5.0
,
6.0
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithDataBroadcast
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
},
{
true
,
false
},
{
2
,
3
},
{
3.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
},
{
2
,
3
},
{
3.0
,
-
1.0
,
-
2.0
,
-
3.0
,
8.0
,
9.0
},
{
2
,
3
},
{
3
,
2
,
3
,
-
3
,
8
,
9
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithNoDataBroadcast1
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
},
{
true
,
false
},
{},
{},
{},
{},
{
1
,
1
},
{
0
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithNoDataBroadcast2
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
,
3
},
{
true
,
false
,
false
,
false
,
true
,
true
},
{},
{},
{},
{},
{
3
,
2
},
{
0
,
0
,
1
,
1
,
1
,
2
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithNoDataBroadcast3
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
,
2
,
3
},
{
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
},
{},
{},
{},
{},
{
6
,
3
},
{
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
2
,
1
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithNoDataBroadcast4
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
,
2
,
2
,
3
},
{
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
},
{},
{},
{},
{},
{
12
,
4
},
{
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
2
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
1
,
1
,
2
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
2
,
1
,
1
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
});
}
TEST_F
(
SelectOpTest
,
SimpleTestWithNoDataBroadcast5
)
{
TestSelect
<
DeviceType
::
CPU
,
float
>
(
{
2
,
2
,
2
,
2
,
3
},
{
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
,
true
,
false
,
false
,
false
,
true
,
true
},
{},
{},
{},
{},
{
24
,
5
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
1
,
1
,
2
,
0
,
1
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
2
,
0
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
1
,
1
,
2
,
1
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
1
,
2
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
1
,
2
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
2
,
1
,
1
,
1
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
});
}
}
// namespace test
}
// namespace ops
}
// namespace mace
tools/python/transform/base_converter.py
浏览文件 @
04ecb90e
...
...
@@ -136,6 +136,7 @@ MaceSupportedOps = [
'ResizeNearestNeighbor'
,
'Reverse'
,
'ScalarMath'
,
'Select'
,
'Slice'
,
'Splice'
,
'Split'
,
...
...
@@ -280,7 +281,7 @@ class MaceKeyword(object):
class
TransformerRule
(
Enum
):
REMOVE_
IDENTITY
_OP
=
1
REMOVE_
USELESS
_OP
=
1
TRANSFORM_GLOBAL_POOLING
=
2
FOLD_RESHAPE
=
3
TRANSFORM_MATMUL_TO_FC
=
4
...
...
@@ -526,9 +527,9 @@ class ConverterOption(object):
else
:
self
.
_transformer_option
=
[
# Model structure related transformation
TransformerRule
.
REMOVE_
IDENTITY
_OP
,
TransformerRule
.
REMOVE_
USELESS
_OP
,
TransformerRule
.
TRANSFORM_FAKE_QUANTIZE
,
TransformerRule
.
REMOVE_
IDENTITY
_OP
,
TransformerRule
.
REMOVE_
USELESS
_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
TransformerRule
.
TRANSFORM_LSTMCELL_ZEROSTATE
,
TransformerRule
.
TRANSFORM_BASIC_LSTMCELL
,
...
...
tools/python/transform/tensorflow_converter.py
浏览文件 @
04ecb90e
...
...
@@ -57,6 +57,7 @@ TFSupportedOps = [
'ArgMax'
,
'AvgPool'
,
'BatchMatMul'
,
'BatchMatMulV2'
,
'BatchToSpaceND'
,
'BiasAdd'
,
'Cast'
,
...
...
@@ -105,6 +106,7 @@ TFSupportedOps = [
'ResizeNearestNeighbor'
,
'ReverseV2'
,
'Rsqrt'
,
'Select'
,
'Shape'
,
'Sigmoid'
,
'Sign'
,
...
...
@@ -134,7 +136,7 @@ TFSupportedOps = [six.b(op) for op in TFSupportedOps]
TFTransformGraphOptions
=
[
'strip_unused_nodes'
,
'remove_nodes(op=Identity, op=CheckNumerics)'
,
'remove_nodes(op=Identity, op=CheckNumerics
, op=StopGradient
)'
,
'fold_constants(ignore_errors=true)'
,
'fold_batch_norms'
,
'fold_old_batch_norms'
,
...
...
@@ -211,6 +213,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType
.
ArgMax
.
name
:
self
.
convert_argmax
,
TFOpType
.
AvgPool
.
name
:
self
.
convert_pooling
,
TFOpType
.
BatchMatMul
.
name
:
self
.
convert_matmul
,
TFOpType
.
BatchMatMulV2
.
name
:
self
.
convert_matmul
,
TFOpType
.
BatchToSpaceND
.
name
:
self
.
convert_space_batch
,
TFOpType
.
BiasAdd
.
name
:
self
.
convert_biasadd
,
TFOpType
.
Cast
.
name
:
self
.
convert_cast
,
...
...
@@ -263,6 +266,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType
.
ResizeBilinear
.
name
:
self
.
convert_resize_bilinear
,
TFOpType
.
ResizeNearestNeighbor
.
name
:
self
.
convert_resize_nearest_neighbor
,
# noqa
TFOpType
.
ReverseV2
.
name
:
self
.
convert_reverse
,
TFOpType
.
Select
.
name
:
self
.
convert_select
,
TFOpType
.
Shape
.
name
:
self
.
convert_shape
,
TFOpType
.
Sigmoid
.
name
:
self
.
convert_activation
,
TFOpType
.
Sign
.
name
:
self
.
convert_elementwise
,
...
...
@@ -993,6 +997,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Reverse
.
name
def
convert_select
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Select
.
name
def
convert_stack
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Stack
.
name
...
...
tools/python/transform/transformer.py
浏览文件 @
04ecb90e
...
...
@@ -48,7 +48,7 @@ class Transformer(base_converter.ConverterInterface):
self
.
_registered_transformers
=
{
TransformerRule
.
TRANSFORM_FAKE_QUANTIZE
:
self
.
transform_fake_quantize
,
TransformerRule
.
REMOVE_
IDENTITY_OP
:
self
.
remove_identity
_op
,
TransformerRule
.
REMOVE_
USELESS_OP
:
self
.
remove_useless
_op
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
:
self
.
transform_global_pooling
,
TransformerRule
.
TRANSFORM_LSTMCELL_ZEROSTATE
:
...
...
@@ -347,15 +347,15 @@ class Transformer(base_converter.ConverterInterface):
return
False
def
remove_
identity
_op
(
self
):
def
remove_
useless
_op
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
'Identity'
:
print
(
"Remove
identity
: %s(%s)"
%
(
op
.
name
,
op
.
type
))
print
(
"Remove
useless op
: %s(%s)"
%
(
op
.
name
,
op
.
type
))
self
.
safe_remove_node
(
op
,
self
.
_producer
.
get
(
op
.
input
[
0
],
None
))
return
True
if
op
.
type
==
'Reshape'
and
\
el
if
op
.
type
==
'Reshape'
and
\
op
.
output_shape
[
0
].
dims
==
\
self
.
get_tensor_shape
(
op
.
input
[
0
]):
print
(
"Remove useless reshape: %s(%s)"
%
(
op
.
name
,
op
.
type
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录