Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bed056aa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bed056aa
编写于
8月 10, 2020
作者:
Y
yeyunpeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix DeDepthwiseConv2D problem
上级
0ae5eeb3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
12 deletion
+13
-12
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
...ols/converter/legacy_optimizer/node/weight_format_pass.cc
+12
-10
mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc
...ools/converter/parser/caffe/caffe_deconvolution_parser.cc
+1
-2
未找到文件。
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
浏览文件 @
bed056aa
...
...
@@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
switch
(
node
->
quantType
)
{
case
QuantType_QUANT_NONE
:
{
if
(
opType
==
schema
::
PrimitiveType_Conv2D
||
opType
==
schema
::
PrimitiveType_DepthwiseConv2D
||
opType
==
schema
::
PrimitiveType_DeConv2D
)
{
opType
==
schema
::
PrimitiveType_DeConv2D
||
opType
==
schema
::
PrimitiveType_DeDepthwiseConv2D
)
{
weightTensor
->
format
=
schema
::
Format_KCHW
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Invalid opType: "
<<
schema
::
EnumNamePrimitiveType
(
opType
)
...
...
@@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
return
RET_OK
;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWCK
)
{
// from tf
return
0
;
}
else
{
...
...
@@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWCK
)
{
// from tf
return
0
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_CHWK
)
{
// from onnx
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
int8_t
>
(
weightTensor
.
get
(),
kCHWK2KHWC
);
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
...
...
@@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG
(
WARNING
)
<<
"TransFilter HWKCToKCHW failed, node : "
<<
node
->
name
.
c_str
();
// todo(00445839): consider varible weight condition
}
}
else
if
(
opType
==
schema
::
PrimitiveType_DeDepthwiseConv2D
)
{
// weight should be
CKHW
if
(
weightTensor
->
format
==
schema
::
Format_
CKHW
)
{
// from caffe
}
else
if
(
opType
==
schema
::
PrimitiveType_DeDepthwiseConv2D
)
{
// weight should be
KHWC
if
(
weightTensor
->
format
==
schema
::
Format_
KHWC
)
{
return
0
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2KHWC
);
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWKC
)
{
// from tf or onnx
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kHWKC2CKHW
);
}
else
{
...
...
@@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
return
-
1
;
}
if
(
status
==
0
)
{
node
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
format
=
schema
::
Format_NHWC
;
node
->
primitive
->
value
.
AsDe
De
pthwiseConv2D
()
->
format
=
schema
::
Format_NHWC
;
weightTensor
->
format
=
schema
::
Format_CKHW
;
}
else
{
MS_LOG
(
WARNING
)
<<
"TransFilter HWKCToCKHW failed, node : "
<<
node
->
name
.
c_str
();
...
...
mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc
浏览文件 @
bed056aa
...
...
@@ -46,14 +46,13 @@ void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schem
deDepthwiseConv2DParam
->
hasBias
=
attr
->
hasBias
;
deDepthwiseConv2DParam
->
activationType
=
attr
->
activationType
;
delete
attr
;
op
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
op
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DeDepthwiseConv2D
;
op
->
primitive
->
value
.
value
=
deDepthwiseConv2DParam
.
release
();
}
STATUS
CaffeDeconvolutionParser
::
Parse
(
const
caffe
::
LayerParameter
&
proto
,
const
caffe
::
LayerParameter
&
weight
,
schema
::
CNodeT
*
op
,
std
::
vector
<
schema
::
TensorT
*>
*
weightVec
)
{
op
->
name
=
proto
.
name
();
schema
::
DeConv2DT
*
attr
=
new
schema
::
DeConv2DT
();
auto
*
attr
=
new
schema
::
DeConv2DT
();
attr
->
format
=
schema
::
Format_NCHW
;
const
caffe
::
ConvolutionParameter
convParam
=
proto
.
convolution_param
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录