Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bb8f2928
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bb8f2928
编写于
1月 22, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/func): change `conv_execution_strategy` to `execution_strategy`
GitOrigin-RevId: 7aef9935ee918ddb967e12fb84453a3b2c918e1b
上级
914af286
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
46 addition
and
26 deletion
+46
-26
imperative/python/megengine/functional/debug_param.py
imperative/python/megengine/functional/debug_param.py
+29
-11
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+5
-5
imperative/python/megengine/functional/quantized.py
imperative/python/megengine/functional/quantized.py
+3
-3
imperative/python/test/integration/test_correctness_mnistnet.py
...tive/python/test/integration/test_correctness_mnistnet.py
+3
-3
imperative/python/test/integration/test_dp_correctness.py
imperative/python/test/integration/test_dp_correctness.py
+2
-2
src/opr/impl/search_policy/algo_chooser.cpp
src/opr/impl/search_policy/algo_chooser.cpp
+4
-2
未找到文件。
imperative/python/megengine/functional/debug_param.py
浏览文件 @
bb8f2928
...
...
@@ -8,23 +8,31 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
_conv_execution_strategy
=
os
.
getenv
(
"MEGENGINE_CONV_EXECUTION_STRATEGY"
,
"HEURISTIC"
)
from
..logger
import
get_logger
from
..utils.deprecation
import
deprecated
_execution_strategy
=
os
.
getenv
(
"MEGENGINE_EXECUTION_STRATEGY"
,
"HEURISTIC"
)
def
get_conv_execution_strategy
()
->
str
:
if
os
.
getenv
(
"MEGENGINE_CONV_EXECUTION_STRATEGY"
)
!=
None
:
get_logger
().
warning
(
"Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`"
)
def
get_execution_strategy
()
->
str
:
"""
Returns the execu
ation strategy of :class:`~.Conv2d`.
Returns the execu
tion strategy of :class:`~.Conv2d` and :func:'~.matmul'
See :func:`~.set_
conv_
execution_strategy` for possible return values
See :func:`~.set_execution_strategy` for possible return values
"""
return
_
conv_
execution_strategy
return
_execution_strategy
def
set_
conv_
execution_strategy
(
option
:
str
):
def
set_execution_strategy
(
option
:
str
):
"""
Sets the execu
ation strategy of :class:`~.Conv2d`.
Sets the execu
tion strategy of :class:`~.Conv2d` and :func:'~.matmul'
:param option: Decides how :class:`~.Conv2d` a
lgorithm is
chosen.
:param option: Decides how :class:`~.Conv2d` a
nd :func:'~.matmul' algorithms are
chosen.
Available values:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm.
...
...
@@ -35,7 +43,7 @@ def set_conv_execution_strategy(option: str):
The default strategy is 'HEURISTIC'.
It can also be set through the environment variable 'MEGENGINE_
CONV_
EXECUTION_STRATEGY'.
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
"""
valid_option
=
(
"HEURISTIC"
,
...
...
@@ -47,5 +55,15 @@ def set_conv_execution_strategy(option: str):
if
not
option
in
valid_option
:
raise
ValueError
(
"Valid option can only be one of {}"
.
format
(
valid_option
))
global
_conv_execution_strategy
# pylint: disable=global-statement
_conv_execution_strategy
=
option
global
_execution_strategy
# pylint: disable=global-statement
_execution_strategy
=
option
@
deprecated
(
version
=
"1.3"
,
reason
=
"use get_execution_strategy() instead"
)
def
get_conv_execution_strategy
()
->
str
:
return
get_execution_strategy
()
@
deprecated
(
version
=
"1.3"
,
reason
=
"use set_execution_strategy() instead"
)
def
set_conv_execution_strategy
(
option
:
str
):
return
set_execution_strategy
(
option
)
imperative/python/megengine/functional/nn.py
浏览文件 @
bb8f2928
...
...
@@ -22,7 +22,7 @@ from ..jit.tracing import is_tracing
from
..random
import
uniform
from
..tensor
import
Tensor
from
..utils.tuple_function
import
_pair
,
_pair_nonzero
from
.debug_param
import
get_
conv_
execution_strategy
from
.debug_param
import
get_execution_strategy
from
.distributed
import
all_reduce_sum
from
.elemwise
import
exp
,
floor
,
log
,
log1p
,
maximum
,
minimum
,
relu
from
.math
import
argsort
,
matmul
,
max
,
prod
,
sum
...
...
@@ -149,7 +149,7 @@ def conv2d(
pad_w
=
pad_w
,
dilate_h
=
dilate_h
,
dilate_w
=
dilate_w
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
sparse
=
sparse_type
,
...
...
@@ -217,7 +217,7 @@ def conv_transpose2d(
pad_w
=
pad_w
,
dilate_h
=
dilate_h
,
dilate_w
=
dilate_w
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
)
weight
,
inp
=
utils
.
convert_inputs
(
weight
,
inp
)
(
output
,)
=
apply
(
op
,
weight
,
inp
)
...
...
@@ -282,7 +282,7 @@ def deformable_conv2d(
pad_w
=
pad_w
,
dilate_h
=
dilate_h
,
dilate_w
=
dilate_w
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
sparse
=
sparse_type
,
...
...
@@ -1658,7 +1658,7 @@ def conv1d(
pad_w
=
0
,
dilate_h
=
dilate_h
,
dilate_w
=
1
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
sparse
=
sparse_type
,
...
...
imperative/python/megengine/functional/quantized.py
浏览文件 @
bb8f2928
...
...
@@ -12,7 +12,7 @@ from ..core._imperative_rt.core2 import apply
from
..core.ops
import
builtin
from
..tensor
import
Tensor
from
..utils.tuple_function
import
_pair
,
_pair_nonzero
from
.debug_param
import
get_
conv_
execution_strategy
from
.debug_param
import
get_execution_strategy
def
conv_bias_activation
(
...
...
@@ -65,7 +65,7 @@ def conv_bias_activation(
dilate_w
=
dw
,
dtype
=
dtype
,
format
=
"NCHW"
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
nonlineMode
=
nonlinear_mode
,
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
...
...
@@ -125,7 +125,7 @@ def batch_conv_bias_activation(
dilate_w
=
dw
,
dtype
=
dtype
,
format
=
"NCHW"
,
strategy
=
get_
conv_
execution_strategy
(),
strategy
=
get_execution_strategy
(),
nonlineMode
=
nonlinear_mode
,
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
...
...
imperative/python/test/integration/test_correctness_mnistnet.py
浏览文件 @
bb8f2928
...
...
@@ -20,7 +20,7 @@ import megengine.functional as F
from
megengine
import
jit
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.core.tensor.utils
import
make_shape_tuple
from
megengine.functional.debug_param
import
set_
conv_
execution_strategy
from
megengine.functional.debug_param
import
set_execution_strategy
from
megengine.jit
import
SublinearMemoryConfig
from
megengine.module
import
(
AdaptiveAvgPool2d
,
...
...
@@ -242,7 +242,7 @@ def test_correctness():
else
:
model_name
=
"mnist_model_with_test_cpu.mge"
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
set_
conv_
execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
set_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
run_train
(
model_path
,
False
,
False
,
max_err
=
1e-5
)
run_train
(
model_path
,
True
,
False
,
max_err
=
1e-5
)
...
...
@@ -265,7 +265,7 @@ def test_correctness_use_adaptive_pooling():
else
:
model_name
=
"mnist_model_with_test_cpu.mge"
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
set_
conv_
execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
set_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
run_train
(
model_path
,
False
,
False
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
)
run_train
(
model_path
,
True
,
False
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
)
...
...
imperative/python/test/integration/test_dp_correctness.py
浏览文件 @
bb8f2928
...
...
@@ -21,7 +21,7 @@ import megengine.autodiff as ad
import
megengine.distributed
as
dist
import
megengine.functional
as
F
from
megengine.device
import
get_default_device
,
set_default_device
from
megengine.functional.debug_param
import
set_
conv_
execution_strategy
from
megengine.functional.debug_param
import
set_execution_strategy
from
megengine.module
import
AvgPool2d
,
BatchNorm2d
,
Conv2d
,
Linear
,
Module
from
megengine.optimizer
import
SGD
from
megengine.tensor
import
Tensor
...
...
@@ -198,5 +198,5 @@ def run_test(
def
test_dp_correctness
():
model_name
=
"mnist_model_with_test.mge"
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
set_
conv_
execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
set_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
run_test
(
model_path
,
False
,
False
,
max_err
=
1e-5
)
src/opr/impl/search_policy/algo_chooser.cpp
浏览文件 @
bb8f2928
...
...
@@ -336,8 +336,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
rst
.
workspace
,
rst
.
time
);
prof_rst
.
push_back
(
rst
);
}
mgb_assert
(
!
prof_rst
.
empty
(),
"no usable convolution algorithm %s"
,
str_on_inp_shape
.
c_str
());
std
::
string
msg
=
ssprintf
(
"no usable %s algorithm %s"
,
ctx
.
mgb_opr
()
->
dyn_typeinfo
()
->
name
,
str_on_inp_shape
.
c_str
());
mgb_assert
(
!
prof_rst
.
empty
(),
"%s"
,
msg
.
c_str
());
FixedTensorLayouts
origin_layouts
=
ctx
.
layouts
();
typename
Opr
::
Param
origin_param
=
ctx
.
megdnn_opr
()
->
param
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录