Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c67c4b7d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c67c4b7d
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add layout transform pass
GitOrigin-RevId: 9ed5f5782eafc8ac82ea2397075b9db9042020e4
上级
2ec7c167
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
472 addition
and
0 deletion
+472
-0
src/gopt/impl/layout_transform_pass.cpp
src/gopt/impl/layout_transform_pass.cpp
+169
-0
src/gopt/test/layout_transform_pass.cpp
src/gopt/test/layout_transform_pass.cpp
+303
-0
未找到文件。
src/gopt/impl/layout_transform_pass.cpp
0 → 100644
浏览文件 @
c67c4b7d
/**
* \file src/gopt/impl/layout_transform_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/serialization/sereg.h"
using
namespace
mgb
;
using
namespace
gopt
;
using
namespace
cg
;
/* =================== LayoutTransformPass ======================*/
void
LayoutTransformPass
::
apply
(
OptState
&
opt
)
const
{
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
CHECK_ALL
^
VarReplaceCheckFlag
::
CHECK_SHAPE
);
SubGraphExtractor
extractor
(
m_ctx
->
opr_list
());
auto
partitions
=
extractor
.
extract
(
opt
.
graph
().
endpoint_vars
());
using
Solution
=
SolverBase
::
Solution
;
Solution
solution
;
ThinHashSet
<
VarNode
*>
endpoint_vars
;
for
(
auto
&&
partition
:
partitions
)
{
if
(
solution
.
empty
())
{
solution
=
m_solver
->
solve
(
Problem
(
partition
,
*
m_ctx
));
}
else
{
auto
s
=
m_solver
->
solve
(
Problem
(
partition
,
*
m_ctx
));
for
(
auto
&&
kv
:
s
)
solution
.
insert
({
kv
.
first
,
kv
.
second
});
}
for
(
auto
&&
o
:
partition
.
output
())
{
endpoint_vars
.
insert
(
o
);
}
}
auto
&&
opr_configs
=
m_ctx
->
opr_configs
();
auto
&&
base_fmt
=
m_ctx
->
attribute
().
base_tensor_formats
;
auto
&&
reformat_attribute
=
m_ctx
->
attribute
().
reformat_attribute
;
ThinHashMap
<
VarNode
*
,
TensorFormats
>
var2fmts
;
static
ThinHashSet
<
Typeinfo
*>
format_aware_oprs
=
{
#define cb(_Opr) opr::_Opr::typeinfo(),
FOREACH_FORMAT_AWARE_OPR
(
cb
)
#undef cb
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
this
,
&
opr_configs
,
&
base_fmt
,
&
reformat_attribute
,
&
rewriter
,
&
solution
,
&
var2fmts
,
&
endpoint_vars
](
OperatorNodeBase
*
opr
)
{
auto
it
=
solution
.
find
(
opr
);
if
(
it
!=
solution
.
end
())
{
auto
opr_fmt
=
it
->
second
;
auto
find
=
opr_configs
.
find
(
opr
->
dyn_typeinfo
());
Maybe
<
OprTensorFormatsConfiguration
>
fmtcfg
=
None
;
if
(
find
!=
opr_configs
.
end
())
{
fmtcfg
=
(
*
find
->
second
.
at
(
opr_fmt
))(
opr
);
}
VarNodeArray
new_inp
;
size_t
nr_inps
=
opr
->
input
().
size
();
TensorFormats
out_fmt
;
if
(
fmtcfg
.
valid
())
{
nr_inps
=
std
::
min
(
fmtcfg
.
val
().
input_tensor_formats
.
size
(),
nr_inps
);
out_fmt
=
fmtcfg
.
val
().
output_tensor_formats
[
0
];
}
else
{
out_fmt
=
opr_format_to_tensor_formats
(
opr_fmt
);
}
new_inp
.
resize
(
nr_inps
);
for
(
size_t
i
=
0
;
i
<
nr_inps
;
++
i
)
{
auto
&&
var
=
opr
->
input
(
i
);
auto
&&
new_var
=
rewriter
.
get_var
(
var
);
auto
find
=
var2fmts
.
find
(
new_var
);
TensorFormats
from
;
if
(
find
==
var2fmts
.
end
())
{
from
=
base_fmt
;
}
else
{
from
=
find
->
second
;
}
auto
to
=
fmtcfg
.
valid
()
?
fmtcfg
.
val
().
input_tensor_formats
[
i
]
:
opr_format_to_tensor_formats
(
opr_fmt
);
bool
is_parameter
=
fmtcfg
.
valid
()
&&
fmtcfg
.
val
().
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
;
ReformatManager
::
ReformatImpl
reformat
;
ReformatManager
::
ReformatKey
key
{
from
,
to
,
reformat_attribute
,
var
->
dtype
().
enumv
(),
var
->
dtype
().
enumv
()};
if
(
is_parameter
)
{
auto
aligned_desc
=
make_aligned_desc
(
base_fmt
,
out_fmt
);
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_weight
(
var
,
key
,
aligned_desc
);
}
else
{
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_featrue
(
var
,
base_fmt
,
key
);
}
if
(
from
!=
to
&&
!
new_var
->
shape
().
is_scalar
())
new_var
=
reformat
({
new_var
});
new_inp
[
i
]
=
new_var
;
}
VarNode
*
new_out
;
if
(
format_aware_oprs
.
count
(
opr
->
dyn_typeinfo
())
>
0
)
{
new_out
=
intl
::
modify_opr_format
(
opr_fmt
,
new_inp
,
opr
);
}
else
{
new_out
=
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
())
->
output
(
0
);
}
if
(
endpoint_vars
.
count
(
opr
->
output
(
0
))
&&
out_fmt
!=
base_fmt
)
{
ReformatManager
::
ReformatKey
key
{
out_fmt
,
base_fmt
,
reformat_attribute
,
opr
->
output
(
0
)
->
dtype
().
enumv
(),
opr
->
output
(
0
)
->
dtype
().
enumv
()};
auto
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_featrue
(
opr
->
output
(
0
),
base_fmt
,
key
);
new_out
=
reformat
({
new_out
});
var2fmts
[
new_out
]
=
base_fmt
;
}
else
{
var2fmts
[
new_out
]
=
out_fmt
;
}
auto
&&
out0
=
opr
->
output
(),
&&
out1
=
new_out
->
owner_opr
()
->
output
();
mgb_assert
(
opr
->
usable_output
().
size
()
==
new_out
->
owner_opr
()
->
usable_output
().
size
(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu"
,
opr
->
cname
(),
opr
->
dyn_typeinfo
()
->
name
,
new_out
->
owner_opr
()
->
cname
(),
new_out
->
owner_opr
()
->
dyn_typeinfo
()
->
name
,
out0
.
size
(),
out1
.
size
());
for
(
size_t
i
=
0
;
i
<
out0
.
size
();
++
i
)
{
if
(
!
out0
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
mgb_assert
(
!
out1
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
));
auto
src
=
out0
[
i
];
auto
dst
=
out1
[
i
];
rewriter
.
replace_var
(
src
,
dst
,
mgb_cstr_log
(
ssprintf
(
"replace opr(%s) to new opr "
"format(%s)"
,
opr
->
cname
(),
opr_format_to_string
(
opr_fmt
))
.
c_str
()));
}
}
}
else
{
auto
new_opr
=
rewriter
.
auto_replace_outputs
(
opr
);
var2fmts
[
new_opr
->
output
(
0
)]
=
base_fmt
;
}
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
}
// vim: syntax=cpp.doxygen
src/gopt/test/layout_transform_pass.cpp
0 → 100644
浏览文件 @
c67c4b7d
/**
* \file src/gopt/test/layout_transform_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/serializer.h"
using
namespace
mgb
;
using
namespace
gopt
;
using
namespace
serialization
;
#if MGB_CUDA
TEST
(
TestLayoutTransform
,
Feature
)
{
auto
inp_file
=
InputFile
::
make_fs
(
"./feat.mdl"
);
auto
format
=
GraphLoader
::
identify_graph_dump_format
(
*
inp_file
);
ASSERT_TRUE
(
format
.
valid
());
auto
loader
=
GraphLoader
::
make
(
std
::
move
(
inp_file
),
format
.
val
());
GraphLoader
::
LoadConfig
load_config
;
load_config
.
comp_graph
=
ComputingGraph
::
make
();
auto
&&
graph_opt
=
load_config
.
comp_graph
->
options
();
graph_opt
.
graph_opt
.
enable_fuse_conv_bias_nonlinearity
();
graph_opt
.
graph_opt
.
enable_fuse_conv_bias_with_z
();
auto
ret
=
loader
->
load
(
load_config
,
false
);
using
S
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
S
strategy
=
S
::
PROFILE
;
gopt
::
modify_opr_algo_strategy_inplace
({
ret
.
output_var_list
},
strategy
);
using
OprFormat
=
LayoutTransformContext
::
OprFormat
;
using
OprList
=
LayoutTransformContext
::
OprList
;
using
ReformatAttribute
=
LayoutTransformContext
::
ReformatAttribute
;
using
Attribute
=
LayoutTransformContext
::
Attribute
;
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ElemwiseMultiType
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
opr
::
PoolingForward
::
typeinfo
(),
opr
::
WarpPerspectiveForward
::
typeinfo
(),
};
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc32
,
TensorFormats
::
CHWNc4
};
Attribute
attribute
=
{
OprFormat
::
NCHW4
,
TensorFormats
::
NCHWc4
,
ReformatAttribute
::
DEFAULT
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
WarpPerspectiveForward
::
typeinfo
(),
OprFormat
::
NCHW4
);
auto
profiler
=
ProfilerBase
::
make_profiler
();
auto
filter
=
[](
const
GraphPartition
&
partition
)
{
auto
has_nchw4_conv
=
false
;
for
(
auto
&&
opr
:
partition
.
all_oprs
())
{
if
(
opr
->
dyn_typeinfo
()
==
opr
::
ConvBiasForward
::
typeinfo
())
{
auto
&
conv
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
if
(
conv
.
param
().
format
==
LayoutTransformContext
::
OprFormat
::
NCHW4
)
{
has_nchw4_conv
=
true
;
break
;
}
}
}
return
has_nchw4_conv
;
};
std
::
unique_ptr
<
SolverBase
>
solver
{
new
DynamicProgrammingSolver
(
std
::
move
(
profiler
),
std
::
move
(
filter
))};
auto
new_out_vars
=
gopt
::
GraphOptimizer
{}
.
add_pass
<
FuseConvBiasNonlinPass
>
()
.
add_pass
<
FuseConvBiasZPass
>
()
.
add_pass
<
LayoutTransformPass
>
(
std
::
move
(
ctx
),
std
::
move
(
solver
))
.
add_pass
<
ShuffleShuffleRemovePass
>
()
.
add_pass
(
FuseNCHW4Int8Preprocess
::
make
())
.
add_pass
<
FoldingConvBiasDimshufflePass
>
()
.
add_pass
<
ParamFusePass
>
()
.
add_pass
<
ParamMergePass
>
()
.
apply
(
ret
.
output_var_list
)
.
endpoint_vars
();
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
"model_opt.mgb"
));
dumper
->
dump
({
new_out_vars
});
}
TEST
(
TestLayoutTransform
,
Detection
)
{
auto
inp_file
=
InputFile
::
make_fs
(
"./det.mdl"
);
static
const
char
*
magic
=
"mgbteset0"
;
size_t
skip_size
=
sizeof
(
magic
)
+
sizeof
(
uint32_t
);
char
skip
[
skip_size
];
inp_file
->
read
(
skip
,
skip_size
);
auto
format
=
GraphLoader
::
identify_graph_dump_format
(
*
inp_file
);
ASSERT_TRUE
(
format
.
valid
());
auto
loader
=
GraphLoader
::
make
(
std
::
move
(
inp_file
),
format
.
val
());
GraphLoader
::
LoadConfig
load_config
;
load_config
.
comp_graph
=
ComputingGraph
::
make
();
auto
&&
graph_opt
=
load_config
.
comp_graph
->
options
();
graph_opt
.
graph_opt
.
enable_fuse_conv_bias_nonlinearity
();
graph_opt
.
graph_opt
.
enable_fuse_conv_bias_with_z
();
auto
ret
=
loader
->
load
(
load_config
,
false
);
using
S
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
S
strategy
=
S
::
PROFILE
;
gopt
::
modify_opr_algo_strategy_inplace
({
ret
.
output_var_list
},
strategy
);
using
OprFormat
=
LayoutTransformContext
::
OprFormat
;
using
OprList
=
LayoutTransformContext
::
OprList
;
using
ReformatAttribute
=
LayoutTransformContext
::
ReformatAttribute
;
using
Attribute
=
LayoutTransformContext
::
Attribute
;
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
ConvolutionBackwardData
::
typeinfo
(),
opr
::
ElemwiseMultiType
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
opr
::
PoolingForward
::
typeinfo
(),
opr
::
WarpPerspectiveForward
::
typeinfo
(),
};
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NHWC
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc32
,
TensorFormats
::
NCHWc64
,
TensorFormats
::
CHWNc4
};
Attribute
attribute
=
{
OprFormat
::
NCHW
,
TensorFormats
::
NCHW
,
ReformatAttribute
::
DEFAULT
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
.
add_opr_config
(
opr
::
ConvolutionBackwardData
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
WarpPerspectiveForward
::
typeinfo
(),
{
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW64
});
auto
profiler
=
ProfilerBase
::
make_profiler
();
std
::
unique_ptr
<
SolverBase
>
solver
{
new
DynamicProgrammingSolver
(
std
::
move
(
profiler
))};
auto
new_out_vars
=
gopt
::
GraphOptimizer
{}
.
add_pass
<
LayoutTransformPass
>
(
std
::
move
(
ctx
),
std
::
move
(
solver
))
.
add_pass
<
ShuffleShuffleRemovePass
>
()
.
add_pass
(
FuseNCHW4Int8Preprocess
::
make
())
.
add_pass
<
FoldingConvBiasDimshufflePass
>
()
.
add_pass
<
ParamFusePass
>
()
.
add_pass
<
ParamMergePass
>
()
.
apply
(
ret
.
output_var_list
)
.
endpoint_vars
();
using
OutputSpecItem
=
cg
::
ComputingGraph
::
OutputSpecItem
;
std
::
vector
<
OutputSpecItem
>
outs
(
new_out_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
new_out_vars
.
size
();
++
i
)
{
auto
cb
=
[](
DeviceTensorND
&
/* d */
)
{};
outs
[
i
]
=
std
::
make_pair
(
new_out_vars
[
i
],
cb
);
}
GraphProfiler
gprof
{
load_config
.
comp_graph
.
get
()};
auto
func
=
load_config
.
comp_graph
->
compile
(
outs
);
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
func
->
execute
();
func
->
wait
();
gprof
.
to_json_full
(
func
.
get
())
->
writeto_fpath
(
output_file
(
"det.json"
));
}
TEST
(
TestLayoutTransform
,
DetectionHead
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
cn
.
activate
();
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ
(
7
,
5
);
constexpr
size_t
N
=
16
,
C
=
3
,
H
=
768
,
W
=
1280
;
HostTensorGenerator
<
dtype
::
Uint8
>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
h2d
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
({
N
,
C
,
H
,
W
},
cn
));
auto
data
=
opr
::
TypeCvt
::
make
(
h2d
,
dtype
::
Float32
());
auto
sub_128
=
data
+
(
-
128
);
auto
x
=
opr
::
TypeCvt
::
make
(
sub_128
,
dtype
::
QuantizedS8
(
1.
f
));
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
),
dtype
);
};
auto
w
=
mkcvar
(
"w"
,
{
16
,
3
,
3
,
3
},
dtype
::
QuantizedS8
(
1.
f
));
auto
b
=
mkcvar
(
"b"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
1.
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
conv_1
=
opr
::
ConvBias
::
make
(
x
,
w
,
b
,
param
,
{},
OperatorNodeConfig
(
dtype
::
QuantizedS8
(
1.
f
)));
conv_1
=
opr
::
TypeCvt
::
make
(
conv_1
,
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
)));
auto
w1
=
mkcvar
(
"w1"
,
{
16
,
16
,
3
,
3
},
dtype
::
QuantizedS4
(
1.
f
));
auto
b1
=
mkcvar
(
"b1"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
1.
f
));
auto
y
=
opr
::
ConvBias
::
make
(
conv_1
,
w1
,
b1
,
param
,
{},
OperatorNodeConfig
(
dtype
::
Quantized4Asymm
(
1.
f
,
static_cast
<
uint8_t
>
(
8
))));
using
S
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
S
strategy
=
S
::
PROFILE
;
gopt
::
modify_opr_algo_strategy_inplace
({
y
},
strategy
);
using
OprFormat
=
LayoutTransformContext
::
OprFormat
;
using
OprList
=
LayoutTransformContext
::
OprList
;
using
ReformatAttribute
=
LayoutTransformContext
::
ReformatAttribute
;
using
Attribute
=
LayoutTransformContext
::
Attribute
;
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
ConvolutionBackwardData
::
typeinfo
(),
opr
::
ElemwiseMultiType
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
opr
::
PoolingForward
::
typeinfo
(),
opr
::
WarpPerspectiveForward
::
typeinfo
(),
};
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NHWC
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc32
,
TensorFormats
::
NCHWc64
,
TensorFormats
::
CHWNc4
};
Attribute
attribute
=
{
OprFormat
::
NCHW
,
TensorFormats
::
NCHW
,
ReformatAttribute
::
DEFAULT
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
.
add_opr_config
(
opr
::
ConvolutionBackwardData
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
.
add_opr_config
(
opr
::
WarpPerspectiveForward
::
typeinfo
(),
{
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW64
});
auto
profiler
=
ProfilerBase
::
make_profiler
();
std
::
unique_ptr
<
SolverBase
>
solver
{
new
DynamicProgrammingSolver
(
std
::
move
(
profiler
))};
auto
new_out_vars
=
gopt
::
GraphOptimizer
{}
.
add_pass
<
LayoutTransformPass
>
(
std
::
move
(
ctx
),
std
::
move
(
solver
))
.
add_pass
<
ShuffleShuffleRemovePass
>
()
.
add_pass
(
FuseNCHW4Int8Preprocess
::
make
())
.
add_pass
<
FoldingConvBiasDimshufflePass
>
()
.
add_pass
<
ParamFusePass
>
()
.
add_pass
<
ParamMergePass
>
()
.
apply
(
SymbolVarArray
{
y
})
.
endpoint_vars
();
using
OutputSpecItem
=
cg
::
ComputingGraph
::
OutputSpecItem
;
std
::
vector
<
OutputSpecItem
>
outs
(
new_out_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
new_out_vars
.
size
();
++
i
)
{
auto
cb
=
[](
DeviceTensorND
&
/* d */
)
{};
outs
[
i
]
=
std
::
make_pair
(
new_out_vars
[
i
],
cb
);
}
GraphProfiler
gprof
{
graph
.
get
()};
auto
func
=
graph
->
compile
(
outs
);
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
func
->
execute
();
func
->
wait
();
gprof
.
to_json_full
(
func
.
get
())
->
writeto_fpath
(
output_file
(
"det_head.json"
));
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录