提交 bb8f2928 编写于 作者: M Megvii Engine Team

refactor(mge/func): change `conv_execution_strategy` to `execution_strategy`

GitOrigin-RevId: 7aef9935ee918ddb967e12fb84453a3b2c918e1b
上级 914af286
...@@ -8,23 +8,31 @@ ...@@ -8,23 +8,31 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os import os
_conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") from ..logger import get_logger
from ..utils.deprecation import deprecated
_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")
def get_conv_execution_strategy() -> str: if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
get_logger().warning(
"Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`"
)
def get_execution_strategy() -> str:
""" """
Returns the execuation strategy of :class:`~.Conv2d`. Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
See :func:`~.set_conv_execution_strategy` for possible return values See :func:`~.set_execution_strategy` for possible return values
""" """
return _conv_execution_strategy return _execution_strategy
def set_conv_execution_strategy(option: str): def set_execution_strategy(option: str):
""" """
Sets the execuation strategy of :class:`~.Conv2d`. Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'
:param option: Decides how :class:`~.Conv2d` algorithm is chosen. :param option: Decides how :class:`~.Conv2d` and :func:'~.matmul' algorithms are chosen.
Available values: Available values:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. * 'HEURISTIC' uses heuristic to choose the fastest algorithm.
...@@ -35,7 +43,7 @@ def set_conv_execution_strategy(option: str): ...@@ -35,7 +43,7 @@ def set_conv_execution_strategy(option: str):
The default strategy is 'HEURISTIC'. The default strategy is 'HEURISTIC'.
It can also be set through the environment variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
""" """
valid_option = ( valid_option = (
"HEURISTIC", "HEURISTIC",
...@@ -47,5 +55,15 @@ def set_conv_execution_strategy(option: str): ...@@ -47,5 +55,15 @@ def set_conv_execution_strategy(option: str):
if not option in valid_option: if not option in valid_option:
raise ValueError("Valid option can only be one of {}".format(valid_option)) raise ValueError("Valid option can only be one of {}".format(valid_option))
global _conv_execution_strategy # pylint: disable=global-statement global _execution_strategy # pylint: disable=global-statement
_conv_execution_strategy = option _execution_strategy = option
@deprecated(version="1.3", reason="use get_execution_strategy() instead")
def get_conv_execution_strategy() -> str:
return get_execution_strategy()
@deprecated(version="1.3", reason="use set_execution_strategy() instead")
def set_conv_execution_strategy(option: str):
return set_execution_strategy(option)
...@@ -22,7 +22,7 @@ from ..jit.tracing import is_tracing ...@@ -22,7 +22,7 @@ from ..jit.tracing import is_tracing
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, matmul, max, prod, sum from .math import argsort, matmul, max, prod, sum
...@@ -149,7 +149,7 @@ def conv2d( ...@@ -149,7 +149,7 @@ def conv2d(
pad_w=pad_w, pad_w=pad_w,
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=dilate_w, dilate_w=dilate_w,
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
...@@ -217,7 +217,7 @@ def conv_transpose2d( ...@@ -217,7 +217,7 @@ def conv_transpose2d(
pad_w=pad_w, pad_w=pad_w,
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=dilate_w, dilate_w=dilate_w,
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
) )
weight, inp = utils.convert_inputs(weight, inp) weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
...@@ -282,7 +282,7 @@ def deformable_conv2d( ...@@ -282,7 +282,7 @@ def deformable_conv2d(
pad_w=pad_w, pad_w=pad_w,
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=dilate_w, dilate_w=dilate_w,
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
...@@ -1658,7 +1658,7 @@ def conv1d( ...@@ -1658,7 +1658,7 @@ def conv1d(
pad_w=0, pad_w=0,
dilate_h=dilate_h, dilate_h=dilate_h,
dilate_w=1, dilate_w=1,
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
......
...@@ -12,7 +12,7 @@ from ..core._imperative_rt.core2 import apply ...@@ -12,7 +12,7 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy from .debug_param import get_execution_strategy
def conv_bias_activation( def conv_bias_activation(
...@@ -65,7 +65,7 @@ def conv_bias_activation( ...@@ -65,7 +65,7 @@ def conv_bias_activation(
dilate_w=dw, dilate_w=dw,
dtype=dtype, dtype=dtype,
format="NCHW", format="NCHW",
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode, nonlineMode=nonlinear_mode,
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
...@@ -125,7 +125,7 @@ def batch_conv_bias_activation( ...@@ -125,7 +125,7 @@ def batch_conv_bias_activation(
dilate_w=dw, dilate_w=dw,
dtype=dtype, dtype=dtype,
format="NCHW", format="NCHW",
strategy=get_conv_execution_strategy(), strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode, nonlineMode=nonlinear_mode,
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
......
...@@ -20,7 +20,7 @@ import megengine.functional as F ...@@ -20,7 +20,7 @@ import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_execution_strategy
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
from megengine.module import ( from megengine.module import (
AdaptiveAvgPool2d, AdaptiveAvgPool2d,
...@@ -242,7 +242,7 @@ def test_correctness(): ...@@ -242,7 +242,7 @@ def test_correctness():
else: else:
model_name = "mnist_model_with_test_cpu.mge" model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_train(model_path, False, False, max_err=1e-5) run_train(model_path, False, False, max_err=1e-5)
run_train(model_path, True, False, max_err=1e-5) run_train(model_path, True, False, max_err=1e-5)
...@@ -265,7 +265,7 @@ def test_correctness_use_adaptive_pooling(): ...@@ -265,7 +265,7 @@ def test_correctness_use_adaptive_pooling():
else: else:
model_name = "mnist_model_with_test_cpu.mge" model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True) run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True)
run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True) run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True)
......
...@@ -21,7 +21,7 @@ import megengine.autodiff as ad ...@@ -21,7 +21,7 @@ import megengine.autodiff as ad
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_execution_strategy
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.tensor import Tensor from megengine.tensor import Tensor
...@@ -198,5 +198,5 @@ def run_test( ...@@ -198,5 +198,5 @@ def run_test(
def test_dp_correctness(): def test_dp_correctness():
model_name = "mnist_model_with_test.mge" model_name = "mnist_model_with_test.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") set_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_test(model_path, False, False, max_err=1e-5) run_test(model_path, False, False, max_err=1e-5)
...@@ -336,8 +336,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { ...@@ -336,8 +336,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
rst.workspace, rst.time); rst.workspace, rst.time);
prof_rst.push_back(rst); prof_rst.push_back(rst);
} }
mgb_assert(!prof_rst.empty(), "no usable convolution algorithm %s", std::string msg = ssprintf("no usable %s algorithm %s",
str_on_inp_shape.c_str()); ctx.mgb_opr()->dyn_typeinfo()->name,
str_on_inp_shape.c_str());
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
FixedTensorLayouts origin_layouts = ctx.layouts(); FixedTensorLayouts origin_layouts = ctx.layouts();
typename Opr::Param origin_param = ctx.megdnn_opr()->param(); typename Opr::Param origin_param = ctx.megdnn_opr()->param();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册