Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7d3f56c5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d3f56c5
编写于
8月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3688 Add deconv parser
Merge pull request !3688 from ghzl/add-deconv-parser
上级
9ba4e4d8
88ba8ee4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
124 addition
and
6 deletion
+124
-6
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
...ite/tools/converter/parser/tflite/tflite_deconv_parser.cc
+68
-0
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h
...lite/tools/converter/parser/tflite/tflite_deconv_parser.h
+41
-0
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
...lite/tools/converter/parser/tflite/tflite_model_parser.cc
+12
-5
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
.../lite/tools/converter/parser/tflite/tflite_model_parser.h
+2
-1
mindspore/lite/tools/converter/parser/tflite/tflite_util.cc
mindspore/lite/tools/converter/parser/tflite/tflite_util.cc
+1
-0
未找到文件。
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
0 → 100644
浏览文件 @
7d3f56c5
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_deconv_parser.h"
namespace
mindspore
{
namespace
lite
{
STATUS
TfliteDeConvParser
::
Parse
(
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
TensorT
>>
&
tflite_tensors
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
BufferT
>>
&
tflite_model_buffer
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
OperatorCodeT
>>
&
tflite_op_set
,
schema
::
CNodeT
*
op
,
TensorCache
*
tensor_cache
,
bool
quantized_model
)
{
MS_LOG
(
DEBUG
)
<<
"parse tflite Transpose_Conv parser"
;
std
::
unique_ptr
<
schema
::
DeConv2DT
>
attr
(
new
schema
::
DeConv2DT
());
const
auto
&
tflite_attr
=
tflite_op
->
builtin_options
.
AsTransposeConvOptions
();
if
(
tflite_attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get op: %s attr failed"
,
op
->
name
.
c_str
();
return
RET_NULL_PTR
;
}
attr
->
group
=
1
;
attr
->
strideW
=
tflite_attr
->
stride_w
;
attr
->
strideH
=
tflite_attr
->
stride_h
;
attr
->
dilateH
=
1
;
attr
->
dilateW
=
1
;
attr
->
padMode
=
GetPadMode
(
tflite_attr
->
padding
);
attr
->
format
=
schema
::
Format_NHWC
;
// get the conv op weight tensor
auto
weight_index
=
tflite_op
->
inputs
[
1
];
const
auto
&
weight_tensor
=
tflite_tensors
[
weight_index
];
std
::
vector
<
tflite
::
TensorT
*>
weight_tensors
{
weight_tensor
.
get
()};
if
(
RET_OK
!=
ParseWeight
(
weight_tensors
,
tflite_model_buffer
,
tensor_cache
,
schema
::
Format_KHWC
))
{
return
RET_ERROR
;
}
auto
weight_shape
=
weight_tensor
->
shape
;
attr
->
channelIn
=
weight_shape
[
KHWC_C
];
attr
->
channelOut
=
weight_shape
[
KHWC_K
];
attr
->
kernelW
=
weight_shape
[
KHWC_W
];
attr
->
kernelH
=
weight_shape
[
KHWC_H
];
if
(
op
!=
nullptr
)
{
op
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
op
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DeConv2D
;
op
->
primitive
->
value
.
value
=
attr
.
release
();
}
return
RET_OK
;
}
TfliteNodeRegister
g_tfliteDeConv2DParser
(
"DeConv2D"
,
new
TfliteDeConvParser
());
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h
0 → 100644
浏览文件 @
7d3f56c5
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_DECONV_PARSER_H
#define PREDICT_TFLITE_DECONV_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace
mindspore
{
namespace
lite
{
class
TfliteDeConvParser
:
public
TfliteNodeParser
{
public:
TfliteDeConvParser
()
:
TfliteNodeParser
(
"DeConv2D"
)
{}
STATUS
Parse
(
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
TensorT
>>
&
tflite_tensors
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
BufferT
>>
&
tflite_model_buffer
,
const
std
::
vector
<
std
::
unique_ptr
<
tflite
::
OperatorCodeT
>>
&
tflite_op_set
,
schema
::
CNodeT
*
op
,
TensorCache
*
tensor_cache
,
bool
quantizedModel
)
override
;
};
}
// namespace lite
}
// namespace mindspore
#endif // PREDICT_TFLITE_DECONV_PARSER_H
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
浏览文件 @
7d3f56c5
...
...
@@ -112,10 +112,17 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
return
RET_OK
;
}
STATUS
TfliteModelParser
::
SetOpInputIdx
(
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
STATUS
TfliteModelParser
::
SetOpInputIdx
(
const
std
::
unique_ptr
<
tflite
::
ModelT
>
&
tflite_model
,
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
TensorCache
*
tensorCache
)
{
for
(
const
auto
&
tfliteIndex
:
tflite_op
->
inputs
)
{
const
auto
&
tflite_tensor
=
tflite_subgraph
->
tensors
[
tfliteIndex
];
auto
op_type
=
GetTfliteNodeType
(
tflite_op
,
tflite_model
);
std
::
vector
<
int32_t
>
op_inputs
(
tflite_op
->
inputs
);
if
(
op_type
==
"DeConv2D"
)
{
reverse
(
op_inputs
.
begin
(),
op_inputs
.
end
());
}
for
(
const
auto
&
tflite_index
:
op_inputs
)
{
const
auto
&
tflite_tensor
=
tflite_subgraph
->
tensors
[
tflite_index
];
auto
tensor_name
=
tflite_tensor
->
name
;
auto
op
=
tfliteOpMap
[
tflite_op
.
get
()];
unsigned
int
index
=
tensorCache
->
FindTensor
(
tensor_name
);
...
...
@@ -228,8 +235,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
}
for
(
const
auto
&
tflite_op
:
tflite_subgraph
->
operators
)
{
auto
status
Tmp
=
SetOpInputIdx
(
tflite_subgraph
,
tflite_op
,
&
tensorCache
);
if
(
status
T
mp
!=
RET_OK
)
{
auto
status
_tmp
=
SetOpInputIdx
(
tflite_model
,
tflite_subgraph
,
tflite_op
,
&
tensorCache
);
if
(
status
_t
mp
!=
RET_OK
)
{
// MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str());
}
}
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
浏览文件 @
7d3f56c5
...
...
@@ -73,7 +73,8 @@ class TfliteModelParser : public ModelParser {
schema
::
CNodeT
*
op
,
TensorCache
*
tensorCache
);
STATUS
SetOpInputIdx
(
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
STATUS
SetOpInputIdx
(
const
std
::
unique_ptr
<
tflite
::
ModelT
>
&
tflite_model
,
const
std
::
unique_ptr
<
tflite
::
SubGraphT
>
&
tflite_subgraph
,
const
std
::
unique_ptr
<
tflite
::
OperatorT
>
&
tflite_op
,
TensorCache
*
tensorCache
);
std
::
map
<
std
::
string
,
schema
::
CNodeT
*>
opMap
;
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_util.cc
浏览文件 @
7d3f56c5
...
...
@@ -55,6 +55,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{
tflite
::
BuiltinOperator_ARG_MAX
,
"Argmax"
},
{
tflite
::
BuiltinOperator_SQUARED_DIFFERENCE
,
"SquaredDifference"
},
{
tflite
::
BuiltinOperator_FAKE_QUANT
,
"FakeQuant"
},
{
tflite
::
BuiltinOperator_TRANSPOSE_CONV
,
"DeConv2D"
},
};
std
::
string
GetMSOpType
(
tflite
::
BuiltinOperator
tfliteOpType
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录