Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9eb0a769
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9eb0a769
编写于
8月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3713 detect model fix bug
Merge pull request !3713 from ghzl/deconv-adapter
上级
d7bc28dc
668db1dd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
29 addition
and
10 deletion
+29
-10
mindspore/lite/src/model_impl.cc
mindspore/lite/src/model_impl.cc
+2
-0
mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc
...ls/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc
+9
-2
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
...lite/tools/converter/optimizer/node/weight_format_pass.cc
+9
-7
mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc
...e/lite/tools/converter/parser/tflite/tflite_add_parser.cc
+9
-1
未找到文件。
mindspore/lite/src/model_impl.cc
浏览文件 @
9eb0a769
...
...
@@ -80,6 +80,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return
new
lite
::
Activation
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_Conv2D
:
return
new
lite
::
Conv2D
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_DeConv2D
:
return
new
lite
::
DeConv2D
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_Reduce
:
return
new
lite
::
Reduce
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_Pooling
:
...
...
mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc
浏览文件 @
9eb0a769
...
...
@@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat
}
auto
baNodeBiasTensor
=
graph
->
allTensors
.
at
(
baNodeInputIndex
[
BIASADD_OP_CONST_TENSOR_INDEX
]).
get
();
MS_ASSERT
(
baNodeBiasTensor
!=
nullptr
);
if
(
baNodeBiasTensor
->
refCount
!=
schema
::
NodeType_ValueNode
)
{
if
(
baNodeBiasTensor
->
nodeType
!=
schema
::
NodeType_ValueNode
)
{
// dont fusion, return
return
RET_OK
;
}
...
...
@@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
<<
". or bias tensor is a scaler"
;
return
RET_ERROR
;
}
if
(
!
biasDims
.
empty
()
&&
biasDims
.
at
(
BIASADD_BIAS_DIM_INDEX
)
!=
kernelNum
)
{
bool
bias_const
=
!
biasDims
.
empty
()
&&
biasDims
.
size
()
==
1
&&
biasDims
[
0
]
==
1
;
if
(
!
biasDims
.
empty
()
&&
!
bias_const
&&
biasDims
.
at
(
BIASADD_BIAS_DIM_INDEX
)
!=
kernelNum
)
{
MS_LOG
(
ERROR
)
<<
"Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
<<
biasDims
.
at
(
BIASADD_BIAS_DIM_INDEX
)
<<
baNode
->
name
.
c_str
()
<<
kernelNum
;
return
RET_ERROR
;
...
...
@@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
MS_LOG
(
ERROR
)
<<
"memset_s newBiasData failed"
;
return
RET_ERROR
;
}
}
else
if
(
bias_const
)
{
auto
*
biasData
=
reinterpret_cast
<
float
*>
(
biasTensor
->
data
.
data
());
for
(
size_t
i
=
0
;
i
<
kernelNum
;
i
++
)
{
newBiasData
[
i
]
=
*
biasData
;
}
}
else
{
if
(
0
!=
memcpy_s
(
newBiasData
,
kernelNum
*
sizeof
(
float
),
biasTensor
->
data
.
data
(),
kernelNum
*
sizeof
(
float
)))
{
MS_LOG
(
ERROR
)
<<
"memcpy_s newBiasData failed"
;
...
...
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
浏览文件 @
9eb0a769
...
...
@@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
weightTensor
->
format
=
schema
::
Format_KHWC
;
}
else
if
(
opType
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
weightTensor
->
format
=
schema
::
Format_CHWK
;
}
else
if
(
opType
==
schema
::
PrimitiveType_DeConv2D
)
{
weightTensor
->
format
=
schema
::
Format_KHWC
;
}
else
{
MS_LOG
(
ERROR
)
<<
"unsupport format"
;
return
-
1
;
...
...
@@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG
(
WARNING
)
<<
"TransFilter HWCKToCKHW failed, node : "
<<
node
->
name
.
c_str
();
// todo(00445839): consider varible weight condition
}
}
else
if
(
opType
==
schema
::
PrimitiveType_DeConv2D
)
{
// weight should be K
CHW
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe or onnx
return
0
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_
HWK
C
)
{
// from tf
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kHWKC2KCHW
)
;
}
else
if
(
opType
==
schema
::
PrimitiveType_DeConv2D
)
{
// weight should be K
HWC
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe or onnx or ms
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2KHWC
)
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_
KHW
C
)
{
// from tf
status
=
RET_OK
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Unsupported weightTensor format: "
<<
weightTensor
->
format
;
return
-
1
;
}
if
(
status
==
0
)
{
node
->
primitive
->
value
.
AsDe
pthwise
Conv2D
()
->
format
=
schema
::
Format_NCHW
;
weightTensor
->
format
=
schema
::
Format_K
CHW
;
node
->
primitive
->
value
.
AsDeConv2D
()
->
format
=
schema
::
Format_NCHW
;
weightTensor
->
format
=
schema
::
Format_K
HWC
;
}
else
{
MS_LOG
(
WARNING
)
<<
"TransFilter HWKCToKCHW failed, node : "
<<
node
->
name
.
c_str
();
// todo(00445839): consider varible weight condition
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc
浏览文件 @
9eb0a769
...
...
@@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
schema
::
CNodeT
*
op
,
TensorCache
*
tensor_cache
,
bool
quantizedModel
)
{
// MS_LOGD("parse TfliteAddParser")
;
MS_LOG
(
DEBUG
)
<<
"parse TfliteAddParser"
;
std
::
unique_ptr
<
schema
::
AddT
>
attr
(
new
schema
::
AddT
());
auto
weight_index
=
tfliteOp
->
inputs
[
1
];
const
auto
&
weight_tensor
=
tfliteTensors
[
weight_index
];
std
::
vector
<
tflite
::
TensorT
*>
weight_tensors
{
weight_tensor
.
get
()};
if
(
RET_OK
!=
ParseWeight
(
weight_tensors
,
tfliteModelBuffer
,
tensor_cache
,
schema
::
Format_KHWC
))
{
return
RET_ERROR
;
}
if
(
op
!=
nullptr
)
{
op
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
op
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Add
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录