提交 bb8f2928 编写于 作者: M Megvii Engine Team

refactor(mge/func): change `conv_execution_strategy` to `execution_strategy`

GitOrigin-RevId: 7aef9935ee918ddb967e12fb84453a3b2c918e1b
上级 914af286
......@@ -8,23 +8,31 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
_conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC")
from ..logger import get_logger
from ..utils.deprecation import deprecated
_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")
def get_conv_execution_strategy() -> str:
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
get_logger().warning(
"Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`"
)
def get_execution_strategy() -> str:
"""
Returns the execuation strategy of :class:`~.Conv2d`.
Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
See :func:`~.set_conv_execution_strategy` for possible return values
See :func:`~.set_execution_strategy` for possible return values
"""
return _conv_execution_strategy
return _execution_strategy
def set_conv_execution_strategy(option: str):
def set_execution_strategy(option: str):
"""
Sets the execuation strategy of :class:`~.Conv2d`.
Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
:param option: Decides how :class:`~.Conv2d` algorithm is chosen.
:param option: Decides how :class:`~.Conv2d` and :func:'~.matmul' algorithms are chosen.
Available values:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm.
......@@ -35,7 +43,7 @@ def set_conv_execution_strategy(option: str):
The default strategy is 'HEURISTIC'.
It can also be set through the environment variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'.
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
"""
valid_option = (
"HEURISTIC",
......@@ -47,5 +55,15 @@ def set_conv_execution_strategy(option: str):
if not option in valid_option:
raise ValueError("Valid option can only be one of {}".format(valid_option))
global _conv_execution_strategy # pylint: disable=global-statement
_conv_execution_strategy = option
global _execution_strategy # pylint: disable=global-statement
_execution_strategy = option
@deprecated(version="1.3", reason="use get_execution_strategy() instead")
def get_conv_execution_strategy() -> str:
return get_execution_strategy()
@deprecated(version="1.3", reason="use set_execution_strategy() instead")
def set_conv_execution_strategy(option: str):
return set_execution_strategy(option)
......@@ -22,7 +22,7 @@ from ..jit.tracing import is_tracing
from ..random import uniform
from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, matmul, max, prod, sum
......@@ -149,7 +149,7 @@ def conv2d(
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
......@@ -217,7 +217,7 @@ def conv_transpose2d(
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
......@@ -282,7 +282,7 @@ def deformable_conv2d(
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
......@@ -1658,7 +1658,7 @@ def conv1d(
pad_w=0,
dilate_h=dilate_h,
dilate_w=1,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
......
......@@ -12,7 +12,7 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .debug_param import get_execution_strategy
def conv_bias_activation(
......@@ -65,7 +65,7 @@ def conv_bias_activation(
dilate_w=dw,
dtype=dtype,
format="NCHW",
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
compute_mode=compute_mode,
......@@ -125,7 +125,7 @@ def batch_conv_bias_activation(
dilate_w=dw,
dtype=dtype,
format="NCHW",
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
compute_mode=compute_mode,
......
......@@ -20,7 +20,7 @@ import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.functional.debug_param import set_execution_strategy
from megengine.jit import SublinearMemoryConfig
from megengine.module import (
AdaptiveAvgPool2d,
......@@ -242,7 +242,7 @@ def test_correctness():
else:
model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_train(model_path, False, False, max_err=1e-5)
run_train(model_path, True, False, max_err=1e-5)
......@@ -265,7 +265,7 @@ def test_correctness_use_adaptive_pooling():
else:
model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True)
run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True)
......
......@@ -21,7 +21,7 @@ import megengine.autodiff as ad
import megengine.distributed as dist
import megengine.functional as F
from megengine.device import get_default_device, set_default_device
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.functional.debug_param import set_execution_strategy
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD
from megengine.tensor import Tensor
......@@ -198,5 +198,5 @@ def run_test(
def test_dp_correctness():
model_name = "mnist_model_with_test.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_test(model_path, False, False, max_err=1e-5)
......@@ -336,8 +336,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
rst.workspace, rst.time);
prof_rst.push_back(rst);
}
mgb_assert(!prof_rst.empty(), "no usable convolution algorithm %s",
str_on_inp_shape.c_str());
std::string msg = ssprintf("no usable %s algorithm %s",
ctx.mgb_opr()->dyn_typeinfo()->name,
str_on_inp_shape.c_str());
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
FixedTensorLayouts origin_layouts = ctx.layouts();
typename Opr::Param origin_param = ctx.megdnn_opr()->param();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册