Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d9787144
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d9787144
编写于
8月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4713 fix anf exporter
Merge pull request !4713 from wangchangkai/master
上级
465390e5
5ca7be57
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
5 deletion
+37
-5
mindspore/lite/tools/anf_exporter/anf_exporter.cc
mindspore/lite/tools/anf_exporter/anf_exporter.cc
+2
-1
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc
...te/tools/anf_importer/anf_populater/anf_conv_populater.cc
+24
-2
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h
...ite/tools/anf_importer/anf_populater/anf_conv_populater.h
+1
-1
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
+6
-1
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc
...ter/legacy_optimizer/graph/weight_format_hardcode_pass.cc
+4
-0
未找到文件。
mindspore/lite/tools/anf_exporter/anf_exporter.cc
浏览文件 @
d9787144
...
...
@@ -387,7 +387,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
meta_graphT
->
allTensors
.
emplace_back
(
msTensor
);
if
(
IsPrimitiveCNode
(
cnode
,
schema
::
PrimitiveType_Conv2D
)
||
IsPrimitiveCNode
(
cnode
,
schema
::
PrimitiveType_DepthwiseConv2D
))
{
||
IsPrimitiveCNode
(
cnode
,
schema
::
PrimitiveType_DepthwiseConv2D
)
||
IsPrimitiveCNode
(
cnode
,
schema
::
PrimitiveType_FusedBatchNorm
))
{
break
;
}
}
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc
浏览文件 @
d9787144
...
...
@@ -29,7 +29,7 @@
namespace
mindspore
::
lite
{
void
AnfConvPopulater
::
PopulaterConv2DMultiGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
)
{
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
DepthwiseConv2DT
>
();
auto
format
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
...
...
@@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
int
channel_mutiplier
=
1
;
if
(
prim
->
GetAttr
(
"channel_mutiplier"
)
!=
nullptr
)
{
channel_mutiplier
=
GetValue
<
int
>
(
prim
->
GetAttr
(
"channel_multiplier"
));
}
attr
->
channelMultiplier
=
channel_mutiplier
;
MS_ASSERT
(
inputs
.
size
()
==
kAnfPopulaterTwo
);
auto
inputNode
=
inputs
[
kAnfPopulaterOne
];
MS_ASSERT
(
inputNode
!=
nullptr
);
if
(
inputNode
->
isa
<
Parameter
>
())
{
auto
paramNode
=
inputNode
->
cast
<
ParameterPtr
>
();
auto
abstractBase
=
paramNode
->
abstract
();
MS_ASSERT
(
abstractBase
!=
nullptr
);
if
(
utils
::
isa
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
))
{
auto
abstractTensor
=
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
);
MS_ASSERT
(
abstractTensor
!=
nullptr
);
if
(
abstractTensor
->
format
()
==
schema
::
Format_NCHW
)
{
abstractTensor
->
set_format
(
schema
::
Format_KCHW
);
}
}
}
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
primitive
->
value
.
value
=
attr
.
release
();
}
...
...
@@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
int
group
=
GetValue
<
int
>
(
prim
->
GetAttr
(
"group"
));
if
(
group
>
1
)
{
PopulaterConv2DMultiGroup
(
prim
,
primitive
,
group
);
PopulaterConv2DMultiGroup
(
prim
,
primitive
,
group
,
inputs
);
}
else
{
PopulaterConv2DSingleGroup
(
prim
,
primitive
,
group
);
}
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h
浏览文件 @
d9787144
...
...
@@ -35,7 +35,7 @@ class AnfConvPopulater : public AnfNodePopulater {
private:
void
PopulaterConv2DMultiGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
);
void
PopulaterConv2DSingleGroup
(
const
PrimitivePtr
&
prim
,
const
std
::
unique_ptr
<
schema
::
PrimitiveT
>
&
primitive
,
const
int
&
group
);
...
...
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
浏览文件 @
d9787144
...
...
@@ -1129,7 +1129,12 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
output_shape
);
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
auto
primReturn
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
MS_ASSERT
(
primReturn
!=
nullptr
);
primReturn
->
value
.
type
=
schema
::
PrimitiveType_Return
;
std
::
shared_ptr
<
PrimitiveTValue
>
primitiveTReturnValuePtr
=
std
::
make_shared
<
PrimitiveTValue
>
(
primReturn
.
release
());
MS_ASSERT
(
primitiveTReturnValuePtr
!=
nullptr
);
inputs
.
push_back
(
NewValueNode
(
primitiveTReturnValuePtr
));
inputs
.
push_back
(
cnode_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
...
...
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc
浏览文件 @
d9787144
...
...
@@ -18,6 +18,7 @@
#include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/node_util.h"
namespace
mindspore
{
namespace
lite
{
...
...
@@ -166,6 +167,9 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
if
(
opType
==
PrimitiveType_Conv2D
)
{
weightTensor
->
format
=
Format_KCHW
;
}
else
if
(
opType
==
PrimitiveType_DepthwiseConv2D
)
{
if
(
weightTensor
->
format
==
Format_KCHW
)
{
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2CKHW
);
}
weightTensor
->
format
=
Format_CKHW
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Unsupported opType: "
<<
EnumNamePrimitiveType
(
opType
)
<<
", node: "
<<
node
->
name
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录