Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8b0315b3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8b0315b3
编写于
4月 09, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix nhwcd4 optpass
GitOrigin-RevId: 9295abec77af2763d301ba372116e4b2281f442a
上级
b588d93e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
241 addition
and
31 deletion
+241
-31
dnn/src/naive/handle.cpp
dnn/src/naive/handle.cpp
+5
-0
dnn/src/naive/handle.h
dnn/src/naive/handle.h
+18
-0
dnn/src/naive/local_share/algorithms.h
dnn/src/naive/local_share/algorithms.h
+41
-0
dnn/src/naive/local_share/opr_impl.cpp
dnn/src/naive/local_share/opr_impl.cpp
+73
-0
dnn/src/naive/local_share/opr_impl.h
dnn/src/naive/local_share/opr_impl.h
+6
-18
src/gopt/impl/inference.cpp
src/gopt/impl/inference.cpp
+34
-13
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+64
-0
未找到文件。
dnn/src/naive/handle.cpp
浏览文件 @
8b0315b3
...
...
@@ -94,6 +94,11 @@ DefaultConvolution3DBackwardFilterAlgorithm
HandleImpl
::
m_default_conv3d_bwd_filter_algo
;
DefaultBatchConvBiasForwardAlgorithm
HandleImpl
::
m_default_batch_conv_bias_fwd_algo
;
DefaultLocalShareForwardAlgorithm
HandleImpl
::
m_default_local_share_fwd_algo
;
DefaultLocalShareBackwardDataAlgorithm
HandleImpl
::
m_default_local_share_bwd_data_algo
;
DefaultLocalShareBackwardFilterAlgorithm
HandleImpl
::
m_default_local_share_bwd_filter_algo
;
HandleImpl
::
HandleImpl
(
megcoreComputingHandle_t
computing_handle
,
HandleType
type
)
...
...
dnn/src/naive/handle.h
浏览文件 @
8b0315b3
...
...
@@ -13,6 +13,7 @@
#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/local_share/algorithms.h"
#include "src/naive/convolution3d/algorithms.h"
#include <functional>
...
...
@@ -39,6 +40,11 @@ class HandleImpl : public HandleImplHelper {
m_default_conv3d_bwd_filter_algo
;
static
DefaultBatchConvBiasForwardAlgorithm
m_default_batch_conv_bias_fwd_algo
;
static
DefaultLocalShareForwardAlgorithm
m_default_local_share_fwd_algo
;
static
DefaultLocalShareBackwardDataAlgorithm
m_default_local_share_bwd_data_algo
;
static
DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo
;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template
<
typename
T
>
...
...
@@ -91,6 +97,18 @@ public:
return
&
m_default_batch_conv_bias_fwd_algo
;
}
LocalShareForward
::
Algorithm
*
default_local_share_fwd_algo
()
{
return
&
m_default_local_share_fwd_algo
;
}
LocalShareBackwardData
::
Algorithm
*
default_local_share_bwd_data_algo
()
{
return
&
m_default_local_share_bwd_data_algo
;
}
LocalShareBackwardFilter
::
Algorithm
*
default_local_share_bwd_filter_algo
()
{
return
&
m_default_local_share_bwd_filter_algo
;
}
Relayout
*
relayout_opr
()
override
{
return
get_helper_opr
<
Relayout
,
2
>
(
this
);
}
...
...
dnn/src/naive/local_share/algorithms.h
0 → 100644
浏览文件 @
8b0315b3
/**
* \file dnn/src/naive/local_share/algorithms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace
megdnn
{
namespace
naive
{
class
DefaultLocalShareForwardAlgorithm
final
:
public
megdnn
::
LocalShareForward
::
Algorithm
{
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
class
DefaultLocalShareBackwardDataAlgorithm
final
:
public
megdnn
::
LocalShareBackwardData
::
Algorithm
{
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
class
DefaultLocalShareBackwardFilterAlgorithm
final
:
public
megdnn
::
LocalShareBackwardFilter
::
Algorithm
{
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/local_share/opr_impl.cpp
浏览文件 @
8b0315b3
...
...
@@ -152,4 +152,77 @@ void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src,
StrategyBwdFlt
>
(
src
,
grad
,
diff
,
param
())););
}
std
::
vector
<
LocalShareForward
::
Algorithm
*>
LocalShareForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_fwd_algo
()};
}
LocalShareForward
::
Algorithm
*
LocalShareForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
const
TensorLayout
&
/* grad */
,
size_t
/* workspace_limit_in_bytes */
,
bool
reproducible
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_fwd_algo
();
if
(
reproducible
)
{
megdnn_assert
(
algo
->
is_reproducible
(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible"
,
algo
->
name
());
}
return
algo
;
}
std
::
vector
<
LocalShareBackwardData
::
Algorithm
*>
LocalShareBackwardDataImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_data_algo
()};
}
LocalShareBackwardData
::
Algorithm
*
LocalShareBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* diff */
,
const
TensorLayout
&
/* grad */
,
size_t
/* workspace_limit_in_bytes */
,
bool
reproducible
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_data_algo
();
if
(
reproducible
)
{
megdnn_assert
(
algo
->
is_reproducible
(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible"
,
algo
->
name
());
}
return
algo
;
}
std
::
vector
<
LocalShareBackwardFilter
::
Algorithm
*>
LocalShareBackwardFilterImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_filter_algo
()};
}
LocalShareBackwardFilter
::
Algorithm
*
LocalShareBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
const
TensorLayout
&
/* grad */
,
size_t
/* workspace_limit_in_bytes */
,
bool
reproducible
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_filter_algo
();
if
(
reproducible
)
{
megdnn_assert
(
algo
->
is_reproducible
(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible"
,
algo
->
name
());
}
return
algo
;
}
// vim: syntax=cpp.doxygen
dnn/src/naive/local_share/opr_impl.h
浏览文件 @
8b0315b3
...
...
@@ -27,17 +27,13 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*dst*/
)
override
{
return
{};
}
const
TensorLayout
&
/*dst*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*dst*/
,
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
{
return
nullptr
;
}
bool
/*reproducible*/
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
@@ -55,17 +51,13 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
{
return
{};
}
const
TensorLayout
&
/*grad*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
,
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
{
return
nullptr
;
}
bool
/*reproducible*/
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
@@ -83,17 +75,13 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
{
return
{};
}
const
TensorLayout
&
/*grad*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
,
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
{
return
nullptr
;
}
bool
/*reproducible*/
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
src/gopt/impl/inference.cpp
浏览文件 @
8b0315b3
...
...
@@ -14,6 +14,7 @@
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
...
...
@@ -1358,23 +1359,28 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return
new_pooling_opr
.
node
()
->
owner_opr
();
};
auto
relayout_inp_to_chw
=
[](
OperatorNodeBase
*
opr
,
auto
var_to_chw
=
[](
VarNode
*
inp
,
VarNode
*
new_inp
)
{
if
(
!
inp
->
shape
().
eq_shape
(
new_inp
->
shape
()))
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
&&
inp
->
format
().
type
()
!=
TensorFormat
::
Type
::
IMAGE2D_PACK4
);
mgb_assert
(
new_inp
->
shape
().
ndim
==
5
&&
new_inp
->
format
().
type
()
==
TensorFormat
::
Type
::
IMAGE2D_PACK4
);
auto
param
=
megdnn
::
param
::
RelayoutFormat
();
param
.
mode
=
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NHWCD4I_NCHW
;
auto
rf
=
opr
::
RelayoutFormat
::
make
(
new_inp
,
param
);
return
rf
.
node
();
}
return
new_inp
;
};
auto
relayout_inp_to_chw
=
[
var_to_chw
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
VarNodeArray
t_inp
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
!
opr
->
input
(
i
)
->
shape
().
eq_shape
(
new_inp
[
i
]
->
shape
()))
{
mgb_assert
(
opr
->
input
(
i
)
->
shape
().
ndim
==
4
&&
opr
->
input
(
i
)
->
format
().
type
()
!=
TensorFormat
::
Type
::
IMAGE2D_PACK4
);
mgb_assert
(
new_inp
[
i
]
->
shape
().
ndim
==
5
&&
new_inp
[
i
]
->
format
().
type
()
==
TensorFormat
::
Type
::
IMAGE2D_PACK4
);
auto
param
=
megdnn
::
param
::
RelayoutFormat
();
param
.
mode
=
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NHWCD4I_NCHW
;
auto
rf
=
opr
::
RelayoutFormat
::
make
(
new_inp
[
i
],
param
);
t_inp
[
i
]
=
rf
.
node
();
}
t_inp
[
i
]
=
var_to_chw
(
opr
->
input
(
i
),
new_inp
[
i
]);
}
auto
new_opr
=
serialization
::
copy_opr_shallow
(
*
opr
,
t_inp
,
opr
->
config
());
...
...
@@ -1415,6 +1421,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
}
};
/* This helper function converts the first input to the NCHW format to
* handle operations that do not support NHWCD4 format
*/
auto
relayout_first_inp_to_chw
=
[
var_to_chw
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
->
OperatorNodeBase
*
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
VarNodeArray
t_inp
=
new_inp
;
t_inp
[
0
]
=
var_to_chw
(
opr
->
input
(
0
),
new_inp
[
0
]);
return
serialization
::
copy_opr_shallow
(
*
opr
,
t_inp
,
opr
->
config
());
};
auto
ret
=
std
::
make_unique
<
ConvertFormatPass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
auto
&&
replace_func
=
ret
->
m_opr_replace_func
;
...
...
@@ -1436,6 +1454,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
replace_warp_perspective_opr
;
replace_func
[
opr
::
WarpAffineForward
::
typeinfo
()]
=
replace_warp_affine_opr
;
replace_func
[
opr
::
LocalForward
::
typeinfo
()]
=
relayout_first_inp_to_chw
;
replace_func
[
opr
::
GroupLocalForward
::
typeinfo
()]
=
relayout_first_inp_to_chw
;
return
ret
;
}
...
...
src/gopt/test/inference.cpp
浏览文件 @
8b0315b3
...
...
@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
#include "megbrain/gopt/inference.h"
...
...
@@ -919,6 +920,69 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-3
);
}
TEST
(
TestGoptInference
,
ConvertFormatNHWCD4LOCAL
)
{
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope
naive_megdnn_handle
;
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
);
};
auto
host_x
=
gen
({
2
,
8
,
8
,
16
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
);
opr
::
Convolution
::
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
w1
=
mkcvar
(
"w1"
,
{
4
,
8
,
3
,
3
}),
conv1
=
opr
::
Convolution
::
make
(
x
,
w1
,
param
);
auto
w2
=
mkcvar
(
"w2"
,
{
8
,
16
,
4
,
3
,
3
,
4
}),
local
=
opr
::
Local
::
make
(
conv1
,
w2
,
param
);
auto
w3
=
mkcvar
(
"w3"
,
{
4
,
4
,
3
,
3
}),
conv2
=
opr
::
Convolution
::
make
(
local
,
w3
,
param
);
opr
::
GroupLocal
::
Param
param_group_local
;
param_group_local
.
pad_h
=
param_group_local
.
pad_w
=
1
;
auto
w4
=
mkcvar
(
"w4"
,
{
2
,
8
,
16
,
2
,
3
,
3
,
2
}),
group_local
=
opr
::
GroupLocal
::
make
(
conv2
,
w4
,
param_group_local
);
auto
w5
=
mkcvar
(
"w5"
,
{
4
,
4
,
3
,
3
}),
y
=
opr
::
Convolution
::
make
(
group_local
,
w5
,
param
);
SymbolVar
y_opt
;
unpack_vector
(
gopt
::
optimize_for_inference
(
{
y
},
gopt
::
OptimizeForInferenceOptions
{}.
enable_use_nhwcd4
()),
y_opt
);
ASSERT_EQ
(
opr
::
Convolution
::
Param
::
Format
::
NHWCD4
,
find_opr
<
opr
::
Convolution
>
(
y_opt
).
param
().
format
);
ASSERT_EQ
(
opr
::
Local
::
Param
::
Format
::
NCHW
,
find_opr
<
opr
::
Local
>
(
y_opt
).
param
().
format
);
ASSERT_EQ
(
opr
::
GroupLocal
::
Param
::
Format
::
NCHW
,
find_opr
<
opr
::
GroupLocal
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ConvertFormatNHWCD4LOCAL.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-3
);
}
TEST
(
TestGoptInference
,
ConvertFormatNHWCD4Deconv
)
{
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope
naive_megdnn_handle
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录