Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6e57281c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e57281c
编写于
8月 06, 2020
作者:
Z
zhengjun10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable converter anf fusion pass and optimize code
上级
fe7141e9
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
1114 addition
and
1026 deletion
+1114
-1026
mindspore/lite/src/gllo/common/gllo_utils.cc
mindspore/lite/src/gllo/common/gllo_utils.cc
+130
-68
mindspore/lite/src/gllo/common/gllo_utils.h
mindspore/lite/src/gllo/common/gllo_utils.h
+9
-6
mindspore/lite/src/gllo/common/node_pass.cc
mindspore/lite/src/gllo/common/node_pass.cc
+5
-1
mindspore/lite/src/gllo/common/node_pass.h
mindspore/lite/src/gllo/common/node_pass.h
+0
-36
mindspore/lite/src/gllo/common/optimizer.cc
mindspore/lite/src/gllo/common/optimizer.cc
+1
-2
mindspore/lite/src/gllo/common/optimizer.h
mindspore/lite/src/gllo/common/optimizer.h
+3
-3
mindspore/lite/src/gllo/common/pass_manager.cc
mindspore/lite/src/gllo/common/pass_manager.cc
+1
-1
mindspore/lite/src/gllo/common/pass_manager.h
mindspore/lite/src/gllo/common/pass_manager.h
+0
-61
mindspore/lite/src/gllo/common/pattern_engine.cc
mindspore/lite/src/gllo/common/pattern_engine.cc
+0
-365
mindspore/lite/src/gllo/common/pattern_engine.h
mindspore/lite/src/gllo/common/pattern_engine.h
+0
-203
mindspore/lite/src/gllo/common/visit.cc
mindspore/lite/src/gllo/common/visit.cc
+0
-165
mindspore/lite/src/gllo/common/visit.h
mindspore/lite/src/gllo/common/visit.h
+0
-59
mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc
mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc
+5
-5
mindspore/lite/src/gllo/fusion/conv_activation_fusion.h
mindspore/lite/src/gllo/fusion/conv_activation_fusion.h
+1
-1
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc
+7
-7
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h
+1
-1
mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc
mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc
+6
-6
mindspore/lite/src/gllo/fusion/conv_bn_fusion.h
mindspore/lite/src/gllo/fusion/conv_bn_fusion.h
+1
-1
mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc
mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc
+6
-6
mindspore/lite/src/gllo/fusion/conv_scale_fusion.h
mindspore/lite/src/gllo/fusion/conv_scale_fusion.h
+1
-1
mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc
mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc
+16
-6
mindspore/lite/src/gllo/fusion/conv_transform_fusion.h
mindspore/lite/src/gllo/fusion/conv_transform_fusion.h
+1
-1
mindspore/lite/test/CMakeLists.txt
mindspore/lite/test/CMakeLists.txt
+7
-3
mindspore/lite/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
...te/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
+184
-0
mindspore/lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
.../lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
+194
-0
mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
...spore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
+296
-0
mindspore/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
...re/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
+221
-0
mindspore/lite/tools/converter/CMakeLists.txt
mindspore/lite/tools/converter/CMakeLists.txt
+3
-3
mindspore/lite/tools/converter/converter.cc
mindspore/lite/tools/converter/converter.cc
+1
-1
mindspore/lite/tools/converter/graphdef_transform.cc
mindspore/lite/tools/converter/graphdef_transform.cc
+14
-14
未找到文件。
mindspore/lite/src/gllo/common/utils.cc
→
mindspore/lite/src/gllo/common/
gllo_
utils.cc
浏览文件 @
6e57281c
...
@@ -13,83 +13,46 @@
...
@@ -13,83 +13,46 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "src/gllo/common/gllo_utils.h"
#include <vector>
#include <vector>
#include <memory>
#include "src/gllo/common/utils.h"
#include "src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/ops.h"
using
PrimitiveTValuePtr
=
std
::
shared_ptr
<
mindspore
::
lite
::
PrimitiveTValue
>
;
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
bool
AnfEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
)
{
namespace
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
a
)
&&
utils
::
isa
<
AnfNodePtr
>
(
b
))
{
constexpr
auto
kAnfPrimitiveIndex
=
0
;
auto
a_node
=
utils
::
cast
<
AnfNodePtr
>
(
a
);
bool
CheckPrimitiveType
(
const
AnfNodePtr
&
node
,
const
PrimitivePtr
&
primitive_type
)
{
auto
b_node
=
utils
::
cast
<
AnfNodePtr
>
(
b
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
a_node
);
if
(
!
node
->
isa
<
CNode
>
())
{
MS_EXCEPTION_IF_NULL
(
b_node
);
return
false
;
if
(
IsValueNode
<
Primitive
>
(
a_node
)
&&
IsValueNode
<
Primitive
>
(
b_node
))
{
auto
a_value_node
=
a_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
a_value_node
);
auto
a_value
=
a_value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
a_value
);
auto
a_prim
=
a_value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
a_prim
);
auto
b_value_node
=
b_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
b_value_node
);
auto
b_value
=
b_value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
b_value
);
auto
b_prim
=
b_value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
b_prim
);
return
a_prim
->
name
()
==
b_prim
->
name
();
}
else
if
(
a_node
->
isa
<
ValueNode
>
()
&&
b_node
->
isa
<
ValueNode
>
())
{
auto
a_value_node_ptr
=
a_node
->
cast
<
ValueNodePtr
>
();
if
(
a_value_node_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"cast value node ptr fail"
;
}
auto
a_value_ptr
=
a_value_node_ptr
->
value
();
if
(
a_value_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"value ptr is nullptr"
;
}
auto
b_value_node_ptr
=
b_node
->
cast
<
ValueNodePtr
>
();
if
(
b_value_node_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"cast value node ptr fail"
;
}
auto
b_value_ptr
=
b_value_node_ptr
->
value
();
if
(
b_value_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"value ptr is nullptr"
;
}
if
(
utils
::
isa
<
lite
::
PrimitiveTValue
>
(
a_value_ptr
)
&&
utils
::
isa
<
lite
::
PrimitiveTValue
>
(
b_value_ptr
))
{
auto
a_obj
=
(
lite
::
PrimitiveTValue
*
)(
a_value_ptr
.
get
());
auto
b_obj
=
(
lite
::
PrimitiveTValue
*
)(
b_value_ptr
.
get
());
return
(
*
a_obj
)
==
(
*
b_obj
);
}
else
{
return
(
*
a_value_ptr
)
==
(
*
b_value_ptr
);
}
}
}
if
(
a
.
m_ptr
->
isa
<
lite
::
PrimitiveTValue
>
())
{
auto
a_value_node_ptr
=
a
.
m_ptr
->
cast
<
PrimitiveTValuePtr
>
();
auto
b_value_node_ptr
=
b
.
m_ptr
->
cast
<
PrimitiveTValuePtr
>
();
return
a_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
==
b_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
;
}
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
return
a
==
b
;
MS_EXCEPTION_IF_NULL
(
cnode
);
return
IsPrimitive
(
cnode
->
input
(
kAnfPrimitiveIndex
),
primitive_type
);
}
}
bool
CNodeTypeEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
)
{
bool
IsRealKernel
(
const
AnfNodePtr
&
node
)
{
// To matchCNode and Kernel's type
MS_EXCEPTION_IF_NULL
(
node
);
if
(
utils
::
isa
<
CNode
>
(
a
)
&&
utils
::
isa
<
CNode
>
(
b
))
{
// parameter and value node is not a real kernel too
if
(
!
node
->
isa
<
CNode
>
())
{
return
true
;
return
true
;
}
}
return
a
.
type
()
==
b
.
type
();
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal null input of cnode(%s)"
<<
node
->
DebugString
();
}
auto
input
=
cnode
->
inputs
()[
0
];
bool
is_virtual_node
=
IsPrimitive
(
input
,
prim
::
kPrimImageSummary
)
||
IsPrimitive
(
input
,
prim
::
kPrimScalarSummary
)
||
IsPrimitive
(
input
,
prim
::
kPrimTensorSummary
)
||
IsPrimitive
(
input
,
prim
::
kPrimHistogramSummary
)
||
IsPrimitive
(
input
,
prim
::
kPrimMakeTuple
)
||
IsPrimitive
(
input
,
prim
::
kPrimStateSetItem
)
||
IsPrimitive
(
input
,
prim
::
kPrimDepend
)
||
IsPrimitive
(
input
,
prim
::
kPrimTupleGetItem
)
||
IsPrimitive
(
input
,
prim
::
kPrimControlDepend
)
||
IsPrimitive
(
input
,
prim
::
kPrimReturn
)
||
IsPrimitive
(
input
,
prim
::
kPrimPartial
);
return
!
is_virtual_node
;
}
}
namespace
{
ValueNodePtr
CreateValueNodeWithSexp
(
const
BaseRef
&
sexp
)
{
ValueNodePtr
CreateValueNodeWithSexp
(
const
BaseRef
&
sexp
)
{
if
(
utils
::
isa
<
int
>
(
sexp
))
{
if
(
utils
::
isa
<
int
>
(
sexp
))
{
return
NewValueNode
(
utils
::
cast
<
int
>
(
sexp
));
return
NewValueNode
(
utils
::
cast
<
int
>
(
sexp
));
...
@@ -118,11 +81,11 @@ CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const
...
@@ -118,11 +81,11 @@ CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const
VarNodePtr
CreateVarNodeWithSexp
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
)
{
VarNodePtr
CreateVarNodeWithSexp
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
)
{
if
(
utils
::
isa
<
VarPtr
>
(
graph
))
{
if
(
utils
::
isa
<
VarPtr
>
(
graph
))
{
//
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
MS_LOG
(
DEBUG
)
<<
"make VarPtr "
+
graph
.
ToString
();
return
std
::
make_shared
<
VarNode
>
(
utils
::
cast
<
VarPtr
>
(
sexp
),
nullptr
);
return
std
::
make_shared
<
VarNode
>
(
utils
::
cast
<
VarPtr
>
(
sexp
),
nullptr
);
}
}
if
(
utils
::
isa
<
FuncGraphPtr
>
(
graph
))
{
if
(
utils
::
isa
<
FuncGraphPtr
>
(
graph
))
{
//
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
MS_LOG
(
DEBUG
)
<<
"VarNode, should input a Var in graph. It's GraphPtr: "
+
graph
.
ToString
();
return
std
::
make_shared
<
VarNode
>
(
utils
::
cast
<
VarPtr
>
(
sexp
),
utils
::
cast
<
FuncGraphPtr
>
(
graph
));
return
std
::
make_shared
<
VarNode
>
(
utils
::
cast
<
VarPtr
>
(
sexp
),
utils
::
cast
<
FuncGraphPtr
>
(
graph
));
}
}
MS_LOG
(
ERROR
)
<<
"VarNode, should input a Var in graph. It's "
+
graph
.
ToString
();
MS_LOG
(
ERROR
)
<<
"VarNode, should input a Var in graph. It's "
+
graph
.
ToString
();
...
@@ -131,7 +94,7 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
...
@@ -131,7 +94,7 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
)
{
bool
multigraph
)
{
//
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_LOG
(
DEBUG
)
<<
"HandleSexpVector sexp: "
+
sexp
.
ToString
()
+
", graph "
+
graph
.
ToString
();
std
::
vector
<
AnfNodePtr
>
input_nodes
;
std
::
vector
<
AnfNodePtr
>
input_nodes
;
const
auto
&
tuple
=
utils
::
cast
<
VectorRef
>
(
sexp
);
const
auto
&
tuple
=
utils
::
cast
<
VectorRef
>
(
sexp
);
if
(
multigraph
&&
utils
::
isa
<
VarPtr
>
(
graph
))
{
if
(
multigraph
&&
utils
::
isa
<
VarPtr
>
(
graph
))
{
...
@@ -151,8 +114,75 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
...
@@ -151,8 +114,75 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
}
}
}
// namespace
}
// namespace
bool
AnfEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
a
)
&&
utils
::
isa
<
AnfNodePtr
>
(
b
))
{
auto
a_node
=
utils
::
cast
<
AnfNodePtr
>
(
a
);
auto
b_node
=
utils
::
cast
<
AnfNodePtr
>
(
b
);
MS_EXCEPTION_IF_NULL
(
a_node
);
MS_EXCEPTION_IF_NULL
(
b_node
);
if
(
IsValueNode
<
Primitive
>
(
a_node
)
&&
IsValueNode
<
Primitive
>
(
b_node
))
{
auto
a_value_node
=
a_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
a_value_node
);
auto
a_value
=
a_value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
a_value
);
auto
a_prim
=
a_value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
a_prim
);
auto
b_value_node
=
b_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
b_value_node
);
auto
b_value
=
b_value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
b_value
);
auto
b_prim
=
b_value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
b_prim
);
return
a_prim
->
name
()
==
b_prim
->
name
();
}
else
if
(
a_node
->
isa
<
ValueNode
>
()
&&
b_node
->
isa
<
ValueNode
>
())
{
auto
a_value_node_ptr
=
a_node
->
cast
<
ValueNodePtr
>
();
if
(
a_value_node_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"cast value node ptr fail"
;
}
auto
a_value_ptr
=
a_value_node_ptr
->
value
();
if
(
a_value_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"value ptr is nullptr"
;
}
auto
b_value_node_ptr
=
b_node
->
cast
<
ValueNodePtr
>
();
if
(
b_value_node_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"cast value node ptr fail"
;
}
auto
b_value_ptr
=
b_value_node_ptr
->
value
();
if
(
b_value_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"value ptr is nullptr"
;
}
if
(
utils
::
isa
<
lite
::
PrimitiveTValue
>
(
a_value_ptr
)
&&
utils
::
isa
<
lite
::
PrimitiveTValue
>
(
b_value_ptr
))
{
auto
a_obj
=
(
lite
::
PrimitiveTValue
*
)
(
a_value_ptr
.
get
());
auto
b_obj
=
(
lite
::
PrimitiveTValue
*
)
(
b_value_ptr
.
get
());
return
(
*
a_obj
)
==
(
*
b_obj
);
}
else
{
return
(
*
a_value_ptr
)
==
(
*
b_value_ptr
);
}
}
}
if
(
a
.
m_ptr
->
isa
<
lite
::
PrimitiveTValue
>
()
&&
b
.
m_ptr
->
isa
<
lite
::
PrimitiveTValue
>
())
{
auto
a_value_node_ptr
=
a
.
m_ptr
->
cast
<
PrimitiveTValuePtr
>
();
auto
b_value_node_ptr
=
b
.
m_ptr
->
cast
<
PrimitiveTValuePtr
>
();
return
a_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
==
b_value_node_ptr
->
GetPrimitiveT
()
->
value
.
type
;
}
return
a
==
b
;
}
bool
CNodeTypeEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
)
{
// To matchCNode and Kernel's type
if
(
utils
::
isa
<
CNode
>
(
a
)
&&
utils
::
isa
<
CNode
>
(
b
))
{
return
true
;
}
return
a
.
type
()
==
b
.
type
();
}
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
)
{
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
)
{
//
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_LOG
(
DEBUG
)
<<
"SexpToNode sexp: "
+
sexp
.
ToString
()
+
", graph "
+
graph
.
ToString
();
MS_EXCEPTION_IF_NULL
(
primitive_vars
);
MS_EXCEPTION_IF_NULL
(
primitive_vars
);
if
(
utils
::
isa
<
VectorRef
>
(
sexp
))
{
if
(
utils
::
isa
<
VectorRef
>
(
sexp
))
{
return
HandleSexpVector
(
sexp
,
graph
,
primitive_vars
,
multigraph
);
return
HandleSexpVector
(
sexp
,
graph
,
primitive_vars
,
multigraph
);
...
@@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
...
@@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
return
value_node
;
return
value_node
;
}
}
bool
IsRealCNodeKernel
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
// parameter and value node is not a real cnode kernel
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
// return considered as a real node
if
(
CheckPrimitiveType
(
node
,
prim
::
kPrimReturn
))
{
return
true
;
}
return
IsRealKernel
(
node
);
}
bool
IsGraphKernel
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
// graph kernel should be a real cnode kernel.
if
(
!
IsRealCNodeKernel
(
node
))
{
return
false
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
// graph kernel should has func_graph as first input.
if
(
!
IsValueNode
<
FuncGraph
>
(
input
))
{
return
false
;
}
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
input
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
return
func_graph
->
has_attr
(
FUNC_GRAPH_ATTR_GRAPH_KERNEL
);
}
void
CheckIfFuncGraphIsNull
(
const
FuncGraphPtr
&
graph
)
{
void
CheckIfFuncGraphIsNull
(
const
FuncGraphPtr
&
graph
)
{
if
(
graph
==
nullptr
)
{
if
(
graph
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"The graph is null."
;
MS_LOG
(
EXCEPTION
)
<<
"The graph is null."
;
...
...
mindspore/lite/src/gllo/common/utils.h
→
mindspore/lite/src/gllo/common/
gllo_
utils.h
浏览文件 @
6e57281c
...
@@ -14,22 +14,21 @@
...
@@ -14,22 +14,21 @@
* limitations under the License.
* limitations under the License.
*/
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_
GLLO_
UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_
GLLO_
UTILS_H_
#include <mindspore/lite/src/ir/primitive_t_value.h>
#include <memory>
#include <memory>
#include "src/ir/primitive_t_value.h"
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "
src/gllo
/common/pattern_engine.h"
#include "
backend/optimizer
/common/pattern_engine.h"
#include "schema/inner/model_generated.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "src/param_value_lite.h"
using
PrimitiveTValuePtr
=
std
::
shared_ptr
<
mindspore
::
lite
::
PrimitiveTValue
>
;
using
PrimitiveTValuePtr
=
std
::
shared_ptr
<
mindspore
::
lite
::
PrimitiveTValue
>
;
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
bool
AnfEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
);
bool
AnfEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
);
bool
CNodeTypeEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
);
bool
CNodeTypeEqual
(
const
BaseRef
&
a
,
const
BaseRef
&
b
);
...
@@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
...
@@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
=
false
);
bool
multigraph
=
false
);
bool
IsRealCNodeKernel
(
const
AnfNodePtr
&
node
);
bool
IsGraphKernel
(
const
AnfNodePtr
&
node
);
void
CheckIfFuncGraphIsNull
(
const
FuncGraphPtr
&
graph
);
void
CheckIfFuncGraphIsNull
(
const
FuncGraphPtr
&
graph
);
void
CheckIfAnfNodeIsNull
(
const
AnfNodePtr
&
node
);
void
CheckIfAnfNodeIsNull
(
const
AnfNodePtr
&
node
);
...
@@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n);
...
@@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n);
bool
IsConvNode
(
const
BaseRef
&
n
);
bool
IsConvNode
(
const
BaseRef
&
n
);
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_
GLLO_
UTILS_H_
mindspore/lite/src/gllo/common/node_pass.cc
浏览文件 @
6e57281c
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "
src/gllo
/common/node_pass.h"
#include "
backend/optimizer
/common/node_pass.h"
#include <unordered_set>
#include <unordered_set>
#include <deque>
#include <deque>
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "ir/manager.h"
#include "src/gllo/common/gllo_utils.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
...
@@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL
(
const_func_graph
);
MS_EXCEPTION_IF_NULL
(
const_func_graph
);
todo
.
push_back
(
const_func_graph
->
output
());
todo
.
push_back
(
const_func_graph
->
output
());
}
else
if
(
new_node
&&
new_node
->
isa
<
CNode
>
())
{
}
else
if
(
new_node
&&
new_node
->
isa
<
CNode
>
())
{
if
(
IsGraphKernel
(
new_node
))
{
todo
.
push_back
(
new_node
);
}
auto
cnode
=
new_node
->
cast
<
CNodePtr
>
();
auto
cnode
=
new_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
inputs
=
cnode
->
inputs
();
auto
inputs
=
cnode
->
inputs
();
...
...
mindspore/lite/src/gllo/common/node_pass.h
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
#include <string>
#include <memory>
#include "src/gllo/common/pass.h"
namespace
mindspore
{
namespace
opt
{
// @brief ANF Node level optimization base pass
class
NodePass
:
public
Pass
{
public:
explicit
NodePass
(
const
std
::
string
&
name
)
:
Pass
(
name
)
{}
~
NodePass
()
override
=
default
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
)
final
;
virtual
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
=
0
;
};
using
NodePassPtr
=
std
::
shared_ptr
<
NodePass
>
;
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
mindspore/lite/src/gllo/common/optimizer.cc
浏览文件 @
6e57281c
...
@@ -23,8 +23,7 @@
...
@@ -23,8 +23,7 @@
#include <utility>
#include <utility>
#include <initializer_list>
#include <initializer_list>
#include "src/gllo/common/pass_manager.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/manager.h"
#include "ir/manager.h"
namespace
mindspore
{
namespace
mindspore
{
...
...
mindspore/lite/src/gllo/common/optimizer.h
浏览文件 @
6e57281c
...
@@ -26,9 +26,9 @@
...
@@ -26,9 +26,9 @@
#include "ir/graph_utils.h"
#include "ir/graph_utils.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "
src/gllo
/common/pass_manager.h"
#include "
backend/optimizer
/common/pass_manager.h"
#include "
src/gllo
/common/pattern_engine.h"
#include "
backend/optimizer
/common/pattern_engine.h"
#include "src/gllo/common/utils.h"
#include "src/gllo/common/
gllo_
utils.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
...
mindspore/lite/src/gllo/common/pass_manager.cc
浏览文件 @
6e57281c
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "
src/gllo
/common/pass_manager.h"
#include "
backend/optimizer
/common/pass_manager.h"
#include <sys/time.h>
#include <sys/time.h>
#include <unordered_set>
#include <unordered_set>
...
...
mindspore/lite/src/gllo/common/pass_manager.h
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "src/gllo/common/pass.h"
#include "src/gllo/common/node_pass.h"
namespace
mindspore
{
namespace
opt
{
// @brief For optimization passes management
class
PassManager
{
public:
explicit
PassManager
(
const
std
::
string
&
name
=
"pm"
,
bool
run_only_once
=
true
)
:
name_
(
name
),
passes_
{},
run_only_once_
(
run_only_once
)
{}
virtual
~
PassManager
()
=
default
;
// Get all the passes added by AddPass
const
std
::
vector
<
PassPtr
>
&
Passes
()
const
;
// Add graph pass, the pass object will be freed when pass manager freed.
void
AddPass
(
const
PassPtr
&
pass
);
// Run passes added in pass manager on the input graph
// @param [inout] graph The graph to be optimized
// @return true, graph changed
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
)
const
;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [in] passes The given graph passes
// @return true, graph changed
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PassPtr
>
&
passes
)
const
;
std
::
string
name
()
const
{
return
name_
;
}
private:
const
std
::
string
name_
;
std
::
vector
<
PassPtr
>
passes_
;
bool
run_only_once_
;
};
using
PassManagerPtr
=
std
::
shared_ptr
<
PassManager
>
;
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
mindspore/lite/src/gllo/common/pattern_engine.cc
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/common/pattern_engine.h"
#include <exception>
#include <iostream>
#include <functional>
#include <iterator>
#include "ir/func_graph.h"
#include "mindspore/core/ir/primitive.h"
#include "utils/info.h"
#include "ir/anf.h"
#include "utils/convert_utils_base.h"
#include "utils/overload.h"
namespace
mindspore
{
static
int
GetNextTag
()
{
static
int
kID
=
0
;
return
kID
++
;
}
void
Var
::
EnsureTag
()
{
if
(
tag_
.
length
()
==
0
)
{
std
::
ostringstream
buffer
;
buffer
<<
"_"
<<
GetNextTag
();
tag_
=
buffer
.
str
();
}
}
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
if
(
lhs
->
isa
<
CondVar
>
()
&&
rhs
->
isa
<
CondVar
>
())
{
CondVarPtr
v1
=
dyn_cast
<
CondVar
>
(
lhs
);
CondVarPtr
v2
=
dyn_cast
<
CondVar
>
(
rhs
);
return
*
v1
==
*
v2
;
}
if
(
lhs
->
isa
<
SeqVar
>
()
&&
rhs
->
isa
<
SeqVar
>
())
{
SVarPtr
v1
=
dyn_cast
<
SeqVar
>
(
lhs
);
SVarPtr
v2
=
dyn_cast
<
SeqVar
>
(
rhs
);
return
*
v1
==
*
v2
;
}
return
(
*
lhs
==
*
rhs
);
}
std
::
string
SeqVar
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
"SeqVar("
<<
tag
()
<<
", "
<<
subvar_
->
ToString
()
<<
")"
;
return
buffer
.
str
();
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
)
{
if
(
var
==
nullptr
)
{
os
<<
""
;
}
else
{
os
<<
var
->
ToString
();
}
return
os
;
}
template
<
>
std
::
ostream
&
operator
<<<
VarPtr
,
BaseRef
>
(
std
::
ostream
&
os
,
const
Equiv
&
equiv
)
{
os
<<
"[Equiv]"
<<
"
\n
"
;
for
(
auto
&
equiv_item
:
equiv
)
{
auto
k
=
equiv_item
.
first
;
os
<<
k
<<
":"
;
BaseRef
x
=
equiv_item
.
second
;
if
(
utils
::
isa
<
AnfNodePtr
>
(
x
))
{
auto
node
=
utils
::
cast
<
AnfNodePtr
>
(
x
);
os
<<
"TypeString["
<<
node
->
type_name
()
<<
"]"
;
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
os
<<
"IsValueNodeGraph "
;
}
os
<<
"type "
<<
node
->
type_name
();
if
(
node
->
isa
<
ValueNode
>
())
{
os
<<
" value "
<<
GetValueNode
(
node
);
}
os
<<
" addr: "
<<
node
;
}
else
if
(
utils
::
isa
<
Named
>
(
x
))
{
os
<<
"Named "
<<
x
.
ToString
().
c_str
();
}
else
if
(
utils
::
isa
<
VarPtr
>
(
x
))
{
os
<<
"TypeString[Var]"
;
os
<<
utils
::
cast
<
VarPtr
>
(
x
);
}
else
if
(
utils
::
isa
<
FuncGraphPtr
>
(
x
))
{
os
<<
"TypeString[Graph]"
;
}
os
<<
"
\n
"
;
}
return
os
;
}
static
BaseRef
GetVar
(
const
BaseRef
&
x
)
{
// MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
if
(
utils
::
isa
<
AnfNodePtr
>
(
x
))
{
auto
node
=
utils
::
cast
<
AnfNodePtr
>
(
x
);
// MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
if
(
node
->
isa
<
VarNode
>
())
{
// MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
return
node
->
cast
<
VarNodePtr
>
()
->
var_
;
}
// if (node->isa<ValueNode>()) {
// MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
// } else {
// MS_LOG(DEBUG) << "type " + node->type_name();
// }
// } else if (utils::isa<Named>(x)) {
// MS_LOG(DEBUG) << "Named " + x.ToString();
// } else if (utils::isa<VectorRef>(x)) {
// MS_LOG(DEBUG) << "VectorRef";
// } else if (utils::isa<VarPtr>(x)) {
// MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
}
// MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
return
x
;
}
EquivPtr
MatchOnVar
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
EquivPtr
equiv
)
{
MS_LOG
(
DEBUG
)
<<
"MatchOnVar pattern "
+
pattern
.
ToString
()
+
" expr: "
+
expr
.
ToString
();
MS_EXCEPTION_IF_NULL
(
equiv
);
if
(
utils
::
isa
<
VarPtr
>
(
pattern
))
{
VarPtr
var
=
utils
::
cast
<
VarPtr
>
(
pattern
);
if
(
var
->
matches
(
expr
))
{
(
*
equiv
)[
var
]
=
expr
;
MS_LOG
(
DEBUG
)
<<
"pattern is var match: "
+
pattern
.
ToString
()
+
", "
+
expr
.
ToString
();
return
equiv
;
}
}
return
nullptr
;
}
bool
PatternEngine
::
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
MS_EXCEPTION_IF_NULL
(
values_expr
);
if
(
utils
::
isa
<
SeqPtr
>
(
pattern_ref
))
{
*
values_pattern
=
pattern_ref
;
*
values_expr
=
expr_ref
;
return
true
;
}
return
false
;
}
bool
PatternEngine
::
ToVector
(
const
BaseRef
&
pattern_ref
,
const
BaseRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
MS_EXCEPTION_IF_NULL
(
values_expr
);
// visitor to visite the list
auto
appender_pattern
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
values
.
push_back
(
GetVar
(
u
));
return
u
;
};
return
fn
;
};
visitor_
->
SetFn
(
appender_pattern
(
*
values_pattern
));
// MS_LOG(DEBUG) << "visit pattern_ref";
bool
success
=
visitor_
->
Visit
(
pattern_ref
,
nullptr
);
if
(
!
success
)
{
return
false
;
}
auto
appender_expr
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
values
.
push_back
(
u
);
return
u
;
};
return
fn
;
};
visitor_
->
SetFn
(
appender_expr
(
*
values_expr
));
// MS_LOG(DEBUG) << "visit expr_ref";
return
visitor_
->
Visit
(
expr_ref
,
nullptr
);
}
static
int
GetSVarStartIndex
(
const
VectorRef
&
values
)
{
int
index
=
-
1
;
int
count
=
0
;
for
(
auto
&
value
:
values
)
{
if
(
utils
::
isa
<
VarPtr
>
(
value
)
&&
utils
::
cast
<
VarPtr
>
(
value
)
->
isa
<
SeqVar
>
())
{
if
(
index
!=
-
1
)
{
// MS_LOG(DEBUG) << "Multiple SVars in sequence";
return
kInvalidVarIndex
;
}
index
=
count
;
}
count
++
;
}
return
index
;
}
void
UpdateEquivMap
(
const
VectorRef
&
values_pattern
,
const
BaseRef
&
expr_ref
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
{
if
(
equiv
==
nullptr
||
values_pattern
.
empty
()
||
!
utils
::
isa
<
AnfNodePtr
>
(
values_pattern
[
0
])
||
!
utils
::
isa
<
AnfNodePtr
>
(
expr_ref
))
{
return
;
}
auto
real_node
=
utils
::
cast
<
AnfNodePtr
>
(
expr_ref
);
MS_EXCEPTION_IF_NULL
(
real_node
);
if
(
!
real_node
->
isa
<
CNode
>
())
{
return
;
}
auto
prim_node
=
utils
::
cast
<
AnfNodePtr
>
(
values_pattern
[
0
]);
MS_EXCEPTION_IF_NULL
(
prim_node
);
if
(
!
IsValueNode
<
Primitive
>
(
prim_node
))
{
return
;
}
ValuePtr
value
=
GetValueNode
(
prim_node
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
prim
=
value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
prim
);
auto
iter
=
primitive_vars
.
find
(
prim
);
if
(
iter
==
primitive_vars
.
end
())
{
return
;
}
(
*
equiv
)[
iter
->
second
]
=
real_node
;
}
EquivPtr
PatternEngine
::
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
{
int
svar_index
=
GetSVarStartIndex
(
values_pattern
);
if
(
svar_index
==
kInvalidVarIndex
)
{
return
nullptr
;
}
size_t
values_pattern_len
=
values_pattern
.
size
();
size_t
values_expr_len
=
values_expr
.
size
();
if
(
svar_index
==
-
1
)
{
if
(
values_pattern_len
!=
values_expr_len
)
{
// MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ",
// expr len " << values_expr_len;
return
nullptr
;
}
}
if
(
values_expr_len
<
values_pattern_len
-
1
)
{
MS_LOG
(
DEBUG
)
<<
"invalid size: pattern len "
<<
values_pattern_len
<<
", expr len "
<<
values_expr_len
;
return
nullptr
;
}
size_t
diff
=
values_expr_len
-
values_pattern_len
+
1
;
for
(
size_t
i
=
0
;
i
<
values_pattern_len
;
i
++
)
{
size_t
expr_i
=
i
;
if
(
svar_index
!=
-
1
&&
i
==
IntToSize
(
svar_index
))
{
auto
seq
=
std
::
vector
<
BaseRef
>
(
values_expr
.
begin
()
+
svar_index
,
values_expr
.
begin
()
+
svar_index
+
SizeToInt
(
diff
));
equiv
=
Match
(
values_pattern
[
svar_index
],
seq
,
primitive_vars
,
equiv
);
}
else
{
if
(
svar_index
!=
-
1
&&
i
>
IntToSize
(
svar_index
))
{
expr_i
=
i
+
diff
-
1
;
}
equiv
=
Match
(
values_pattern
[
i
],
values_expr
[
expr_i
],
primitive_vars
,
equiv
);
}
if
(
equiv
==
nullptr
)
{
return
nullptr
;
}
}
return
equiv
;
}
EquivPtr
PatternEngine
::
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
{
MS_LOG
(
DEBUG
)
<<
"-----[in Match]"
;
// MS_LOG(DEBUG) << "GetVar w";
BaseRef
pattern_ref
=
GetVar
(
pattern
);
// MS_LOG(DEBUG) << "GetVar v";
BaseRef
expr_ref
=
expr
;
if
(
equiv
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Equiv pointer is null"
;
}
MS_LOG
(
DEBUG
)
<<
"Pattern ref "
+
pattern_ref
.
ToString
()
+
", expr ref"
+
expr_ref
.
ToString
();
// 1. if pattern_ref is var and already in equiv, replace it.
if
(
utils
::
isa
<
VarPtr
>
(
pattern_ref
))
{
VarPtr
var
=
utils
::
cast
<
VarPtr
>
(
pattern_ref
);
auto
iter
=
equiv
->
find
(
var
);
if
(
iter
!=
equiv
->
end
())
{
pattern_ref
=
iter
->
second
;
}
}
// 2. check equal
if
(
eq_
(
pattern_ref
,
expr_ref
))
{
return
equiv
;
}
// 3. match var
EquivPtr
ret_equiv
=
MatchOnVar
(
pattern_ref
,
expr_ref
,
equiv
);
if
(
ret_equiv
)
{
return
ret_equiv
;
}
// 4. here the type can be std:vector, std:list,
// or cnode.
if
(
!
type_eq_
(
pattern_ref
,
expr_ref
))
{
MS_LOG
(
DEBUG
)
<<
"Type mismatch"
;
return
nullptr
;
}
// 5. transfer the Containers by visitor to std::vector
VectorRef
values_pattern
;
VectorRef
values_expr
;
if
(
!
ToVector
(
pattern_ref
,
expr_ref
,
&
values_pattern
,
&
values_expr
))
{
return
nullptr
;
}
// 6. if any svar in both side, find the SeqVar index,
// try to pack the Var s in std::vector to a Seq and match elements one by one.
// check svar
equiv
=
AlignSVar
(
values_pattern
,
values_expr
,
primitive_vars
,
equiv
);
UpdateEquivMap
(
values_pattern
,
expr_ref
,
primitive_vars
,
equiv
);
return
equiv
;
}
BaseRef
PatternEngine
::
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_LOG
(
DEBUG
)
<<
"-----[in Replace]"
;
BaseRef
ref
=
GetVar
(
pattern
);
BaseRef
out
;
bool
is_match
=
false
;
// w is var
if
(
utils
::
isa
<
VarPtr
>
(
ref
))
{
const
VarPtr
&
var
=
utils
::
cast
<
VarPtr
>
(
ref
);
auto
iter
=
equiv
->
find
(
var
);
if
(
iter
!=
equiv
->
end
())
{
out
=
iter
->
second
;
is_match
=
true
;
}
}
if
(
is_match
)
{
return
out
;
}
// visitor to visit the list
std
::
function
<
BaseRef
(
BaseRef
)
>
fn
=
[
&
,
this
,
equiv
](
const
BaseRef
&
u
)
{
return
Replace
(
u
,
equiv
);
};
visitor_
->
SetFn
(
fn
);
BaseRef
visit_out
;
if
(
!
visitor_
->
Visit
(
pattern
,
&
visit_out
))
{
return
pattern
;
}
return
visit_out
;
}
}
// namespace mindspore
mindspore/lite/src/gllo/common/pattern_engine.h
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
#include <string>
#include <sstream>
#include <memory>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <initializer_list>
#include <iostream>
#include <algorithm>
#include <map>
#include <stdexcept>
#include <list>
#include <utility>
#include "src/gllo/common/visit.h"
#include "mindspore/core/base/base.h"
#include "utils/log_adapter.h"
#include "base/base_ref.h"
namespace
mindspore
{
class
CondVar
;
class
SeqVar
;
using
CondVarPtr
=
std
::
shared_ptr
<
CondVar
>
;
using
SVarPtr
=
std
::
shared_ptr
<
SeqVar
>
;
const
int
kInvalidVarIndex
=
-
2
;
using
ConditionFunc
=
std
::
function
<
bool
(
const
BaseRef
&
)
>
;
// Base wildcard variable which could match any anf node.
class
Var
:
public
Base
{
friend
class
VarHasher
;
public:
explicit
Var
(
std
::
string
tag
=
""
)
:
tag_
(
std
::
move
(
tag
)),
primitive_
(
nullptr
)
{
EnsureTag
();
}
explicit
Var
(
const
PrimitivePtr
&
primitive
,
std
::
string
tag
=
""
)
:
tag_
(
std
::
move
(
tag
)),
primitive_
(
primitive
)
{
EnsureTag
();
}
Var
(
const
Var
&
other
)
:
Base
(
other
),
tag_
(
other
.
tag_
)
{}
virtual
Var
&
operator
=
(
const
Var
&
other
)
{
if
(
&
other
==
this
)
{
return
*
this
;
}
this
->
tag_
=
other
.
tag_
;
return
*
this
;
}
~
Var
()
override
=
default
;
MS_DECLARE_PARENT
(
Var
,
Base
);
virtual
bool
matches
(
const
BaseRef
&
)
{
return
true
;
}
virtual
bool
operator
==
(
const
Var
&
other
)
const
{
return
tag_
==
other
.
tag_
;
}
bool
operator
!=
(
const
Var
&
other
)
const
{
return
!
(
&
other
==
this
);
}
std
::
string
tag
()
const
{
return
tag_
;
}
PrimitivePtr
primitive
()
const
{
return
primitive_
;
}
std
::
string
ToString
()
const
override
{
std
::
ostringstream
buffer
;
buffer
<<
"Var("
<<
tag_
<<
")"
;
return
buffer
.
str
();
}
std
::
size_t
hash
()
const
override
{
return
std
::
hash
<
std
::
string
>
()(
tag_
);
}
protected:
void
EnsureTag
();
std
::
string
tag_
;
PrimitivePtr
primitive_
;
};
// VarNode means variable node, a subclass of AnfNode
class
VarNode
:
public
AnfNode
{
public:
VarNode
(
const
VarPtr
&
value
,
const
FuncGraphPtr
&
func_graph
)
:
AnfNode
(
func_graph
),
var_
(
value
)
{}
~
VarNode
()
override
=
default
;
MS_DECLARE_PARENT
(
VarNode
,
AnfNode
);
const
VarPtr
var_
;
};
using
VarNodePtr
=
std
::
shared_ptr
<
VarNode
>
;
class
VarHasher
{
public:
std
::
size_t
operator
()(
const
Var
&
var
)
const
{
return
var
.
hash
();
}
};
// Condition Var, match an anf node when condition function return true.
class
CondVar
:
public
Var
{
public:
explicit
CondVar
(
const
ConditionFunc
&
cond
)
:
cond_fn_
(
cond
)
{}
~
CondVar
()
override
=
default
;
MS_DECLARE_PARENT
(
CondVar
,
Var
);
bool
matches
(
const
BaseRef
&
value
)
override
{
// MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
if
(
utils
::
isa
<
Var
>
(
value
))
{
return
false
;
}
return
cond_fn_
(
value
);
}
ConditionFunc
cond_fn_
;
};
using
Seq
=
VectorRef
;
using
SeqPtr
=
std
::
shared_ptr
<
Seq
>
;
// Sequence Var which could match multiple consecutive input nodes of a CNode.
class
SeqVar
:
public
Var
{
public:
SeqVar
()
:
subvar_
(
std
::
make_shared
<
Var
>
())
{}
~
SeqVar
()
override
=
default
;
MS_DECLARE_PARENT
(
SeqVar
,
Var
);
explicit
SeqVar
(
const
VarPtr
subvar
)
:
subvar_
(
subvar
)
{}
bool
matches
(
const
BaseRef
&
value
)
override
{
// match Seq.
if
(
utils
::
isa
<
Seq
>
(
value
))
{
const
Seq
&
seq
=
utils
::
cast
<
Seq
>
(
value
);
return
std
::
all_of
(
seq
.
begin
(),
seq
.
end
(),
[
this
](
const
BaseRef
&
v
)
{
auto
eq
=
subvar_
->
matches
(
v
);
return
eq
;
});
}
return
false
;
}
bool
operator
==
(
const
SeqVar
&
other
)
const
{
return
*
subvar_
==
*
other
.
subvar_
;
}
std
::
string
ToString
()
const
override
;
private:
VarPtr
subvar_
;
};
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
);
inline
bool
operator
!=
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
);
using
Equiv
=
std
::
map
<
VarPtr
,
BaseRef
>
;
using
EquivPtr
=
std
::
shared_ptr
<
Equiv
>
;
using
PrimitiveVarMap
=
std
::
unordered_map
<
PrimitivePtr
,
VarPtr
>
;
using
PrimitiveVarMapPtr
=
std
::
shared_ptr
<
PrimitiveVarMap
>
;
inline
bool
DefaultTypeEq
(
const
BaseRef
&
x
,
const
BaseRef
&
y
)
{
return
x
.
type
()
==
y
.
type
();
}
class
PatternEngine
{
public:
PatternEngine
(
const
std
::
shared_ptr
<
Visitor
>
&
visitor
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
&
eq
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
&
type_eq
=
DefaultTypeEq
)
:
visitor_
(
visitor
),
eq_
(
eq
),
type_eq_
(
type_eq
)
{}
~
PatternEngine
()
=
default
;
EquivPtr
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
;
// Replace pattern with equivalent
BaseRef
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
;
private:
EquivPtr
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
;
bool
ToVector
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
bool
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
std
::
shared_ptr
<
Visitor
>
visitor_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
eq_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
type_eq_
;
};
}
// namespace mindspore
namespace
std
{
using
mindspore
::
ERROR
;
using
mindspore
::
LogStream
;
using
mindspore
::
NoExceptionType
;
template
<
>
struct
hash
<
mindspore
::
VarPtr
>
{
std
::
size_t
operator
()(
const
mindspore
::
VarPtr
var
)
const
{
if
(
var
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Invalid var ptr"
;
return
0
;
}
return
std
::
hash
<
std
::
string
>
{}(
var
->
tag
());
}
};
}
// namespace std
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
mindspore/lite/src/gllo/common/visit.cc
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include <algorithm>
#include <utility>
#include "src/gllo/common/visit.h"
#include "src/gllo/common/pattern_engine.h"
#include "utils/any.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
bool
CheckIfNeedExpand
(
const
std
::
vector
<
BaseRef
>
&
list
)
{
return
std
::
any_of
(
list
.
begin
(),
list
.
end
(),
[](
const
BaseRef
&
any
)
{
return
utils
::
isa
<
Seq
>
(
any
);
});
}
std
::
shared_ptr
<
VectorRef
>
ExpandList
(
const
std
::
vector
<
BaseRef
>
&
list
)
{
std
::
shared_ptr
<
VectorRef
>
new_list
=
std
::
make_shared
<
VectorRef
>
();
for
(
auto
&
item
:
list
)
{
if
(
utils
::
isa
<
Seq
>
(
item
))
{
const
Seq
&
seq
=
utils
::
cast
<
Seq
>
(
item
);
new_list
->
insert
(
new_list
->
end
(),
seq
.
begin
(),
seq
.
end
());
}
else
{
new_list
->
push_back
(
item
);
}
}
return
new_list
;
}
bool
DefaultVisitor
::
Visit
(
const
VectorRef
&
v_any
,
BaseRef
*
const
visit_out
)
const
{
std
::
vector
<
BaseRef
>
out
;
(
void
)
std
::
transform
(
v_any
.
begin
(),
v_any
.
end
(),
std
::
back_inserter
(
out
),
[
this
](
const
BaseRef
&
item
)
{
return
fn_
(
item
);
});
if
(
visit_out
!=
nullptr
)
{
*
visit_out
=
ExpandList
(
out
);
}
return
true
;
}
bool
DefaultVisitor
::
Visit
(
const
BaseRef
&
any
,
BaseRef
*
const
visit_out
)
const
{
if
(
utils
::
isa
<
Seq
>
(
any
))
{
return
Visit
(
utils
::
cast
<
Seq
>
(
any
),
visit_out
);
}
else
if
(
utils
::
isa
<
AnfNodePtr
>
(
any
))
{
auto
nodeptr
=
utils
::
cast
<
AnfNodePtr
>
(
any
);
AnfNodePtr
output
;
AnfNodePtr
*
p_output
=
&
output
;
if
(
visit_out
==
nullptr
)
{
p_output
=
nullptr
;
}
Visit
(
nodeptr
,
fn_
,
p_output
);
if
(
visit_out
!=
nullptr
)
{
*
visit_out
=
output
;
}
return
true
;
}
MS_LOG
(
DEBUG
)
<<
"VisitError, not support type to Visit: "
+
any
.
ToString
();
return
false
;
}
void
DefaultVisitor
::
Visit
(
const
AnfNodePtr
&
node
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
{
if
(
node
->
isa
<
CNode
>
())
{
Visit
(
node
->
cast
<
CNodePtr
>
(),
fn
,
output
);
return
;
}
if
(
node
->
isa
<
ValueNode
>
())
{
Visit
(
node
->
cast
<
ValueNodePtr
>
(),
fn
,
output
);
return
;
}
if
(
output
!=
nullptr
)
{
*
output
=
node
;
}
}
void
DefaultVisitor
::
Visit
(
const
CNodePtr
&
cnode
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
{
// if output is nullptr, it's not required to make the new CNode node.
if
(
output
==
nullptr
)
{
for
(
auto
&
inp
:
cnode
->
inputs
())
{
(
void
)
fn
(
inp
);
}
if
(
cnode
->
func_graph
()
!=
nullptr
)
{
(
void
)
fn
(
cnode
->
func_graph
());
}
else
{
(
void
)
fn
(
cnode
->
func_graph_as_var
());
}
return
;
}
std
::
vector
<
AnfNodePtr
>
new_inputs
;
std
::
vector
<
BaseRef
>
after_cnode_fn
;
std
::
shared_ptr
<
VectorRef
>
out
;
(
void
)
std
::
transform
(
cnode
->
inputs
().
begin
(),
cnode
->
inputs
().
end
(),
std
::
back_inserter
(
after_cnode_fn
),
fn
);
if
(
CheckIfNeedExpand
(
after_cnode_fn
))
{
out
=
ExpandList
(
after_cnode_fn
);
}
std
::
vector
<
BaseRef
>
&
outs
=
after_cnode_fn
;
if
(
out
!=
nullptr
)
{
outs
=
out
->
elements
();
}
for
(
auto
&
any_item
:
outs
)
{
if
(
!
utils
::
isa
<
AnfNodePtr
>
(
any_item
))
{
MS_LOG
(
EXCEPTION
)
<<
"VisitError, fn not return the same type AnfNodePtr"
;
}
new_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
(
any_item
));
}
BaseRef
any_fg
;
AnfNodePtr
new_cnode
=
nullptr
;
if
(
cnode
->
func_graph
()
!=
nullptr
)
{
any_fg
=
fn
(
cnode
->
func_graph
());
if
(
!
utils
::
isa
<
FuncGraphPtr
>
(
any_fg
))
{
MS_LOG
(
EXCEPTION
)
<<
"VisitError, fn not return the same type FuncGraphPtr"
;
}
new_cnode
=
std
::
make_shared
<
CNode
>
(
new_inputs
,
utils
::
cast
<
FuncGraphPtr
>
(
any_fg
));
}
else
{
any_fg
=
fn
(
cnode
->
func_graph_as_var
());
if
(
utils
::
isa
<
VarPtr
>
(
any_fg
))
{
new_cnode
=
std
::
make_shared
<
CNode
>
(
new_inputs
,
utils
::
cast
<
VarPtr
>
(
any_fg
));
}
else
if
(
utils
::
isa
<
FuncGraphPtr
>
(
any_fg
))
{
new_cnode
=
std
::
make_shared
<
CNode
>
(
new_inputs
,
utils
::
cast
<
FuncGraphPtr
>
(
any_fg
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"VisitError, fn not return VarPtr or FuncGraphPtr"
;
}
}
new_cnode
->
set_abstract
(
cnode
->
abstract
());
*
output
=
new_cnode
;
}
void
DefaultVisitor
::
Visit
(
const
ValueNodePtr
&
vnode
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
{
const
BaseRef
&
value
=
utils
::
cast
<
ValuePtr
>
(
fn
(
vnode
->
value
()));
if
(
utils
::
isa
<
ValuePtr
>
(
value
))
{
if
(
output
!=
nullptr
)
{
auto
ct
=
NewValueNode
(
utils
::
cast
<
ValuePtr
>
(
value
));
ct
->
set_abstract
(
vnode
->
abstract
());
*
output
=
ct
;
}
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Visit result is not ValuePtr."
;
}
}
// namespace mindspore
mindspore/lite/src/gllo/common/visit.h
已删除
100644 → 0
浏览文件 @
fe7141e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
#define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
#include <unordered_map>
#include <stdexcept>
#include <list>
#include <vector>
#include <string>
#include <memory>
#include "mindspore/core/base/base.h"
#include "base/base_ref.h"
namespace
mindspore
{
using
VisitFn
=
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
;
class
Visitor
{
public:
virtual
void
SetFn
(
VisitFn
fn
)
=
0
;
virtual
bool
Visit
(
const
BaseRef
&
e
,
BaseRef
*
out
)
const
=
0
;
virtual
bool
Visit
(
const
VectorRef
&
e
,
BaseRef
*
out
)
const
=
0
;
virtual
~
Visitor
()
=
default
;
};
class
DefaultVisitor
:
public
Visitor
{
public:
DefaultVisitor
()
:
fn_
(
nullptr
)
{}
~
DefaultVisitor
()
override
=
default
;
void
SetFn
(
VisitFn
fn
)
override
{
fn_
=
fn
;
};
bool
Visit
(
const
VectorRef
&
e
,
BaseRef
*
out
)
const
override
;
bool
Visit
(
const
BaseRef
&
e
,
BaseRef
*
out
)
const
override
;
void
Visit
(
const
AnfNodePtr
&
node
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
;
void
Visit
(
const
CNodePtr
&
cnode
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
;
void
Visit
(
const
ValueNodePtr
&
vnode
,
const
VisitFn
&
fn
,
AnfNodePtr
*
output
)
const
;
VisitFn
fn_
;
};
std
::
shared_ptr
<
VectorRef
>
ExpandList
(
const
std
::
vector
<
BaseRef
>
&
list
);
bool
CheckIfNeedExpand
(
const
std
::
vector
<
BaseRef
>
&
list
);
}
// namespace mindspore
#endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc
浏览文件 @
6e57281c
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "
mindspore/lite/
src/gllo/fusion/conv_activation_fusion.h"
#include "src/gllo/fusion/conv_activation_fusion.h"
#include <memory>
#include <memory>
#include "
mindspore/lite/
schema/inner/model_generated.h"
#include "schema/inner/model_generated.h"
#include "
mindspore/lite/
src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "
mindspore/ccsrc/
utils/utils.h"
#include "utils/utils.h"
#include "
mindspore/lite/src/gllo/common/
utils.h"
#include "
src/gllo/common/gllo_
utils.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
namespace
{
namespace
{
...
...
mindspore/lite/src/gllo/fusion/conv_activation_fusion.h
浏览文件 @
6e57281c
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#include <string>
#include <string>
#include "
mindspore/lite/
src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
...
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc
浏览文件 @
6e57281c
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h"
#include "src/gllo/fusion/conv_biasadd_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include <memory>
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "securec/include/securec.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
...
@@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
...
@@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
CheckIfCNodeIsNull
(
conv_node
);
CheckIfCNodeIsNull
(
conv_node
);
GenConvNewBias
(
func_graph
,
conv_node
,
add_node
);
GenConvNewBias
(
func_graph
,
conv_node
,
add_node
);
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
lite
::
PrimitiveTValue
>>
(
conv_node
->
input
(
0
));
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
lite
::
PrimitiveTValue
>>
(
conv_node
->
input
(
0
));
MS_ASSERT
(
primitiveT_value
);
MS_ASSERT
(
primitiveT_value
!=
nullptr
);
auto
type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
auto
type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
if
(
type
==
schema
::
PrimitiveType_Conv2D
)
{
if
(
type
==
schema
::
PrimitiveType_Conv2D
)
{
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsConv2D
()
->
hasBias
=
true
;
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsConv2D
()
->
hasBias
=
true
;
...
...
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h
浏览文件 @
6e57281c
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#include "
mindspore/lite/
src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
...
mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc
浏览文件 @
6e57281c
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h"
#include "src/gllo/fusion/conv_bn_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include <memory>
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "securec/include/securec.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
...
...
mindspore/lite/src/gllo/fusion/conv_bn_fusion.h
浏览文件 @
6e57281c
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#include "
mindspore/lite/
src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
class
ConvBatchNormFusion
:
public
ConvTransformFusion
{
class
ConvBatchNormFusion
:
public
ConvTransformFusion
{
...
...
mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc
浏览文件 @
6e57281c
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "
mindspore/lite/
src/gllo/fusion/conv_scale_fusion.h"
#include "src/gllo/fusion/conv_scale_fusion.h"
#include <memory>
#include <memory>
#include "
mindspore/lite/
src/param_value_lite.h"
#include "src/param_value_lite.h"
#include "
mindspore/lite/
schema/inner/model_generated.h"
#include "schema/inner/model_generated.h"
#include "
mindspore/lite/
src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "
mindspore/ccsrc/
utils/utils.h"
#include "utils/utils.h"
#include "
mindspore/lite/src/gllo/common/
utils.h"
#include "
src/gllo/common/gllo_
utils.h"
#include "include/errorcode.h"
#include "include/errorcode.h"
#include "securec/include/securec.h"
#include "securec/include/securec.h"
...
...
mindspore/lite/src/gllo/fusion/conv_scale_fusion.h
浏览文件 @
6e57281c
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#include "
mindspore/lite/
src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
class
ConvScaleFusion
:
public
ConvTransformFusion
{
class
ConvScaleFusion
:
public
ConvTransformFusion
{
...
...
mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc
浏览文件 @
6e57281c
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "
mindspore/lite/
src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
#include <memory>
#include <memory>
#include "
mindspore/lite/
src/param_value_lite.h"
#include "src/param_value_lite.h"
#include "
mindspore/lite/
schema/inner/model_generated.h"
#include "schema/inner/model_generated.h"
#include "
mindspore/lite/
src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "
mindspore/ccsrc/
utils/utils.h"
#include "utils/utils.h"
#include "
mindspore/lite/src/gllo/common/
utils.h"
#include "
src/gllo/common/gllo_
utils.h"
#include "include/errorcode.h"
#include "include/errorcode.h"
#include "securec/include/securec.h"
#include "securec/include/securec.h"
...
@@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
...
@@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
GenNewConvTensor
(
func_graph
,
conv_node
,
kernel_nums
,
trans_scale
,
trans_bias
);
GenNewConvTensor
(
func_graph
,
conv_node
,
kernel_nums
,
trans_scale
,
trans_bias
);
delete
[]
trans_bias
;
delete
[]
trans_bias
;
delete
[]
trans_scale
;
delete
[]
trans_scale
;
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
lite
::
PrimitiveTValue
>>
(
conv_node
->
input
(
0
));
MS_ASSERT
(
primitiveT_value
!=
nullptr
);
auto
type
=
primitiveT_value
->
GetPrimitiveT
()
->
value
.
type
;
if
(
type
==
schema
::
PrimitiveType_Conv2D
)
{
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsConv2D
()
->
hasBias
=
true
;
}
else
if
(
type
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
primitiveT_value
->
GetPrimitiveT
()
->
value
.
AsDepthwiseConv2D
()
->
hasBias
=
true
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unsupported opType, "
<<
type
;
}
return
pre_node
;
return
pre_node
;
}
}
...
...
mindspore/lite/src/gllo/fusion/conv_transform_fusion.h
浏览文件 @
6e57281c
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
#include <string>
#include <string>
#include "
mindspore/lite/
src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace
mindspore
::
opt
{
namespace
mindspore
::
opt
{
class
ConvTransformFusion
:
public
PatternProcessPass
{
class
ConvTransformFusion
:
public
PatternProcessPass
{
...
...
mindspore/lite/test/CMakeLists.txt
浏览文件 @
6e57281c
...
@@ -63,6 +63,8 @@ if(BUILD_CONVERTER)
...
@@ -63,6 +63,8 @@ if(BUILD_CONVERTER)
${
CCSRC_DIR
}
/pybind_api/export_flags.cc
${
CCSRC_DIR
}
/pybind_api/export_flags.cc
${
CCSRC_DIR
}
/utils/context/context_extends.cc
${
CCSRC_DIR
}
/utils/context/context_extends.cc
${
CCSRC_DIR
}
/frontend/parallel/costmodel_context.cc
${
CCSRC_DIR
}
/frontend/parallel/costmodel_context.cc
${
CCSRC_DIR
}
/backend/optimizer/common/pattern_engine.cc
${
CCSRC_DIR
}
/backend/optimizer/common/visit.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../src/common/graph_utils_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../src/common/graph_utils_extends.cc
)
)
else
()
else
()
...
@@ -202,12 +204,14 @@ if(BUILD_CONVERTER)
...
@@ -202,12 +204,14 @@ if(BUILD_CONVERTER)
${
LITE_DIR
}
/tools/converter/converter.cc
${
LITE_DIR
}
/tools/converter/converter.cc
${
LITE_DIR
}
/tools/converter/parser/onnx/onnx.pb.cc
${
LITE_DIR
}
/tools/converter/parser/onnx/onnx.pb.cc
${
LITE_DIR
}
/test/st/converter_test.cc
${
LITE_DIR
}
/test/st/converter_test.cc
${
LITE_DIR
}
/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
${
LITE_DIR
}
/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
${
LITE_DIR
}
/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
${
LITE_DIR
}
/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
${
LITE_DIR
}
/src/gllo/common/node_pass.cc
${
LITE_DIR
}
/src/gllo/common/node_pass.cc
${
LITE_DIR
}
/src/gllo/common/optimizer.cc
${
LITE_DIR
}
/src/gllo/common/optimizer.cc
${
LITE_DIR
}
/src/gllo/common/pass_manager.cc
${
LITE_DIR
}
/src/gllo/common/pass_manager.cc
${
LITE_DIR
}
/src/gllo/common/pattern_engine.cc
${
LITE_DIR
}
/src/gllo/common/gllo_utils.cc
${
LITE_DIR
}
/src/gllo/common/visit.cc
${
LITE_DIR
}
/src/gllo/common/utils.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_biasadd_fusion.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_biasadd_fusion.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_activation_fusion.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_activation_fusion.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_transform_fusion.cc
${
LITE_DIR
}
/src/gllo/fusion/conv_transform_fusion.cc
...
...
mindspore/lite/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
0 → 100644
浏览文件 @
6e57281c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace
mindspore
{
class
ConvActivationFusionTest
:
public
mindspore
::
Common
{
public:
ConvActivationFusionTest
()
=
default
;
};
using
MetaGraphTptr
=
std
::
shared_ptr
<
schema
::
MetaGraphT
>
;
using
CNodeTptr
=
std
::
unique_ptr
<
schema
::
CNodeT
>
;
namespace
{
CNodeTptr
BuildConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
auto
prim1
=
new
schema
::
Conv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelOut
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
CNodeTptr
BuildDepthwiseConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
auto
prim1
=
new
schema
::
DepthwiseConv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelIn
=
1
;
prim1
->
channelMultiplier
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
MetaGraphTptr
BuildGraph
(
schema
::
PrimitiveType
conv_type
,
schema
::
ActivationType
activation_type
)
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// conv node
CNodeTptr
convNode
;
if
(
conv_type
==
schema
::
PrimitiveType_Conv2D
)
{
convNode
=
BuildConv2D
();
}
else
{
convNode
=
BuildDepthwiseConv2D
();
}
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
convNode
));
// relu node
auto
next_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
next_node
->
inputIndex
=
{
2
};
next_node
->
outputIndex
=
{
3
};
next_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
next_node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Activation
;
auto
prim2
=
new
schema
::
ActivationT
;
prim2
->
type
=
activation_type
;
next_node
->
primitive
->
value
.
value
=
prim2
;
next_node
->
name
=
"activation"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
next_node
));
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
3
};
// input 0: data
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
5
,
5
,
3
};
input0
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// input 1: weight
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_KHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
8
,
3
,
3
,
3
};
input1
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
// conv output
auto
conv_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
conv_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
conv_output
->
format
=
schema
::
Format_NHWC
;
conv_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
conv_output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
conv_output
));
// final output
auto
output
=
std
::
make_unique
<
schema
::
TensorT
>
();
output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
output
->
format
=
schema
::
Format_NHWC
;
output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
output
));
return
meta_graph
;
}
}
// namespace
TEST_F
(
ConvActivationFusionTest
,
TestConvReluNode
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_Conv2D
,
schema
::
ActivationType_RELU
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsConv2D
()
->
activationType
,
schema
::
ActivationType_RELU
);
}
}
TEST_F
(
ConvActivationFusionTest
,
TestConvRelu6Node
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_Conv2D
,
schema
::
ActivationType_RELU6
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsConv2D
()
->
activationType
,
schema
::
ActivationType_RELU6
);
}
}
TEST_F
(
ConvActivationFusionTest
,
TestBadCase_ConvRelu
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_DepthwiseConv2D
,
schema
::
ActivationType_LEAKY_RELU
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
2
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
if
(
cnode
->
primitive
->
value
.
type
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
activationType
,
schema
::
ActivationType_NO_ACTIVATION
);
}
}
}
}
// namespace mindspore
mindspore/lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
0 → 100644
浏览文件 @
6e57281c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace
mindspore
{
class
ConvBiasAddFusionTest
:
public
mindspore
::
Common
{
public:
ConvBiasAddFusionTest
()
=
default
;
};
using
MetaGraphTptr
=
std
::
shared_ptr
<
schema
::
MetaGraphT
>
;
using
CNodeTptr
=
std
::
unique_ptr
<
schema
::
CNodeT
>
;
namespace
{
CNodeTptr
BuildConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
auto
prim1
=
new
schema
::
Conv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelOut
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
CNodeTptr
BuildDepthwiseConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
auto
prim1
=
new
schema
::
DepthwiseConv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelIn
=
1
;
prim1
->
channelMultiplier
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
MetaGraphTptr
BuildGraph
(
schema
::
PrimitiveType
conv_type
,
schema
::
PrimitiveType
add_type
)
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// conv node
CNodeTptr
convNode
;
if
(
conv_type
==
schema
::
PrimitiveType_Conv2D
)
{
convNode
=
BuildConv2D
();
}
else
{
convNode
=
BuildDepthwiseConv2D
();
}
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
convNode
));
// biasadd node
auto
biasadd_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
biasadd_node
->
inputIndex
=
{
2
,
3
};
biasadd_node
->
outputIndex
=
{
4
};
biasadd_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
biasadd_node
->
primitive
->
value
.
type
=
add_type
;
auto
prim2
=
new
schema
::
BiasAddT
;
biasadd_node
->
primitive
->
value
.
value
=
prim2
;
biasadd_node
->
name
=
"BiasAdd"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
biasadd_node
));
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
4
};
// input 0: data
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
5
,
5
,
3
};
input0
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// input 1: weight
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_KHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
8
,
3
,
3
,
3
};
input1
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
// conv output
auto
conv_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
conv_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
conv_output
->
format
=
schema
::
Format_NHWC
;
conv_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
conv_output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
conv_output
));
// input2: bias
auto
input2
=
std
::
make_unique
<
schema
::
TensorT
>
();
input2
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input2
->
format
=
schema
::
Format_NHWC
;
input2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input2
->
dims
=
{
1
,
5
,
5
,
8
};
input2
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input2
));
// final output
auto
output
=
std
::
make_unique
<
schema
::
TensorT
>
();
output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
output
->
format
=
schema
::
Format_NHWC
;
output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
output
));
return
meta_graph
;
}
}
// namespace
TEST_F
(
ConvBiasAddFusionTest
,
TestConvAddNode
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_Conv2D
,
schema
::
PrimitiveType_BiasAdd
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsConv2D
()
->
hasBias
,
true
);
}
MS_LOG
(
INFO
)
<<
"Passed"
;
}
TEST_F
(
ConvBiasAddFusionTest
,
TestDeptiwiseConvAddNode
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_DepthwiseConv2D
,
schema
::
PrimitiveType_Add
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
hasBias
,
true
);
}
}
TEST_F
(
ConvBiasAddFusionTest
,
TestBadCase_ConvAdd
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_DepthwiseConv2D
,
schema
::
PrimitiveType_MatMul
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
2
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
if
(
cnode
->
primitive
->
value
.
type
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
hasBias
,
false
);
}
}
}
}
// namespace mindspore
mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
0 → 100644
浏览文件 @
6e57281c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace
mindspore
{
class
ConvBNFusionTest
:
public
mindspore
::
Common
{
public:
ConvBNFusionTest
()
=
default
;
};
using
MetaGraphTptr
=
std
::
shared_ptr
<
schema
::
MetaGraphT
>
;
using
CNodeTptr
=
std
::
unique_ptr
<
schema
::
CNodeT
>
;
namespace
{
CNodeTptr
BuildConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
auto
prim1
=
new
schema
::
Conv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelOut
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
CNodeTptr
BuildDepthwiseConv2D
()
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
convNode
->
inputIndex
=
{
0
,
1
,
2
};
convNode
->
outputIndex
=
{
3
};
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
auto
prim1
=
new
schema
::
DepthwiseConv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelIn
=
1
;
prim1
->
channelMultiplier
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
// caffe bn op has 3 inputs
MetaGraphTptr
BuildCaffeGraph
(
schema
::
PrimitiveType
conv_type
)
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// conv node
CNodeTptr
convNode
;
if
(
conv_type
==
schema
::
PrimitiveType_Conv2D
)
{
convNode
=
BuildConv2D
();
}
else
{
convNode
=
BuildDepthwiseConv2D
();
}
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
convNode
));
// bn_node
auto
bn_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
bn_node
->
inputIndex
=
{
2
,
3
,
4
};
bn_node
->
outputIndex
=
{
5
};
bn_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
bn_node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_CaffeBatchNorm
;
auto
prim2
=
new
schema
::
CaffeBatchNormT
;
bn_node
->
primitive
->
value
.
value
=
prim2
;
bn_node
->
name
=
"bn"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
bn_node
));
// input 0: data
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
5
,
5
,
3
};
input0
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// input 1: weight
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_KHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
8
,
3
,
3
,
3
};
input1
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
// conv output
auto
conv_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
conv_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
conv_output
->
format
=
schema
::
Format_NHWC
;
conv_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
conv_output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
conv_output
));
// caffe bn : mean
auto
input2
=
std
::
make_unique
<
schema
::
TensorT
>
();
input2
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input2
->
format
=
schema
::
Format_NHWC
;
input2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input2
->
dims
=
{
1
,
5
,
5
,
8
};
input2
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input2
));
// caffe bn : var
auto
input3
=
std
::
make_unique
<
schema
::
TensorT
>
();
input3
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input3
->
format
=
schema
::
Format_NHWC
;
input3
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input3
->
dims
=
{
1
,
5
,
5
,
8
};
input3
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input3
));
// final bn output
auto
output
=
std
::
make_unique
<
schema
::
TensorT
>
();
output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
output
->
format
=
schema
::
Format_NHWC
;
output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
output
));
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
5
};
return
meta_graph
;
}
// tf bn op has 4 inputs
MetaGraphTptr
BuildTFGraph
(
schema
::
PrimitiveType
conv_type
)
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// conv node
CNodeTptr
convNode
;
if
(
conv_type
==
schema
::
PrimitiveType_Conv2D
)
{
convNode
=
BuildConv2D
();
}
else
{
convNode
=
BuildDepthwiseConv2D
();
}
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
convNode
));
// bn_node
auto
bn_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
bn_node
->
inputIndex
=
{
3
,
4
,
5
,
6
,
7
};
bn_node
->
outputIndex
=
{
8
};
bn_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
bn_node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_FusedBatchNorm
;
auto
prim2
=
new
schema
::
FusedBatchNormT
;
bn_node
->
primitive
->
value
.
value
=
prim2
;
bn_node
->
name
=
"bn"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
bn_node
));
// input 0: data
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
5
,
5
,
3
};
input0
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// input 1: conv_bias
auto
input11
=
std
::
make_unique
<
schema
::
TensorT
>
();
input11
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input11
->
format
=
schema
::
Format_KHWC
;
input11
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input11
->
dims
=
{
8
,
3
,
3
,
3
};
input11
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input11
));
// input 1: weight
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_KHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
8
,
3
,
3
,
3
};
input1
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
// conv output
auto
conv_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
conv_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
conv_output
->
format
=
schema
::
Format_NHWC
;
conv_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
conv_output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
conv_output
));
// tflite bn : scale
auto
input2
=
std
::
make_unique
<
schema
::
TensorT
>
();
input2
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input2
->
format
=
schema
::
Format_NHWC
;
input2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input2
->
dims
=
{
1
,
5
,
5
,
8
};
input2
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input2
));
// tflite bn : bias
auto
input3
=
std
::
make_unique
<
schema
::
TensorT
>
();
input3
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input3
->
format
=
schema
::
Format_NHWC
;
input3
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input3
->
dims
=
{
1
,
5
,
5
,
8
};
input3
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input3
));
// tflite bn : mean
auto
input4
=
std
::
make_unique
<
schema
::
TensorT
>
();
input4
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input4
->
format
=
schema
::
Format_NHWC
;
input4
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input4
->
dims
=
{
1
,
5
,
5
,
8
};
input4
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input4
));
// tflite bn : var
auto
input5
=
std
::
make_unique
<
schema
::
TensorT
>
();
input5
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input5
->
format
=
schema
::
Format_NHWC
;
input5
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input5
->
dims
=
{
1
,
5
,
5
,
8
};
input5
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input5
));
// final output
auto
output
=
std
::
make_unique
<
schema
::
TensorT
>
();
output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
output
->
format
=
schema
::
Format_NHWC
;
output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
output
));
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
8
};
return
meta_graph
;
}
}
// namespace
TEST_F
(
ConvBNFusionTest
,
TestConvAddNode
)
{
auto
meta_graph
=
BuildCaffeGraph
(
schema
::
PrimitiveType_Conv2D
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsConv2D
()
->
hasBias
,
true
);
}
}
TEST_F
(
ConvBNFusionTest
,
TestDeptiwiseConvAddNode
)
{
auto
meta_graph
=
BuildTFGraph
(
schema
::
PrimitiveType_DepthwiseConv2D
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
hasBias
,
true
);
}
}
}
// namespace mindspore
mindspore/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
0 → 100644
浏览文件 @
6e57281c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace
mindspore
{
class
ConvScaleFusionTest
:
public
mindspore
::
Common
{
public:
ConvScaleFusionTest
()
=
default
;
};
using
MetaGraphTptr
=
std
::
shared_ptr
<
schema
::
MetaGraphT
>
;
using
CNodeTptr
=
std
::
unique_ptr
<
schema
::
CNodeT
>
;
namespace
{
// conv has 2 inputs
CNodeTptr
BuildConv2D
(
int
with_bias_flag
)
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
if
(
with_bias_flag
)
{
convNode
->
inputIndex
=
{
0
,
1
,
2
};
convNode
->
outputIndex
=
{
3
};
}
else
{
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
}
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
auto
prim1
=
new
schema
::
Conv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelOut
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
// conv2d has 3 inputs
CNodeTptr
BuildDepthwiseConv2D
(
int
with_bias_flag
)
{
auto
convNode
=
std
::
make_unique
<
schema
::
CNodeT
>
();
if
(
with_bias_flag
)
{
convNode
->
inputIndex
=
{
0
,
1
,
2
};
convNode
->
outputIndex
=
{
3
};
}
else
{
convNode
->
inputIndex
=
{
0
,
1
};
convNode
->
outputIndex
=
{
2
};
}
convNode
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
convNode
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
auto
prim1
=
new
schema
::
DepthwiseConv2DT
;
prim1
->
padMode
=
schema
::
PadMode_SAME
;
prim1
->
format
=
schema
::
Format_NHWC
;
prim1
->
strideH
=
1
;
prim1
->
strideW
=
1
;
prim1
->
kernelH
=
3
;
prim1
->
kernelW
=
3
;
prim1
->
dilateH
=
1
;
prim1
->
dilateW
=
1
;
prim1
->
channelIn
=
1
;
prim1
->
channelMultiplier
=
3
;
convNode
->
primitive
->
value
.
value
=
prim1
;
convNode
->
name
=
"Conv2D"
;
return
convNode
;
}
MetaGraphTptr
BuildGraph
(
schema
::
PrimitiveType
conv_type
,
bool
conv_with_bias
)
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// conv node
CNodeTptr
convNode
;
if
(
conv_type
==
schema
::
PrimitiveType_Conv2D
)
{
convNode
=
BuildConv2D
(
conv_with_bias
);
}
else
{
convNode
=
BuildDepthwiseConv2D
(
conv_with_bias
);
}
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
convNode
));
// scale_node weight bias
auto
scale_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
if
(
conv_with_bias
)
{
scale_node
->
inputIndex
=
{
3
,
4
,
5
};
scale_node
->
outputIndex
=
{
6
};
}
else
{
scale_node
->
inputIndex
=
{
2
,
3
,
4
};
scale_node
->
outputIndex
=
{
5
};
}
scale_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
scale_node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Scale
;
auto
prim2
=
new
schema
::
ScaleT
;
scale_node
->
primitive
->
value
.
value
=
prim2
;
scale_node
->
name
=
"scale"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
scale_node
));
// input 0: data
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
5
,
5
,
3
};
input0
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// input 1: weight
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_KHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
8
,
3
,
3
,
3
};
input1
->
data
.
resize
(
sizeof
(
float
)
*
8
*
3
*
3
*
3
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
if
(
conv_with_bias
)
{
// input 00: bias
auto
input00
=
std
::
make_unique
<
schema
::
TensorT
>
();
input00
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input00
->
format
=
schema
::
Format_NHWC
;
input00
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input00
->
dims
=
{
1
,
5
,
5
,
3
};
input00
->
offset
=
-
1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input00
));
}
// conv output
auto
conv_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
conv_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
conv_output
->
format
=
schema
::
Format_NHWC
;
conv_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
conv_output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
conv_output
));
// scale weight input
auto
input2
=
std
::
make_unique
<
schema
::
TensorT
>
();
input2
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input2
->
format
=
schema
::
Format_NHWC
;
input2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input2
->
dims
=
{
1
,
5
,
5
,
8
};
input2
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input2
));
// scale bias input
auto
input3
=
std
::
make_unique
<
schema
::
TensorT
>
();
input3
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input3
->
format
=
schema
::
Format_NHWC
;
input3
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input3
->
dims
=
{
1
,
5
,
5
,
8
};
input3
->
data
.
resize
(
sizeof
(
float
)
*
8
*
5
*
5
);
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input3
));
// final scale output
auto
output
=
std
::
make_unique
<
schema
::
TensorT
>
();
output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
output
->
format
=
schema
::
Format_NHWC
;
output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
output
->
dims
=
{
1
,
5
,
5
,
8
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
output
));
if
(
conv_with_bias
)
{
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
6
};
}
else
{
meta_graph
->
inputIndex
=
{
0
};
meta_graph
->
outputIndex
=
{
5
};
}
return
meta_graph
;
}
}
// namespace
TEST_F
(
ConvScaleFusionTest
,
TestConvScaleNode
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_Conv2D
,
true
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsConv2D
()
->
hasBias
,
true
);
}
}
TEST_F
(
ConvScaleFusionTest
,
TestDeptiwiseConvScaleNode
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_DepthwiseConv2D
,
false
);
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
anf_transform
=
new
lite
::
AnfTransform
();
auto
new_graph
=
anf_transform
->
Transform
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
1
);
for
(
auto
&
cnode
:
new_meta_graph
->
nodes
)
{
ASSERT_EQ
(
cnode
->
primitive
->
value
.
AsDepthwiseConv2D
()
->
hasBias
,
true
);
ASSERT_EQ
(
cnode
->
inputIndex
.
size
(),
3
);
}
}
}
// namespace mindspore
mindspore/lite/tools/converter/CMakeLists.txt
浏览文件 @
6e57281c
...
@@ -49,6 +49,8 @@ set(ANF_SRC
...
@@ -49,6 +49,8 @@ set(ANF_SRC
${
CCSRC_DIR
}
/pybind_api/export_flags.cc
${
CCSRC_DIR
}
/pybind_api/export_flags.cc
${
CCSRC_DIR
}
/utils/context/context_extends.cc
${
CCSRC_DIR
}
/utils/context/context_extends.cc
${
CCSRC_DIR
}
/frontend/parallel/costmodel_context.cc
${
CCSRC_DIR
}
/frontend/parallel/costmodel_context.cc
${
CCSRC_DIR
}
/backend/optimizer/common/pattern_engine.cc
${
CCSRC_DIR
}
/backend/optimizer/common/visit.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/common/graph_utils_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/common/graph_utils_extends.cc
)
)
...
@@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
...
@@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/node_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/node_pass.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/optimizer.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/optimizer.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/pass_manager.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/pass_manager.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/pattern_engine.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/gllo_utils.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/visit.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/common/utils.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_biasadd_fusion.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_biasadd_fusion.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_activation_fusion.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_activation_fusion.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_transform_fusion.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/gllo/fusion/conv_transform_fusion.cc
...
...
mindspore/lite/tools/converter/converter.cc
浏览文件 @
6e57281c
...
@@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
...
@@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return
nullptr
;
return
nullptr
;
}
}
// auto newG
raph = anfTransform->Transform(graph);
g
raph
=
anfTransform
->
Transform
(
graph
);
CreateQuantizer
(
graph
,
flag
);
CreateQuantizer
(
graph
,
flag
);
if
(
mQuantizer
!=
nullptr
)
{
if
(
mQuantizer
!=
nullptr
)
{
...
...
mindspore/lite/tools/converter/graphdef_transform.cc
浏览文件 @
6e57281c
...
@@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
...
@@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// }
// }
// fusion
// fusion
{
//
{
Optimizer
fusionOptimizer
;
//
Optimizer fusionOptimizer;
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
ConvBiasAddFusionPass
());
//
fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
ConvBNFusionPass
());
//
fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
ConvScaleFusionPass
());
//
fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
ConvReluFusionPass
());
//
fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
ConvRelu6FusionPass
());
//
fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
fusionOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
IsolatedNodeRemovePass
());
//
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status
=
fusionOptimizer
.
Run
(
graphDefT
);
//
status = fusionOptimizer.Run(graphDefT);
if
(
status
!=
RET_OK
&&
status
!=
RET_NO_CHANGE
)
{
//
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG
(
ERROR
)
<<
"Run fusionOptimizer graphPasses Failed"
;
//
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return
status
;
//
return status;
}
//
}
}
//
}
// weight format trans
// weight format trans
if
(
ctx
.
formatTrans
)
{
if
(
ctx
.
formatTrans
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录