Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a63ee29d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a63ee29d
编写于
8月 17, 2020
作者:
L
lyvette
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug when op is custom
fix deconv bug
上级
a974a354
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
170 addition
and
9 deletion
+170
-9
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/l2norm.tflite
.../ut/tools/converter/parser/tflite/test_data/l2norm.tflite
+0
-0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_l2norm_parser_test.cc
...ools/converter/parser/tflite/tflite_l2norm_parser_test.cc
+41
-0
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
...ite/tools/converter/parser/tflite/tflite_deconv_parser.cc
+0
-2
mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc
...ite/tools/converter/parser/tflite/tflite_l2norm_parser.cc
+75
-0
mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h
...lite/tools/converter/parser/tflite/tflite_l2norm_parser.h
+43
-0
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
...lite/tools/converter/parser/tflite/tflite_model_parser.cc
+10
-6
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
.../lite/tools/converter/parser/tflite/tflite_model_parser.h
+1
-1
未找到文件。
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/l2norm.tflite
0 → 100644
浏览文件 @
a63ee29d
文件已添加
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_l2norm_parser_test.cc
0 → 100644
浏览文件 @
a63ee29d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace
mindspore
{
class
TestTfliteParserL2Norm
:
public
TestTfliteParser
{
public:
TestTfliteParserL2Norm
()
=
default
;
void
SetUp
()
override
{
meta_graph
=
LoadAndConvert
(
"./l2norm.tflite"
,
""
);
}
};
TEST_F
(
TestTfliteParserL2Norm
,
OpType
)
{
ASSERT_NE
(
meta_graph
,
nullptr
);
ASSERT_GT
(
meta_graph
->
nodes
.
size
(),
0
);
ASSERT_EQ
(
meta_graph
->
nodes
.
front
()
->
primitive
->
value
.
type
,
schema
::
PrimitiveType_L2Norm
)
<<
"wrong Op Type"
;
}
TEST_F
(
TestTfliteParserL2Norm
,
AttrValue
)
{
ASSERT_NE
(
meta_graph
->
nodes
.
front
()
->
primitive
->
value
.
AsL2Norm
(),
nullptr
);
auto
val
=
meta_graph
->
nodes
.
front
()
->
primitive
->
value
.
AsL2Norm
();
ASSERT_EQ
(
val
->
epsilon
,
0.0
);
std
::
vector
<
int32_t
>
axis
=
{
0
,
1
,
2
,
3
};
ASSERT_EQ
(
val
->
axis
,
axis
);
}
}
// namespace mindspore
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
浏览文件 @
a63ee29d
...
@@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
...
@@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
tflite_op
->
inputs
[
2
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
tflite_op
->
inputs
[
2
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
AddOpInput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
AddOpInput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
tflite_op
->
inputs
[
1
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_KHWC
);
tflite_op
->
inputs
[
1
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_KHWC
);
AddOpInput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
tflite_op
->
inputs
[
0
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
AddOpOutput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
AddOpOutput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
tflite_op
->
outputs
[
0
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
tflite_op
->
outputs
[
0
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
return
RET_OK
;
return
RET_OK
;
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc
0 → 100644
浏览文件 @
a63ee29d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_l2norm_parser.h"
#include <vector>
#include <memory>
#include <map>
namespace
mindspore
{
namespace
lite
{
STATUS
TfliteL2NormParser
::
Parse
(
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
TensorT
>>
&
tflite_tensors
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
BufferT
>>
&
tflite_model_buffer
,
schema
::
CNodeT
*
op
,
std
::
vector
<
int32_t
>
*
tensors_id
,
std
::
vector
<
schema
::
Format
>
*
tensors_format
,
std
::
map
<
int
,
int
>
*
tensors_id_map
)
{
MS_LOG
(
DEBUG
)
<<
"parse TfliteL2NormParser"
;
// set attr
if
(
op
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"op is null"
;
return
RET_NULL_PTR
;
}
op
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
if
(
op
->
primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"op->primitive is null"
;
return
RET_NULL_PTR
;
}
std
::
unique_ptr
<
schema
::
L2NormT
>
attr
(
new
schema
::
L2NormT
());
auto
data_index
=
tflite_op
->
inputs
[
0
];
const
auto
&
data_tensor
=
tflite_tensors
[
data_index
];
if
(
data_tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"the input tensor is null"
;
return
RET_NULL_PTR
;
}
auto
ndim
=
data_tensor
->
shape
.
size
();
std
::
vector
<
int32_t
>
axis
;
axis
.
reserve
(
ndim
);
for
(
int
i
=
0
;
i
<
ndim
;
i
++
)
{
axis
.
emplace_back
(
i
);
}
attr
->
axis
=
axis
;
attr
->
epsilon
=
0.0
f
;
op
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_L2Norm
;
op
->
primitive
->
value
.
value
=
attr
.
release
();
// set input
AddOpInput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
tflite_op
->
inputs
[
0
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
AddOpOutput
(
op
,
tensors_id
,
tensors_format
,
tensors_id_map
,
tflite_op
->
outputs
[
0
],
tensors_id
->
size
(),
tflite_tensors
.
size
(),
schema
::
Format_NHWC
);
return
RET_OK
;
}
TfliteNodeRegister
g_tfliteL2NormParser
(
"L2_NORMALIZATION"
,
new
TfliteL2NormParser
());
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h
0 → 100644
浏览文件 @
a63ee29d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace
mindspore
{
namespace
lite
{
class
TfliteL2NormParser
:
public
TfliteNodeParser
{
public:
TfliteL2NormParser
()
:
TfliteNodeParser
(
"L2_NORMALIZATION"
)
{}
STATUS
Parse
(
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
TensorT
>>
&
tflite_tensors
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
BufferT
>>
&
tflite_model_buffer
,
schema
::
CNodeT
*
op
,
std
::
vector
<
int32_t
>
*
tensors_id
,
std
::
vector
<
schema
::
Format
>
*
tensors_format
,
std
::
map
<
int
,
int
>
*
tensors_id_map
)
override
;
};
}
// namespace lite
}
// namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
浏览文件 @
a63ee29d
...
@@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
...
@@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
for
(
const
auto
&
tflite_op
:
tflite_subgraph
->
operators
)
{
for
(
const
auto
&
tflite_op
:
tflite_subgraph
->
operators
)
{
auto
tflite_op_type
=
(
tflite_model
->
operator_codes
[
tflite_op
->
opcode_index
])
->
builtin_code
;
auto
tflite_op_type
=
(
tflite_model
->
operator_codes
[
tflite_op
->
opcode_index
])
->
builtin_code
;
auto
op_type
=
GetMSOpType
(
tflite_op_type
);
auto
op_type
=
GetMSOpType
(
tflite_op_type
);
if
(
op_type
==
"CUSTOM"
)
{
auto
custom_type
=
(
tflite_model
->
operator_codes
[
tflite_op
->
opcode_index
])
->
custom_code
;
MS_LOG
(
ERROR
)
<<
"CUSTOM op is not supported, the type is "
<<
custom_type
;
return
RET_ERROR
;
}
std
::
unique_ptr
<
schema
::
CNodeT
>
op
(
new
schema
::
CNodeT
);
std
::
unique_ptr
<
schema
::
CNodeT
>
op
(
new
schema
::
CNodeT
);
op
->
name
=
op_type
+
"-"
+
std
::
to_string
(
idx
++
);
op
->
name
=
op_type
+
"-"
+
std
::
to_string
(
idx
++
);
...
@@ -216,7 +221,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
...
@@ -216,7 +221,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
TfliteModelParser
::
UpdateOp
(
schema
::
MetaGraphT
*
sub_graph
)
{
STATUS
TfliteModelParser
::
ConvertGroupDepthwiseOp
(
schema
::
MetaGraphT
*
sub_graph
)
{
for
(
auto
&
op
:
sub_graph
->
nodes
)
{
for
(
auto
&
op
:
sub_graph
->
nodes
)
{
if
(
op
->
primitive
->
value
.
type
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
if
(
op
->
primitive
->
value
.
type
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
auto
attr
=
op
->
primitive
->
value
.
AsDepthwiseConv2D
();
auto
attr
=
op
->
primitive
->
value
.
AsDepthwiseConv2D
();
...
@@ -248,7 +253,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
...
@@ -248,7 +253,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
auto
weight_id
=
op
->
inputIndex
[
1
];
auto
weight_id
=
op
->
inputIndex
[
1
];
auto
&
weight_tensor
=
sub_graph
->
allTensors
.
at
(
weight_id
);
auto
&
weight_tensor
=
sub_graph
->
allTensors
.
at
(
weight_id
);
if
(
weight_tensor
->
dataType
==
TypeId
::
kNumberTypeUInt8
)
{
if
(
weight_tensor
->
dataType
==
TypeId
::
kNumberTypeUInt8
)
{
// convert weight format KHWC -> CHWK
auto
status
=
TransFilterFormat
<
uint8_t
>
(
weight_tensor
.
get
(),
kKHWC2CHWK
);
auto
status
=
TransFilterFormat
<
uint8_t
>
(
weight_tensor
.
get
(),
kKHWC2CHWK
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Trans depthwiseConv Filter Format failed."
;
MS_LOG
(
ERROR
)
<<
"Trans depthwiseConv Filter Format failed."
;
...
@@ -256,13 +260,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
...
@@ -256,13 +260,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
}
}
}
}
if
(
weight_tensor
->
dataType
==
kNumberTypeFloat32
||
weight_tensor
->
dataType
==
kNumberTypeFloat
)
{
if
(
weight_tensor
->
dataType
==
kNumberTypeFloat32
||
weight_tensor
->
dataType
==
kNumberTypeFloat
)
{
// convert weight format KHWC -> CHWK
auto
status
=
TransFilterFormat
<
float
>
(
weight_tensor
.
get
(),
kKHWC2CHWK
);
auto
status
=
TransFilterFormat
<
float
>
(
weight_tensor
.
get
(),
kKHWC2CHWK
);
if
(
status
!=
RET_OK
)
{
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Trans
depthwiseConv Filter F
ormat failed."
;
MS_LOG
(
ERROR
)
<<
"Trans
filter f
ormat failed."
;
return
RET_ERROR
;
return
RET_ERROR
;
}
}
}
}
weight_tensor
->
format
=
schema
::
Format_CHWK
;
}
}
}
}
}
}
...
@@ -303,8 +307,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s
...
@@ -303,8 +307,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s
}
}
// update for depthwiseConv
// update for depthwiseConv
if
(
Updat
eOp
(
sub_graph
.
get
())
!=
RET_OK
)
{
if
(
ConvertGroupDepthwis
eOp
(
sub_graph
.
get
())
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"
update
depthwise conv failed"
;
MS_LOG
(
ERROR
)
<<
"
convert group
depthwise conv failed"
;
return
nullptr
;
return
nullptr
;
}
}
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
浏览文件 @
a63ee29d
...
@@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser {
...
@@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser {
STATUS
GetGraphInfo
(
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
STATUS
GetGraphInfo
(
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
schema
::
MetaGraphT
*
sub_graph
);
schema
::
MetaGraphT
*
sub_graph
);
STATUS
Updat
eOp
(
schema
::
MetaGraphT
*
sub_graph
);
STATUS
ConvertGroupDepthwis
eOp
(
schema
::
MetaGraphT
*
sub_graph
);
private:
private:
std
::
vector
<
int32_t
>
tensorsId
;
std
::
vector
<
int32_t
>
tensorsId
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录