Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ed76b16
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2ed76b16
编写于
7月 23, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add graph dumper for graph partition
GitOrigin-RevId: 6dbcb67009678ce9a3c895d2115db8c429531cfb
上级
76b28408
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
543 addition
and
35 deletion
+543
-35
src/gopt/impl/subgraph_extractor.cpp
src/gopt/impl/subgraph_extractor.cpp
+243
-30
src/gopt/include/megbrain/gopt/subgraph_extractor.h
src/gopt/include/megbrain/gopt/subgraph_extractor.h
+25
-5
src/gopt/test/subgraph_extractor.cpp
src/gopt/test/subgraph_extractor.cpp
+275
-0
未找到文件。
src/gopt/impl/subgraph_extractor.cpp
浏览文件 @
2ed76b16
...
@@ -11,17 +11,214 @@
...
@@ -11,17 +11,214 @@
*/
*/
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include <atomic>
#include "megbrain/serialization/opr_shallow_copy.h"
using
namespace
mgb
;
using
namespace
mgb
;
using
namespace
cg
;
using
namespace
cg
;
using
namespace
gopt
;
using
namespace
gopt
;
/* ================== GraphPartition::InputPlaceholder =================*/
// clang-format off
MGB_DEFINE_OPR_CLASS
(
GraphPartition
::
InputPlaceholder
,
cg
::
SingleCNOperatorNodeBase
)
// {
public
:
InputPlaceholder
(
VarNode
*
src_var
,
const
TensorShape
&
infer_shp
,
std
::
unique_ptr
<
HostTensorND
>
infer_val
=
nullptr
);
static
SymbolVar
make
(
VarNode
*
src_var
,
const
TensorShape
&
infer_shp
,
std
::
unique_ptr
<
HostTensorND
>
infer_val
=
nullptr
);
size_t
input_id
()
const
{
return
m_id
;
}
private
:
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
void
init_output_comp_node
()
override
;
const
size_t
m_id
;
TensorShape
m_infer_shp
;
std
::
unique_ptr
<
HostTensorND
>
m_infer_val
;
static
std
::
atomic_size_t
sm_id
;
}
;
// clang-format on
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GraphPartition
::
InputPlaceholder
);
std
::
atomic_size_t
GraphPartition
::
InputPlaceholder
::
sm_id
{
0
};
GraphPartition
::
InputPlaceholder
::
InputPlaceholder
(
VarNode
*
src_var
,
const
TensorShape
&
infer_shp
,
std
::
unique_ptr
<
HostTensorND
>
infer_val
)
:
Super
(
src_var
->
owner_graph
(),
{},
{},
{}),
m_id
{
sm_id
.
fetch_add
(
1
,
std
::
memory_order_relaxed
)},
m_infer_shp
{
infer_shp
},
m_infer_val
{
std
::
move
(
infer_val
)}
{
name
(
ssprintf
(
"InputPlaceholder@%zu"
,
m_id
));
add_equivalence_component
<
ScalarHash
<
DTypeEnum
>>
(
src_var
->
dtype
().
enumv
());
add_equivalence_component
<
ScalarHash
<
size_t
>>
(
m_id
);
add_output
(
None
)
->
dtype
(
src_var
->
dtype
());
}
void
GraphPartition
::
InputPlaceholder
::
init_output_comp_node
()
{
output
(
0
)
->
comp_node
(
CompNode
::
default_cpu
());
}
void
GraphPartition
::
InputPlaceholder
::
scn_do_execute
()
{
mgb_throw
(
InternalError
,
"InputPlaceholder opr can not be executed"
);
}
void
GraphPartition
::
InputPlaceholder
::
init_output_static_infer_desc
()
{
using
namespace
cg
::
static_infer
;
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
if
(
m_infer_shp
.
ndim
==
0
)
{
auto
infer_shape
=
[](
TensorShape
&
,
const
InpVal
&
)
{
return
false
;
};
mgr
.
register_shape_infer
(
output
(
0
),
{
SourceType
::
MUTABLE
,
{},
infer_shape
});
}
else
{
mgr
.
register_shape_infer
(
output
(
0
),
ShapeInferDesc
::
make_const
(
m_infer_shp
));
}
if
(
m_infer_val
==
nullptr
)
{
auto
infer_value
=
[](
DeviceTensorND
&
,
const
InpVal
&
)
{
return
false
;
};
mgr
.
register_value_infer
(
output
(
0
),
{
SourceType
::
MUTABLE
,
{},
infer_value
});
}
else
{
auto
infer_value
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
)
{
dest
.
copy_from
(
*
m_infer_val
).
sync
();
return
true
;
};
mgr
.
register_value_infer
(
output
(
0
),
{
SourceType
::
CONSTANT
,
{},
infer_value
});
}
}
SymbolVar
GraphPartition
::
InputPlaceholder
::
make
(
VarNode
*
src_var
,
const
TensorShape
&
infer_shp
,
std
::
unique_ptr
<
HostTensorND
>
infer_val
)
{
return
src_var
->
owner_graph
()
->
insert_opr
(
std
::
make_unique
<
InputPlaceholder
>
(
src_var
,
infer_shp
,
std
::
move
(
infer_val
)))
->
output
(
0
);
}
/* ================== GraphPartition =================*/
#if MGB_ENABLE_JSON
std
::
shared_ptr
<
json
::
Value
>
GraphPartition
::
to_json
()
const
{
auto
replaced_outputs
=
std
::
get
<
1
>
(
replace_graph_by_placeholder
());
ThinHashSet
<
VarNode
*>
all_var_node
;
ThinHashSet
<
OperatorNodeBase
*>
all_opr_node
;
auto
comp_seq
=
json
::
Array
::
make
();
auto
cb
=
[
&
](
OperatorNodeBase
*
opr
)
{
comp_seq
->
add
(
json
::
String
::
make
(
opr
->
id_str
()));
for
(
const
auto
&
i
:
opr
->
input
())
{
if
(
all_var_node
.
count
(
i
)
==
0
)
{
all_var_node
.
insert
(
i
);
}
}
all_opr_node
.
insert
(
opr
);
for
(
const
auto
&
o
:
opr
->
output
())
{
all_var_node
.
insert
(
o
);
}
};
cg
::
DepOprIter
iter
{
cb
};
for
(
const
auto
&
o
:
replaced_outputs
)
iter
.
add
(
o
->
owner_opr
());
auto
dump_node_coll
=
[](
auto
&&
collection
)
{
auto
objptr
=
json
::
Object
::
make
();
auto
&&
obj
=
*
objptr
;
for
(
auto
&&
i
:
collection
)
obj
[
i
->
id_str
()]
=
i
->
to_json
();
return
objptr
;
};
return
json
::
Object
::
make
({{
"operator"
,
dump_node_coll
(
all_opr_node
)},
{
"var"
,
dump_node_coll
(
all_var_node
)},
{
"comp_seq"
,
comp_seq
}});
}
#endif
std
::
pair
<
VarNodeArray
,
VarNodeArray
>
GraphPartition
::
replace_graph_by_placeholder
()
const
{
ThinHashMap
<
VarNode
*
,
VarNode
*>
old2new
;
auto
graph_partition_copy_opr_shallow
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
inps
)
{
OperatorNodeConfig
config
=
opr
->
config
();
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
config
)
->
output
(
0
);
};
OperatorNodeSet
input_opr_set
;
for
(
const
auto
&
i
:
m_inputs
)
input_opr_set
.
insert
(
i
->
owner_opr
());
VarNodeArray
placeholders
;
VarNodeArray
replaced_outputs
;
VarNodeArray
new_i
;
auto
cb
=
[
&
](
OperatorNodeBase
*
opr
)
{
for
(
const
auto
&
o
:
opr
->
output
())
{
if
(
o
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
)
||
(
input_opr_set
.
count
(
opr
)
&&
!
m_inputs
.
count
(
o
)))
{
continue
;
}
VarNode
*
new_o
;
if
(
m_inputs
.
count
(
o
))
{
auto
&&
mgr
=
opr
->
owner_graph
()
->
static_infer_manager
();
const
TensorShape
*
shp_ptr
=
nullptr
;
if
(
cg
::
is_static_var_shape
(
o
))
{
shp_ptr
=
mgr
.
infer_shape_fallible
(
o
);
}
TensorShape
infer_shp
;
if
(
shp_ptr
)
infer_shp
=
*
shp_ptr
;
std
::
unique_ptr
<
HostTensorND
>
hval
=
nullptr
;
const
DeviceTensorND
*
dval_ptr
=
nullptr
;
if
(
cg
::
is_static_var_value
(
o
))
{
dval_ptr
=
mgr
.
infer_value_fallible
(
o
);
}
if
(
dval_ptr
)
{
hval
.
reset
(
new
HostTensorND
(
CompNode
::
default_cpu
(),
dval_ptr
->
dtype
()));
hval
->
resize
(
dval_ptr
->
shape
()).
copy_from
(
*
dval_ptr
).
sync
();
}
new_o
=
InputPlaceholder
::
make
(
o
,
infer_shp
,
std
::
move
(
hval
))
.
node
();
placeholders
.
push_back
(
new_o
);
}
else
{
new_i
.
clear
();
for
(
const
auto
&
i
:
opr
->
input
())
{
new_i
.
push_back
(
old2new
.
at
(
i
));
}
new_o
=
graph_partition_copy_opr_shallow
(
o
->
owner_opr
(),
new_i
);
}
old2new
[
o
]
=
new_o
;
}
};
cg
::
DepOprIter
iter
{
cb
};
for
(
auto
&&
i
:
m_inputs
)
{
for
(
auto
&&
j
:
i
->
owner_opr
()
->
input
())
{
if
(
!
input_opr_set
.
count
(
j
->
owner_opr
())
&&
!
m_opr_set
.
count
(
j
->
owner_opr
()))
{
iter
.
set_visited
(
j
->
owner_opr
());
}
}
}
for
(
auto
&&
o
:
m_outputs
)
iter
.
add
(
o
->
owner_opr
());
for
(
auto
&&
o
:
m_outputs
)
{
replaced_outputs
.
push_back
(
old2new
.
at
(
o
));
}
return
std
::
make_pair
(
placeholders
,
replaced_outputs
);
}
/* ================== SubGraphExtractor =================*/
/* ================== SubGraphExtractor =================*/
std
::
vector
<
InternalGraph
>
SubGraphExtractor
::
extract
(
std
::
vector
<
GraphPartition
>
SubGraphExtractor
::
extract
(
const
SymbolVarArray
&
endpoint_vars
)
const
{
const
SymbolVarArray
&
endpoint_vars
)
const
{
ThinHashMap
<
OperatorNodeBase
*
,
std
::
pair
<
OperatorNodeBase
*
,
int
>>
parent
;
ThinHashMap
<
OperatorNodeBase
*
,
std
::
pair
<
OperatorNodeBase
*
,
int
>>
parent
;
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
)
>
union_find
;
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
)
>
union_find
;
auto
union_find
=
[
&
parent
,
&
union_find
](
OperatorNodeBase
*
o
)
{
union_find
=
[
&
parent
,
&
union_find
](
OperatorNodeBase
*
o
)
{
if
(
parent
[
o
].
first
==
o
)
if
(
parent
[
o
].
first
==
o
)
return
o
;
return
o
;
else
{
else
{
...
@@ -34,7 +231,7 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
...
@@ -34,7 +231,7 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
OperatorNodeBase
*
y
)
{
OperatorNodeBase
*
y
)
{
auto
root_x
=
union_find
(
x
),
root_y
=
union_find
(
y
);
auto
root_x
=
union_find
(
x
),
root_y
=
union_find
(
y
);
if
(
root_x
!=
root_y
)
{
if
(
root_x
!=
root_y
)
{
OperatorNodeBase
*
large
,
small
;
OperatorNodeBase
*
large
,
*
small
;
if
(
parent
[
root_x
].
second
<
parent
[
root_y
].
second
)
{
if
(
parent
[
root_x
].
second
<
parent
[
root_y
].
second
)
{
small
=
root_x
,
large
=
root_y
;
small
=
root_x
,
large
=
root_y
;
}
else
{
}
else
{
...
@@ -42,25 +239,23 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
...
@@ -42,25 +239,23 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
}
}
parent
[
small
].
first
=
large
;
parent
[
small
].
first
=
large
;
if
(
parent
[
large
].
second
==
parent
[
small
].
second
)
{
if
(
parent
[
large
].
second
==
parent
[
small
].
second
)
{
paren
d
[
large
].
second
+=
1
;
paren
t
[
large
].
second
+=
1
;
}
}
}
}
};
};
std
::
vector
<
OperatorNodeBase
*>
topo
;
std
::
vector
<
OperatorNodeBase
*>
topo
;
auto
cb
=
[
&
topo
](
OperatorNodeBase
*
opr
)
{
auto
cb
=
[
this
,
&
parent
,
&
union_merge
,
&
topo
](
OperatorNodeBase
*
opr
)
{
topo
.
push_back
(
opr
);
topo
.
push_back
(
opr
);
if
(
opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
if
(
m_
opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
return
;
return
;
auto
find
=
parent
.
find
(
opr
);
auto
find
=
parent
.
find
(
opr
);
if
(
find
==
parent
.
end
())
{
if
(
find
==
parent
.
end
())
{
auto
insert
=
parent
.
insert
(
std
::
make_pair
(
opr
,
std
::
make_pair
(
opr
,
0
)));
parent
.
insert
(
std
::
make_pair
(
opr
,
std
::
make_pair
(
opr
,
0
)));
find
=
insert
.
first
;
}
}
for
(
auto
&&
i
:
opr
->
input
())
{
for
(
auto
&&
i
:
opr
->
input
())
{
auto
&&
o
=
i
->
owner_opr
();
auto
&&
o
=
i
->
owner_opr
();
if
(
opr_list
.
count
(
o
->
dyn_typeinfo
())
==
0
)
if
(
m_
opr_list
.
count
(
o
->
dyn_typeinfo
())
==
0
)
continue
;
continue
;
union_merge
(
opr
,
o
);
union_merge
(
opr
,
o
);
}
}
...
@@ -69,33 +264,51 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
...
@@ -69,33 +264,51 @@ std::vector<InternalGraph> SubGraphExtractor::extract(
for
(
const
auto
&
v
:
endpoint_vars
)
for
(
const
auto
&
v
:
endpoint_vars
)
iter
.
add
(
v
.
node
()
->
owner_opr
());
iter
.
add
(
v
.
node
()
->
owner_opr
());
std
::
vector
<
InternalGraph
>
partitions
;
std
::
vector
<
GraphPartition
>
partitions
;
ThinHashMap
<
OperatorNodeBase
*
,
InternalGraph
*>
roots
;
partitions
.
reserve
(
topo
.
size
());
ThinHashMap
<
OperatorNodeBase
*
,
GraphPartition
*>
roots
;
for
(
const
auto
&
opr
:
reverse_adaptor
(
topo
))
{
for
(
const
auto
&
opr
:
reverse_adaptor
(
topo
))
{
auto
root
=
union_find
(
opr
);
if
(
m_opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
{
auto
find
=
roots
.
find
(
root
);
for
(
const
auto
&
i
:
opr
->
input
())
{
InternalGraph
*
internal_graph
=
nullptr
;
if
(
m_opr_list
.
count
(
i
->
owner_opr
()
->
dyn_typeinfo
()))
{
if
(
find
==
roots
.
end
())
{
auto
root
=
union_find
(
i
->
owner_opr
());
partitions
.
emplace_back
(
InternalGraph
{});
GraphPartition
*
partition
;
auto
insert
=
auto
find
=
roots
.
find
(
root
);
roots
.
insert
(
std
::
make_pair
(
root
,
&
partitions
.
back
()));
if
(
find
!=
roots
.
end
())
{
internal_graph
=
insert
.
first
->
second
;
partition
=
find
->
second
;
internal_graph
->
m_outputs
.
insert
(
opr
->
output
(
0
));
partition
->
output
().
insert
(
i
);
}
}
}
}
else
{
}
else
{
internal_graph
=
find
->
second
;
auto
root
=
union_find
(
opr
);
auto
erase
=
internal_graph
->
m_inputs
.
erase
(
opr
->
output
(
0
));
auto
find
=
roots
.
find
(
root
);
if
(
erase
>
0
)
{
GraphPartition
*
partition
=
nullptr
;
internal_graph
->
m_internals
.
insert
(
opr
->
output
(
0
));
if
(
find
==
roots
.
end
())
{
partitions
.
emplace_back
(
GraphPartition
{});
auto
insert
=
roots
.
insert
(
std
::
make_pair
(
root
,
&
partitions
.
back
()));
partition
=
insert
.
first
->
second
;
for
(
auto
&&
o
:
opr
->
output
())
{
if
(
!
o
->
contain_flag
(
cg
::
VarNode
::
Flag
::
VOLATILE_CONTENT
))
partition
->
output
().
insert
(
o
);
}
}
else
{
}
else
{
internal_graph
->
m_outputs
.
insert
(
opr
->
output
(
0
));
partition
=
find
->
second
;
for
(
auto
&&
o
:
opr
->
output
())
{
if
(
!
o
->
contain_flag
(
cg
::
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
auto
erase
=
partition
->
input
().
erase
(
o
);
if
(
erase
==
0
)
partition
->
output
().
insert
(
o
);
}
}
}
}
partition
->
opr_set
().
insert
(
opr
);
for
(
const
auto
&
i
:
opr
->
input
())
partition
->
input
().
insert
(
i
);
}
}
for
(
const
auto
&
i
:
opr
->
input
())
internal_graph
->
m_inputs
.
insert
(
i
);
}
}
return
partitions
;
return
partitions
;
}
}
/* ============= SubGraphExtractor =================*/
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
src/gopt/include/megbrain/gopt/subgraph_extractor.h
浏览文件 @
2ed76b16
...
@@ -16,17 +16,37 @@
...
@@ -16,17 +16,37 @@
namespace
mgb
{
namespace
mgb
{
namespace
gopt
{
namespace
gopt
{
struct
InternalGraph
{
class
GraphPartition
{
ThinHashSet
<
VarNode
*>
m_internals
;
public:
ThinHashSet
<
VarNode
*>
m_inputs
;
using
VarNodeSet
=
ThinHashSet
<
VarNode
*>
;
ThinHashSet
<
VarNode
*>
m_outputs
;
using
OperatorNodeSet
=
ThinHashSet
<
cg
::
OperatorNodeBase
*>
;
class
InputPlaceholder
;
GraphPartition
()
=
default
;
#if MGB_ENABLE_JSON
std
::
shared_ptr
<
json
::
Value
>
to_json
()
const
;
#endif
const
OperatorNodeSet
&
opr_set
()
const
{
return
m_opr_set
;
}
const
VarNodeSet
&
input
()
const
{
return
m_inputs
;
}
const
VarNodeSet
&
output
()
const
{
return
m_outputs
;
}
OperatorNodeSet
&
opr_set
()
{
return
m_opr_set
;
}
VarNodeSet
&
input
()
{
return
m_inputs
;
}
VarNodeSet
&
output
()
{
return
m_outputs
;
}
private:
OperatorNodeSet
m_opr_set
;
VarNodeSet
m_inputs
;
VarNodeSet
m_outputs
;
std
::
pair
<
VarNodeArray
,
VarNodeArray
>
replace_graph_by_placeholder
()
const
;
};
};
class
SubGraphExtractor
{
class
SubGraphExtractor
{
public:
public:
using
OprList
=
ThinHashSet
<
Typeinfo
*>
;
using
OprList
=
ThinHashSet
<
Typeinfo
*>
;
SubGraphExtractor
(
OprList
opr_list
)
:
m_opr_list
{
opr_list
}
{};
SubGraphExtractor
(
OprList
opr_list
)
:
m_opr_list
{
opr_list
}
{};
std
::
vector
<
InternalGraph
>
extract
(
std
::
vector
<
GraphPartition
>
extract
(
const
SymbolVarArray
&
endpoint_vars
)
const
;
const
SymbolVarArray
&
endpoint_vars
)
const
;
private:
private:
...
...
src/gopt/test/subgraph_extractor.cpp
0 → 100644
浏览文件 @
2ed76b16
/**
* \file src/gopt/test/subgraph_extractor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/internal/identical_fwd.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
using
namespace
mgb
;
using
namespace
gopt
;
using
namespace
serialization
;
namespace
{
// clang-format off
MGB_DEFINE_OPR_CLASS
(
MultipleInputOutput
,
cg
::
SingleCNOperatorNodeBase
)
// {
public:
MultipleInputOutput
(
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
);
static
SymbolVarArray
make
(
const
SymbolVarArray
&
inputs
,
const
OperatorNodeConfig
&
config
=
{});
private:
void
scn_do_execute
()
override
{
}
void
init_output_static_infer_desc
()
override
{
}
};
// clang-format on
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
MultipleInputOutput
);
MultipleInputOutput
::
MultipleInputOutput
(
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
:
Super
(
inputs
[
0
]
->
owner_graph
(),
config
,
"multiple_input_output"
,
inputs
)
{
for
(
auto
&&
i
:
inputs
)
add_input
({
i
});
if
(
inputs
.
size
()
==
1
)
{
add_output
(
None
);
}
else
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
add_output
(
ssprintf
(
"o%zu"
,
i
));
}
cg
::
add_workspace_output
(
this
);
}
SymbolVarArray
MultipleInputOutput
::
make
(
const
SymbolVarArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
src
=
cg
::
to_var_node_array
(
inputs
);
auto
multiple_io
=
std
::
make_unique
<
MultipleInputOutput
>
(
src
,
config
);
auto
ret
=
cg
::
to_symbol_var_array
(
src
[
0
]
->
owner_graph
()
->
insert_opr
(
std
::
move
(
multiple_io
))
->
output
());
ret
.
pop_back
();
return
ret
;
}
}
TEST
(
TestSubGraphExtractor
,
MultipleOutputs
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
)).
rename
(
name
);
};
graph
->
options
().
graph_opt_level
=
0
;
auto
x
=
mkvar
(
"x"
,
{
8
,
8
,
8
,
8
}),
w1
=
mkcvar
(
"w1"
,
{
4
,
8
,
3
,
3
});
auto
y
=
mkvar
(
"y"
,
{
1
,
8
,
1
,
1
});
auto
add
=
x
+
y
;
opr
::
Convolution
::
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
c1
=
opr
::
Convolution
::
make
(
add
,
w1
,
param
);
auto
w2
=
mkcvar
(
"w2"
,
{
8
,
4
,
3
,
3
});
auto
c2
=
opr
::
ConvolutionBackwardData
::
make
(
w2
,
add
,
param
,
{},
{});
auto
sym_var_arr
=
MultipleInputOutput
::
make
({
c1
,
c2
});
auto
z
=
sym_var_arr
[
1
];
z
=
z
+
(
-
128
);
using
OprList
=
SubGraphExtractor
::
OprList
;
static
const
OprList
opr_list
=
{
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
MultipleInputOutput
::
typeinfo
(),
};
SubGraphExtractor
extractor
(
opr_list
);
auto
partitions
=
extractor
.
extract
({
z
});
ASSERT_EQ
(
partitions
.
size
(),
1u
);
// outputs: sym_var_arr[0], z, add
ASSERT_EQ
(
partitions
[
0
].
output
().
size
(),
3u
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
add
.
node
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
z
.
node
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
sym_var_arr
[
0
].
node
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
sym_var_arr
[
1
].
node
())
==
0
);
// inputs: x, y, w1, c2, (-128)
ASSERT_EQ
(
partitions
[
0
].
input
().
size
(),
5u
);
ASSERT_TRUE
(
partitions
[
0
].
input
().
count
(
x
.
node
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
input
().
count
(
c2
.
node
())
>
0
);
// opr: (x + y) conv1 multi_io, (z - 128)
ASSERT_EQ
(
partitions
[
0
].
opr_set
().
size
(),
4u
);
ASSERT_TRUE
(
partitions
[
0
].
opr_set
().
count
(
add
.
node
()
->
owner_opr
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
opr_set
().
count
(
c1
.
node
()
->
owner_opr
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
opr_set
().
count
(
sym_var_arr
[
0
].
node
()
->
owner_opr
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
opr_set
().
count
(
z
.
node
()
->
owner_opr
())
>
0
);
}
TEST
(
TestSubGraphExtractor
,
MultipleReaders
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
)).
rename
(
name
);
};
graph
->
options
().
graph_opt_level
=
0
;
auto
x
=
mkvar
(
"x"
,
{
8
,
8
,
8
,
8
}),
w1
=
mkcvar
(
"w1"
,
{
4
,
8
,
3
,
3
});
auto
y
=
mkvar
(
"y"
,
{
1
,
8
,
1
,
1
});
auto
add
=
x
+
y
;
opr
::
Convolution
::
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
c1
=
opr
::
Convolution
::
make
(
add
,
w1
,
param
);
auto
w2
=
mkcvar
(
"w2"
,
{
8
,
4
,
3
,
3
});
auto
c2
=
opr
::
ConvolutionBackwardData
::
make
(
w2
,
add
,
param
,
{},
{});
auto
z
=
c1
+
c2
;
using
OprList
=
SubGraphExtractor
::
OprList
;
static
const
OprList
opr_list
=
{
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
};
SubGraphExtractor
extractor
(
opr_list
);
auto
partitions
=
extractor
.
extract
({
z
});
ASSERT_EQ
(
partitions
.
size
(),
1u
);
ASSERT_EQ
(
partitions
[
0
].
output
().
size
(),
2u
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
add
.
node
())
>
0
);
ASSERT_TRUE
(
partitions
[
0
].
output
().
count
(
z
.
node
())
>
0
);
ASSERT_EQ
(
partitions
[
0
].
input
().
size
(),
4u
);
ASSERT_TRUE
(
partitions
[
0
].
input
().
count
(
x
.
node
())
>
0
);
partitions
[
0
].
to_json
()
->
writeto_fpath
(
output_file
(
"TestSubGraphExtractor.MultipleReaders.json"
));
}
TEST
(
TestSubGraphExtractor
,
Complicated
)
{
const
size_t
N
=
16
,
C
=
3
,
H
=
768
,
W
=
1280
;
HostTensorGenerator
<
dtype
::
Uint8
>
gen
;
auto
graph
=
ComputingGraph
::
make
();
/* h2d
|
v
astype(f32)
|
add(-128)
|
v
astype(q8)
|
v
conv1
|
v
astype(u4)
|
/ \
conv2 conv3 -> astype(q32) -> output
\ /
qadd
|
v
astype(q8)
/ \
deconv conv4
\ /
concat -> output */
auto
h2d
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
({
N
,
C
,
H
,
W
}));
auto
data
=
opr
::
TypeCvt
::
make
(
h2d
,
dtype
::
Float32
());
auto
sub_128
=
data
+
(
-
128
);
auto
x
=
opr
::
TypeCvt
::
make
(
sub_128
,
dtype
::
QuantizedS8
(
1.
f
));
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
)).
rename
(
name
),
dtype
);
};
auto
w1
=
mkcvar
(
"w1"
,
{
16
,
3
,
3
,
3
},
dtype
::
QuantizedS8
(
1.
f
));
auto
b1
=
mkcvar
(
"b1"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
1.
f
));
opr
::
ConvBias
::
Param
param
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
conv1
=
opr
::
ConvBias
::
make
(
x
,
w1
,
b1
,
param
,
{},
OperatorNodeConfig
(
dtype
::
QuantizedS8
(
1.
f
)));
conv1
=
opr
::
TypeCvt
::
make
(
conv1
,
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
)));
auto
w2
=
mkcvar
(
"w2"
,
{
16
,
16
,
3
,
3
},
dtype
::
QuantizedS4
(
1.
f
));
auto
b2
=
mkcvar
(
"b2"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
1.
f
));
auto
conv2
=
opr
::
ConvBias
::
make
(
conv1
,
w2
,
b2
,
param
,
{},
OperatorNodeConfig
(
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
))));
param
.
pad_h
=
param
.
pad_w
=
0
;
auto
w3
=
mkcvar
(
"w3"
,
{
16
,
16
,
1
,
1
},
dtype
::
QuantizedS4
(
1.
f
));
auto
b3
=
mkcvar
(
"b3"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
1.
f
));
auto
conv3
=
opr
::
ConvBias
::
make
(
conv1
,
w3
,
b3
,
param
,
{},
OperatorNodeConfig
(
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
))));
auto
conv3f
=
opr
::
TypeCvt
::
make
(
conv3
,
dtype
::
Float32
());
auto
qadd
=
opr
::
ElemwiseMultiType
::
make
(
{
conv2
,
conv3
},
{
opr
::
ElemwiseMultiType
::
Mode
::
QADD
},
OperatorNodeConfig
(
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
))));
auto
q8
=
opr
::
TypeCvt
::
make
(
qadd
,
dtype
::
QuantizedS8
(
1.
f
));
auto
w4
=
mkcvar
(
"w4"
,
{
16
,
16
,
3
,
3
},
dtype
::
QuantizedS8
(
1.
f
));
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
conv4
=
opr
::
ConvBiasForward
::
make
(
q8
,
w4
,
param
,
{},
OperatorNodeConfig
(
dtype
::
QuantizedS8
(
1.
f
)));
conv4
=
opr
::
TypeCvt
::
make
(
conv4
,
dtype
::
Float32
());
opr
::
Convolution
::
Param
conv_param
;
conv_param
.
stride_h
=
param
.
stride_w
=
1
;
conv_param
.
pad_h
=
param
.
pad_w
=
0
;
auto
w5
=
mkcvar
(
"w4"
,
{
16
,
16
,
1
,
1
},
dtype
::
QuantizedS8
(
1.
f
));
auto
deconv
=
opr
::
ConvolutionBackwardData
::
make
(
w5
,
q8
,
conv_param
,
{},
OperatorNodeConfig
(
dtype
::
QuantizedS8
(
1.
f
)));
deconv
=
opr
::
TypeCvt
::
make
(
deconv
,
dtype
::
Float32
());
auto
z
=
opr
::
Concat
::
make
({
conv4
,
deconv
},
1
);
using
OprList
=
SubGraphExtractor
::
OprList
;
static
const
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
ConvolutionBackwardData
::
typeinfo
(),
opr
::
ElemwiseMultiType
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
opr
::
PoolingForward
::
typeinfo
(),
opr
::
WarpPerspectiveForward
::
typeinfo
(),
};
SubGraphExtractor
extractor
(
opr_list
);
auto
partitions
=
extractor
.
extract
({
conv3f
.
node
(),
z
.
node
()});
ASSERT_EQ
(
partitions
.
size
(),
1u
);
const
char
*
prefix
=
"TestSubGraphExtractor.Complicated"
;
partitions
[
0
].
to_json
()
->
writeto_fpath
(
output_file
(
ssprintf
(
"%s.json"
,
prefix
).
c_str
()));
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录