diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index a50de77d342dbd19bbcc5451616296bb703387d2..9ef038c78c8eb5c6f7d08479453b15ceb3e269ac 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush) # subpackages import megengine.amp import megengine.autodiff +import megengine.config import megengine.data import megengine.distributed import megengine.dtr diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index e19d9f6855cb99964bf9b948bd5d43e9014f9d8b..49877f2a084d5b7525f201402b63e0c1c15e1f5c 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -2,7 +2,13 @@ import os from contextlib import contextmanager -from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option +from ._imperative_rt.core2 import ( + _clear_algorithm_cache, + get_auto_format_convert, + get_option, + set_auto_format_convert, + set_option, +) __compute_mode = "default" __conv_format = "default" @@ -24,8 +30,8 @@ __all__ = [ def benchmark_kernel(mod): r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, which means use heuristic to choose the fastest algorithm. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -47,8 +53,8 @@ def benchmark_kernel(mod, option: bool): def deterministic_kernel(mod): r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, which means the algorithm is not reproducible. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -67,8 +73,8 @@ def deterministic_kernel(mod, option: bool): def async_level(mod) -> int: r"""Get or set config whether raise error exactly when invoking op. The default level is 2, which means both device and user side errors are async. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -108,8 +114,8 @@ def _compute_mode(mod): which means that no special requirements will be placed on. When set to 'float32', it would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -137,8 +143,8 @@ def _conv_format(mod): ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -153,20 +159,41 @@ def _conv_format(mod, format: str): __conv_format = format +@property +def _auto_format_convert(mod): + r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. + The default value is False, which means no convert. + + Examples: + .. code-block:: + + import megengine as mge + mge.config._auto_format_convert = True + """ + return get_auto_format_convert() + + +@_auto_format_convert.setter +def _auto_format_convert(mod, option: bool): + set_auto_format_convert(option) + + def _reset_execution_config( benchmark_kernel=None, deterministic_kernel=None, async_level=None, compute_mode=None, conv_format=None, + auto_format_convert=None, ): - global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format + global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format orig_flags = ( _benchmark_kernel, _deterministic_kernel, get_option("async_level"), __compute_mode, __conv_format, + get_auto_format_convert(), ) if benchmark_kernel is not None: _benchmark_kernel = benchmark_kernel @@ -178,6 +205,8 @@ def _reset_execution_config( __compute_mode = compute_mode if conv_format is not None: __conv_format = conv_format + if auto_format_convert is not None: + set_auto_format_convert(auto_format_convert) return orig_flags @@ -189,26 +218,33 @@ def _override( async_level=None, compute_mode=None, conv_format=None, + auto_format_convert=None, ): r"""A context manager that users can opt in by attaching the decorator to set the config of the global variable. - - Examples: + + Examples: .. code-block:: import megengine as mge - + @mge.config._override( benchmark_kernel = True, deterministic_kernel = Fasle, async_level=2, compute_mode="float32", conv_format="NHWC", + auto_format_convert=True, ) def train(): """ orig_flags = _reset_execution_config( - benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, + benchmark_kernel, + deterministic_kernel, + async_level, + compute_mode, + conv_format, + auto_format_convert, ) try: yield diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index a05415a83373eb2b689445a60478a36d18458c71..f7b45693d7a3bc8ee2175db6d4d15eb2d1ddb6cc 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -564,7 +564,6 @@ def interpolate( if inp.dtype == np.float16: inp = inp.astype("float32") conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) - assert conv_format == "NCHW", "Currently resize only support NCHW mode" op = builtin.Resize(imode=mode_map[mode], format=conv_format) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) (ret,) = apply(op, inp, shape) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 3f3fa41a9c9e0e5c1871bb78494aa10c3eed91aa..0aec35a5160e04a82b1a255c97db52d57d41f169 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -4,6 +4,7 @@ from typing import Union import numpy as np from .core._imperative_rt import CompNode +from .core._imperative_rt.core2 import FormatType from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import apply, set_py_tensor_type from .core._trace_option import use_symbolic_shape @@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin): is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`. no_cache: Whether cache it for memory sharing. name: Used to improve convenience in graph operation on dumped model. + format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride, + but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc". .. note:: @@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin): is_const: bool = False, no_cache: bool = False, name: str = None, + format: str = "default", ): if name is None: name = "" @@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin): r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.""" return super().dtype + @property + def format(self) -> str: + return super().format + @property def qparams(self): r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.""" diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 4cd4a27f1aff78a0323a164f269e681a75f8a914..40852d7ae847e1fa59e011925fa18cace9452fb1 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -8,6 +8,7 @@ #include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" +#include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/symbol.h" @@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) { // name case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; } + case 'f': + // format + return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1; } // clang-format on return -1; @@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { {"is_const", []() -> py::object { return py::bool_(false); }}, {"no_cache", []() -> py::object { return py::bool_(false); }}, {"name", []() -> py::object { return py::none(); }}, + {"format", []() -> py::object { return py::none(); }}, }, name2idx}; py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType @@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } else { tup = parse_args(tup, descs); } - mgb_assert(tup.size() == 6); + mgb_assert(tup.size() == 7); if (auto* t = try_cast(tup[0].ptr())) { m_tensor = t->m_tensor->copy(); } else { auto data = tup[0]; DType dtype = tup[1].cast(); + CompNode cn = as_comp_node(tup[2]); bool is_const = tup[3].cast(); bool no_cache = tup[4].cast(); std::string name; if (!tup[5].is_none()) { name = tup[5].cast(); } - CompNode cn = as_comp_node(tup[2]); + Format format; + if (!tup[6].is_none()) { + format = tup[6].cast(); + } { CreateTensor::Kind kind = is_const ? CreateTensor::Const @@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } else { auto&& hval = pyobj2hval(data, cn, dtype); val = imperative::apply( - CreateTensor(kind, cn, hval.dtype, hval.shape), + CreateTensor(kind, cn, hval.dtype, hval.shape, format), hval.storage)[0]; } m_tensor.emplace(val); @@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() { return py::cast(m_tensor->comp_node()).release().ptr(); } +PyObject* TensorWrapper::format() { + return py::cast(m_tensor->format().to_string()).release().ptr(); +} + PyObject* TensorWrapper::numpy() { auto hv = m_tensor->numpy(); if (!hv) { @@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp); void init_tensor(py::module m) { imperative::Tensor::static_initialize(); + // Transformations static auto& transformations = TransformationManager::get_instance(); using Segment = TransformationManager::Segment; @@ -755,6 +769,9 @@ void init_tensor(py::module m) { .register_at( std::make_shared()) .release()); + auto format_trans = std::make_shared(); + MGB_MARK_USED_VAR( + transformations.register_at(format_trans).release()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); @@ -788,12 +805,14 @@ void init_tensor(py::module m) { } }); + // Tensor auto* tensor_type = TensorWrapper::wrap_t::type() .def<&TensorWrapper::numpy>("numpy") .def_getset<&TensorWrapper::shape>("shape") .def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::device>("device") + .def_getset<&TensorWrapper::format>("format") .def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::isscalar>("_isscalar") .def<&TensorWrapper::detach>("detach") @@ -812,6 +831,11 @@ void init_tensor(py::module m) { if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); + py::enum_(m, "FormatType") + .value("DEFAULT", Format::Type::DEFAULT) + .value("NCHW", Format::Type::NCHW) + .value("NHWC", Format::Type::NHWC) + .export_values(); py::class_(m, "TensorWeakRef") .def(py::init()) @@ -911,6 +935,7 @@ void init_tensor(py::module m) { sync_py_task_q(); }); + // GradTransformation py::handle grad_key_type = GradKeyWrapper::wrap_t::type() .def<&GradKeyWrapper::attach>("attach") @@ -1203,6 +1228,7 @@ void init_tensor(py::module m) { return wrapped_outputs; }); + // ModuleTraceTransformation static py::function module_trace_hook; static auto get_module_trace = [] { @@ -1309,6 +1335,12 @@ void init_tensor(py::module m) { m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); + // FormatTransformation + m.def("set_auto_format_convert", + [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); }); + m.def("get_auto_format_convert", + [format_trans]() { return format_trans->get_auto_convert(); }); + py::register_exception(m, "TraceError"); } diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 86394cbac1fc8fc1c01a5c4e23fcfc638dd15578..7e243631ee35dedbb916e0f19eba34a0f2c6dca9 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -1,10 +1,11 @@ #pragma once #pragma GCC diagnostic ignored "-Wmissing-field-initializers" -#include - #include #include +#include + +#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/interpreter.h" #include "pybind11/pybind11.h" @@ -57,6 +58,7 @@ public: } return *shape; } + inline Format format() { return *data().format(); } inline HostValue::ref_t numpy() { return data().numpy(); } inline void reset(ValueRef value) { m_data = value; @@ -116,6 +118,7 @@ public: PyObject* shape(); PyObject* dtype(); PyObject* device(); + PyObject* format(); PyObject* numpy(); void reset(PyObject*); PyObject* detach(); diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index ae5cf59aa6311b3e21259f1c0e779b1a222f1d9d..99c55cbcbcdd89cb77bc4b10738c26d21b6156dc 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -19,6 +19,7 @@ public: DTypePromote, DimExpansion, Grad, + Format, Scalar, Symbol, Trace, diff --git a/imperative/python/test/unit/amp/test_autocast.py b/imperative/python/test/unit/amp/test_autocast.py index 5b4d09e82b39013d60620249285ae7215175a8a0..e81245b3d11c3f354e0466e8f3daaa00aa8dd5e3 100644 --- a/imperative/python/test/unit/amp/test_autocast.py +++ b/imperative/python/test/unit/amp/test_autocast.py @@ -2,7 +2,7 @@ from megengine import amp from megengine.core.tensor import amp as origin_amp -def test_grad_scaler(): +def test_autocast(): def check(enabled, low, high): assert amp.enabled == enabled assert origin_amp._enabled == enabled diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..1f87d2f680586bca97fead388be7904a9189c4f0 --- /dev/null +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -0,0 +1,307 @@ +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +from megengine import tensor +from megengine.autodiff import GradManager + + +def test_basic(): + a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") + assert a.format == "nhwc" + b = tensor(a) + assert b.format == "nhwc" + # TODO: fix Tensor init bug for another Tensor + # c = tensor(a, format="nchw") + # assert c.format == "nchw" + + +def _compare_nchw_nhwc(data, func): + x1 = tensor(data, format="nchw") + x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + out1 = func(x1) + with mge.config._override(auto_format_convert=True): + out2 = func(x2) + np.testing.assert_equal(out1, out2) + + +def test_dimshuffle(): + def func(x): + out = F.transpose(x, [2, 3, 0, 1]) + assert out.format == "default" + return out.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_reshape(): + # maintain NHWC format + def func(x): + out = F.reshape(x, (1, 2, 6, 2)) + if x.format == "nhwc": + assert out.format == "nhwc" + return out.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + # not maintain NHWC format + def func2(x): + out = F.reshape(x, (1, 24)) + assert out.format == "default" + return out.numpy() + + _compare_nchw_nhwc(data, func2) + + +def test_flatten(): + def func(x): + return F.flatten(x).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_broadcast(): + # maintain NHWC format + def func(x): + out = F.broadcast_to(x, (4, 3, 2, 3)) + if x.format == "nhwc": + assert out.format == "nhwc" + return out.numpy() + + data = np.arange(0, 24).reshape((4, 3, 2, 1)) + _compare_nchw_nhwc(data, func) + + # not maintain NHWC format + def func2(x): + out = F.broadcast_to(x, (3, 4, 3, 2, 1)) + assert out.format == "default" + return out.numpy() + + _compare_nchw_nhwc(data, func2) + + +@pytest.mark.skip("repeat cannot maintain format yet") +def test_repeat(): + def func(x): + rst = F.repeat(x, 3, axis=1) + assert rst.format == x.format + return rst.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_getshape(): + def func(x): + return x.shape + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.skip("symbolic shape is not supported yet") +def test_get_symbolic_shape(): + from megengine.core._trace_option import set_symbolic_shape + + origin_opt = set_symbolic_shape(True) + + def func(x): + return x.shape.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + set_symbolic_shape(origin_opt) + + +def test_getvalue(): + def func(x): + return x.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_get_set_subtensor(): + def get_subtensor(x): + return x[:, :1, :2, :3].numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, get_subtensor) + + def set_subtensor(x): + x[:, :1, :2, :3] = 0 + return x.numpy() + + _compare_nchw_nhwc(data, set_subtensor) + + +def test_get_set_advanced_indexing(): + def get_advanced_indexing(x): + x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy() + return x + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, get_advanced_indexing) + + def set_advanced_indexing(x): + x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0 + return x.numpy() + + _compare_nchw_nhwc(data, set_advanced_indexing) + + +def test_typecvt(): + def typecvt(x): + return x.astype("float16").numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, typecvt) + + +def test_elemwise(): + def elemwise(x): + return (x * 2 + x / 2).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, elemwise) + + +def test_concat(): + def func(x): + rst = F.concat([x / 2, x * 2], axis=1) + assert rst.format == x.format + return rst.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.parametrize( + "mode", ["bilinear", "nearest"], +) +def test_interpolate(mode): + def func(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + rst = F.vision.interpolate(x, scale_factor=3, mode=mode) + assert rst.format == "nhwc" + return rst.numpy() + else: + return F.vision.interpolate(x, scale_factor=3, mode=mode).numpy() + + # NHWC interpolate only suppoted channel is 1 or 3 + data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") + _compare_nchw_nhwc(data, func) + + +def test_conv2d(): + def conv2d(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = F.conv2d( + x, + weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), + ) + assert x.format == "nhwc" + return x.numpy() + else: + return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, conv2d) + + +def test_group_conv2d(): + def conv2d(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = F.conv2d( + x, + weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), + groups=2, + ) + assert x.format == "nhwc" + return x.numpy() + else: + return F.conv2d( + x, F.ones((2, 2, 2, 1, 1)), F.ones((1, 4, 1, 1)), groups=2 + ).numpy() + + data = np.arange(0, 48).reshape((1, 4, 3, 4)) + _compare_nchw_nhwc(data, conv2d) + + +def test_bn(): + def func(x): + if x.format == "nhwc": + with mge.config._override(bn_format="dim_111c"): + oups = F.batch_norm( + x.astype("float32"), + running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + training=True, + inplace=False, + ) + assert oups[0].format == "nhwc", "y's format is wrong" + assert oups[1].format == "nhwc", "running_mean's format is wrong" + assert oups[2].format == "nhwc", "running_var's format is wrong" + return oups[0].numpy() + else: + return F.batch_norm( + x.astype("float32"), + running_mean=mge.tensor(np.ones((1, 2, 1, 1))), + running_var=mge.tensor(np.ones((1, 2, 1, 1))), + weight=mge.tensor(np.ones((1, 2, 1, 1))), + bias=mge.tensor(np.ones((1, 2, 1, 1))), + training=True, + inplace=False, + )[0].numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.parametrize( + "pooling", + [F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d], +) +def test_pooling2d(pooling): + def func(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = pooling(x.astype("float32"), 2) + assert x.format == "nhwc" + return x.numpy() + else: + return pooling(x.astype("float32"), 2).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_backward(): + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") + b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") + gm = GradManager().attach([w, b]) + with gm: + with mge.config._override(auto_format_convert=True, conv_format="NHWC"): + x = F.conv2d(x, w, b) + gm.backward(x) + # TODO: backward grad has no format yet + np.testing.assert_equal( + w.grad.numpy(), + np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), + ) + np.testing.assert_equal( + b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) + ) diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index bb4c265acacd5c0046e20ff9d8dd697f0992c9dc..a91c228f2c00af1fcace4f952b1f8f516edd6b33 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -33,14 +33,20 @@ std::string GetAttr::to_string() const { return ssprintf("GetAttr{attr=%s}", attr_name); } -CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) - : m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} +CreateTensor::CreateTensor( + Kind kind, CompNode device, DType dtype, ValueShape shape, Format format) + : m_kind(kind), + m_device(device), + m_dtype(dtype), + m_shape(shape), + m_format(format) {} CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) : m_kind(kind), m_device(device), m_dtype(layout.dtype), - m_shape(ValueShape::from(layout)) { + m_shape(ValueShape::from(layout)), + m_format(Format::Type::DEFAULT) { mgb_assert( layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); } @@ -74,8 +80,9 @@ auto CreateTensor::parse(Span inputs) const -> Args { std::string CreateTensor::to_string() const { return ssprintf( - "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, - m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); + "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}", + (int)m_kind, m_device.to_string().c_str(), m_dtype.name(), + m_shape.to_string().c_str(), m_format.to_string().c_str()); } std::string DTRCommand::to_string() const { diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa431665ef101c2a7782994cfe115fdafc0d46ab --- /dev/null +++ b/imperative/src/impl/transformations/format.cpp @@ -0,0 +1,406 @@ +#include "megbrain/imperative/transformations/format.h" + +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb { +namespace imperative { + +using FT = Format::Type; + +TypedValueRef FormattedTensorValue::as(const FT& target) const { + return FormattedTensorValue::make(m_value, target); +} + +TypedValueRef FormattedTensorValue::to( + const FT& target, const std::string& scope) const { + std::vector pattern; + if (m_format == FT::NHWC && target == FT::NCHW) { + pattern = {0, 3, 1, 2}; + } else if (m_format == FT::NCHW && target == FT::NHWC) { + pattern = {0, 2, 3, 1}; + } else { + mgb_throw( + MegBrainError, "Unsupport format conversion from %s to %s", + m_format.to_string().c_str(), Format(target).to_string().c_str()); + } + auto output = imperative::apply( + *Dimshuffle::make(pattern, scope), std::vector{m_value})[0]; + return FormattedTensorValue::make(output, target); +} + +namespace { + +ValueRef unwrap_input(const ValueRef& input) { + if (auto format_input = input.as_ref()) { + return format_input->value(); + } else { + return input; + } +} + +std::vector unwrap_inputs(const Span& inputs) { + std::vector unwrapped_inputs; + for (auto&& input : inputs) { + unwrapped_inputs.push_back(unwrap_input(input)); + } + return unwrapped_inputs; +} + +std::vector wrap_outputs( + const std::vector& outputs, FT type = FT::DEFAULT) { + std::vector wrapped_outputs; + for (auto&& output : outputs) { + wrapped_outputs.push_back(FormattedTensorValue::make(output, type)); + } + return wrapped_outputs; +} + +ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { + mgb_assert(shape.ndim == 4); + auto out = ValueShape(shape); + out[3] = shape[2]; + out[2] = shape[1]; + out[1] = shape[3]; + return out; +} + +using FormatRule = std::function( + const OpDef&, Span&, const bool&)>; +static std::unordered_map format_rules; + +template +void register_format_rule( + std::vector (*rule)(const T&, Span&, const bool&)) { + format_rules[T::typeinfo()] = [rule](const OpDef& def, Span& inputs, + const bool& auto_convert) { + return (*rule)(def.cast_final_safe(), inputs, auto_convert); + }; +} + +auto convert_nchw2nhwc_pattern(const std::vector& pattern) { + mgb_assert(pattern.size() == 4); + auto nhwc_pattern = pattern; + for (size_t idx = 0; idx < 4; ++idx) { + auto dim = pattern[idx]; + if (dim == 1) { + nhwc_pattern[idx] = 3; + } else if (dim == 2) { + nhwc_pattern[idx] = 1; + } else if (dim == 3) { + nhwc_pattern[idx] = 2; + } + } + return nhwc_pattern; +} + +std::vector dimshuffle_rule( + const Dimshuffle& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 1); + auto& src = inputs[0].cast(); + // Only support converting pattern from NCHW to NHWC currently. + if (auto_convert && src.format() == FT::NHWC) { + auto pattern = convert_nchw2nhwc_pattern(op.pattern); + // dimshuffle will not maintain NHWC Format + return wrap_outputs(imperative::apply( + *Dimshuffle::make(std::move(pattern), op.scope()), + unwrap_inputs(inputs))); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { + mgb_assert(shape.layout().total_nr_elems() == 4); + auto* old_ptr = shape.ptr(); + auto cn = shape.comp_node(); + auto layout = shape.layout(); + auto nhwc_shape = HostTensorND(cn, layout); + auto* new_ptr = nhwc_shape.ptr(); + new_ptr[0] = old_ptr[0]; + new_ptr[1] = old_ptr[2]; + new_ptr[2] = old_ptr[3]; + new_ptr[3] = old_ptr[1]; + auto hv = HostStorage::make(nhwc_shape.storage()); + auto nhwc_shape_input = + imperative::apply(CreateTensor(CreateTensor::Const, cn, layout), hv)[0]; + return nhwc_shape_input; +} + +std::vector reshape_rule( + const Reshape& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 2); + auto& src = inputs[0].cast(); + if (auto_convert && src.format() == FT::NHWC) { + auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + if (shape.layout().total_nr_elems() == 4) { + // output is still NHWC format + auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); + auto outputs = imperative::apply( + op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); + return wrap_outputs(outputs, FT::NHWC); + } else { + // will not maintain src's format + auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto outputs = imperative::apply( + op, std::vector{nchw_src, unwrap_input(inputs[1])}); + return wrap_outputs(outputs); + } + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +std::vector broadcast_rule( + const Broadcast& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 2); + auto& src = inputs[0].cast(); + if (auto_convert && src.format() == FT::NHWC) { + auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + if (shape.layout().total_nr_elems() == 4) { + // output is still NHWC format + auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); + auto outputs = imperative::apply( + op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); + return wrap_outputs(outputs, FT::NHWC); + } else { + // will not maintain src's format + auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto outputs = imperative::apply( + op, std::vector{nchw_src, unwrap_input(inputs[1])}); + return wrap_outputs(outputs); + } + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +bool is_reduce_ndim_idx_items( + const std::vector>& items, + const Span& inputs) { + for (auto i = 0; i < items.size(); ++i) { + auto&& [axis, begin, end, step, idx] = items[i]; + if (idx) { + // if inputs[i] contains more than one value, ndim will not be reduced. + return inputs[i].is_scalar(); + } + } + return false; +} + +auto convert_nchw2nhwc_idx_items( + const std::vector>& items) { + auto nhwc_items = items; + for (auto i = 0; i < nhwc_items.size(); ++i) { + auto&& [axis, begin, end, step, idx] = nhwc_items[i]; + if (axis == 2 || axis == 3) { + nhwc_items[i] = {axis - 1, begin, end, step, idx}; + } else if (axis == 1) { + nhwc_items[i] = {3, begin, end, step, idx}; + } + } + return nhwc_items; +} + +template +std::vector subtensor_rule( + const T& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() >= 1); + auto& src = inputs[0].cast(); + bool is_reduce_ndim = is_reduce_ndim_idx_items( + op.items, {&inputs[1], &inputs[inputs.size() - 1]}); + if (!is_reduce_ndim) { + // only support NHWC2NCHW convert, otherwise maintain src's format + if (!(auto_convert && src.format() == FT::NHWC)) { + return {FormattedTensorValue::make( + imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + } + auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); + auto outputs = imperative::apply( + *T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs)); + return wrap_outputs(outputs, FT::NHWC); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +template +std::vector setsubtensor_rule( + const T& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() >= 2); + auto& src = inputs[0].cast(); + bool is_reduce_ndim = is_reduce_ndim_idx_items( + op.items, {&inputs[2], &inputs[inputs.size() - 1]}); + if (!is_reduce_ndim) { + // only support NHWC2NCHW convert, otherwise maintain src's format + if (!(auto_convert && src.format() == FT::NHWC)) { + return {FormattedTensorValue::make( + imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + } + // value has been broadcasted to src's fake NCHW shape. + auto& value = inputs[1].cast(); + auto& format = value.format(); + auto nhwc_inputs = std::vector(inputs.size()); + if (format == FT::DEFAULT || format == FT::NCHW) { + // value for setsubtensor should transpose to match shape. + auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC); + // make new inputs for setsubtensor + nhwc_inputs[0] = src.value(); + nhwc_inputs[1] = nhwc_value->value(); + for (auto i = 2; i < inputs.size(); ++i) { + nhwc_inputs[i] = inputs[i].as_ref()->value(); + } + } else if (format != FT::NHWC) { + mgb_throw( + MegBrainError, "Unsupported format(%s) of value for setsubtensor.", + format.to_string().c_str()); + } + auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); + auto outputs = imperative::apply( + *T::make(std::move(nhwc_items), op.scope()), nhwc_inputs); + return wrap_outputs(outputs, FT::NHWC); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +FT get_inputs_format(Span& inputs) { + FT format(FT::DEFAULT); + for (auto& inp : inputs) { + auto& inp_format = inp.cast().format(); + if (inp_format != FT::DEFAULT) { + mgb_assert(format == FT::DEFAULT || inp_format == format); + format = inp_format.type(); + } + } + return format; +} + +std::vector concat_rule( + const Concat& op, Span& inputs, const bool& auto_convert) { + FT format = get_inputs_format(inputs); + if (!(format == FT::NHWC && auto_convert)) { + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); + } + // TODO: handle 5D NHWC Tensor from group conv + auto axis = op.axis; + if (axis == 2 || axis == 3) { + axis = axis - 1; + } else if (axis == 1) { + axis = 3; + } + return wrap_outputs( + imperative::apply( + *Concat::make(axis, op.comp_node, op.scope()), + unwrap_inputs(inputs)), + format); +} + +std::vector elemwise_rule( + const Elemwise& op, Span& inputs, const bool& auto_convert) { + FT format = get_inputs_format(inputs); + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); +} + +std::vector identity_rule_helper( + const OpDef& op, const Span& inputs) { + // mgb_assert(inputs.size() == 1); + auto& src = inputs[0].cast(); + return wrap_outputs( + imperative::apply(op, unwrap_inputs(inputs)), src.format().type()); +} + +// clang-format off +#define FOREACH_IDENTITY_OP(cb) \ + cb(Copy) \ + cb(FastpathCopy) \ + cb(TypeCvt) \ + cb(Pooling) \ + cb(AdaptivePooling) \ + cb(Dropout) \ + cb(Convolution) \ + cb(BatchNorm) \ + cb(Resize) \ + cb(Identity) +// clang-format on + +#define CREATE_IDENTITY_OP_RULE(op) \ + std::vector op##_rule( \ + const op& _op, Span& inputs, const bool& auto_convert) { \ + return identity_rule_helper(_op, inputs); \ + } +FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) +#undef CREATE_IDENTITY_OP_RULE + +#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule); +struct FormatRuleRegistry { + FormatRuleRegistry() { + register_format_rule(dimshuffle_rule); + register_format_rule(reshape_rule); + register_format_rule(broadcast_rule); + register_format_rule(subtensor_rule); + register_format_rule(subtensor_rule); + register_format_rule(setsubtensor_rule); + register_format_rule(setsubtensor_rule); + register_format_rule(concat_rule); + register_format_rule(elemwise_rule); + FOREACH_IDENTITY_OP(REGISTER_IDENTITY_OP_RULE) + } +} _; +#undef REGISTER_IDENTITY_OP_RULE +} // namespace + +std::vector FormatTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* apply_op = op.as()) { + // all inputs should be FormattedTensorValue + auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != format_rules.end()) { + return iter->second(apply_op->op(), inputs, m_auto_convert); + } else { + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + } + } else if (auto* create_tensor = op.as()) { + auto format = create_tensor->format(); + return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)}; + } else if (auto* get_attr = op.as()) { + auto* src = inputs.as_array<1>()[0].as(); + if (!m_auto_convert || !src || src->format() != FT::NHWC) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + switch (get_attr->attr()) { + case GetAttr::Shape: { + auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; + auto shape = convert_nhwc2nchw_shape(output.cast()); + return {ShapeValue::make(shape)}; + } + case GetAttr::Value: { + auto nchw_src = unwrap_input(src->to(FT::NCHW, "")); + return imperative::apply(op, std::vector{nchw_src}); + } + default: + return imperative::apply(op, unwrap_inputs(inputs)); + } + } else if (op.is()) { + bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + if (is_formatted_tensor) { + return {FormatValue::make(inputs[0].cast().format())}; + } else { + mgb_log_warn( + "Not FormattedTensorValue input for GetFormat op: %s", + inputs[0].to_string().c_str()); + return {FormatValue::make(FT::DEFAULT)}; + } + } else if (op.is()) { + bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + if (is_formatted_tensor) { + auto& format = inputs[0].cast().format(); + return wrap_outputs( + imperative::apply(op, unwrap_inputs(inputs)), format.type()); + } else { + mgb_log_warn( + "Not FormattedTensorValue input for IdentityLike op: %s", + inputs[0].to_string().c_str()); + return imperative::apply(op, inputs); + } + } else { + return imperative::apply(op, unwrap_inputs(inputs)); + } +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 8b4cb4dd3f863b6e5fe07cc22d724348e0c5aa01..e4344e2128c79c048520ab852dc7b49092ff43e7 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -58,6 +58,10 @@ TypedValueRef ValueRef::dtype() const { return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref(); } +TypedValueRef ValueRef::format() const { + return imperative::apply(GetFormat(), *this)[0].as_ref(); +} + TypedValueRef ValueRef::name() const { return imperative::apply(GetName(), *this)[0].cast_ref(); } diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index 963f24d111b40b78ec26b8a49a1db66d17cbb67b..4d35c746f2393a60c6aed2de5ef2fc4827fba64e 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -5,6 +5,7 @@ #include "megbrain/imperative/op_def.h" #include "megbrain/imperative/operator.h" +#include "megbrain/imperative/utils/data_format.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/value_shape.h" @@ -82,9 +83,12 @@ private: CompNode m_device; DType m_dtype; ValueShape m_shape; + Format m_format; public: - CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); + CreateTensor( + Kind kind, CompNode device, DType dtype, ValueShape shape, + Format format = Format::Type::DEFAULT); CreateTensor(Kind kind, CompNode device, TensorLayout layout); /** @@ -99,6 +103,7 @@ public: CompNode device() const { return m_device; } DType dtype() const { return m_dtype; } ValueShape shape() const { return m_shape; } + Format format() const { return m_format; } std::string to_string() const override; }; @@ -157,6 +162,11 @@ public: std::string to_string() const override; }; +class GetFormat final : public OperatorImpl { +public: + std::string to_string() const override { return "GetFormat{}"; } +}; + class GetVarVal final : public OperatorImpl { public: std::string to_string() const override; diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 3aa0d3d166d71d0048a85cfd19e8703ee334186b..ad4adae56a6094f050ada76a46772064939b4f13 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -3,6 +3,7 @@ #include #include +#include "megbrain/imperative/utils/data_format.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/value_shape.h" #include "megbrain/imperative/value.h" @@ -148,6 +149,13 @@ public: std::string to_string() const override; }; +class FormatValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override { return Format::to_string(); } +}; + class StringValue final : public PrimitiveValue { public: using PrimitiveValue::PrimitiveValue; diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h new file mode 100644 index 0000000000000000000000000000000000000000..81fd83bfaee5c735f7ebdae988908a53326cd81b --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -0,0 +1,70 @@ +#pragma once + +#include "megbrain/imperative/basic_values.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/data_format.h" + +namespace mgb::imperative { + +class FormattedTensorValue final : public ValueImpl { +private: + ValueRef m_value; + Format m_format; + +public: + FormattedTensorValue(ValueRef value, Format format) + : m_value(value), m_format(format) {} + + std::string to_string() const override { + return ssprintf( + "FormattedTensorValue{value=%s, format=%s}", + m_value.to_string().c_str(), m_format.to_string().c_str()); + } + + ValueRef value() const { return m_value; } + + const Format& format() const { return m_format; } + + TypedValueRef as(const Format::Type& target) const; + TypedValueRef to( + const Format::Type& target, const std::string& scope = "") const; + + void clear() override { + m_value = {}; + m_format = {}; + } + + void on_watch() override { m_value.watch(); } + + void on_unwatch() override { m_value.unwatch(); } +}; + +/** + * \brief simulates scalar because megbrain graph system don't support scalar + * + * Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'. + * This transformation simulates scalars with a flag. If a value is ScalarValue, it is + * scalar, vice versa. So there is not scalar down this layer. + */ +class FormatTransformation final : public Transformation { +private: + bool m_auto_convert = false; + +public: + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is()); + return value; + } + + std::string name() const override { + return ssprintf("FormatTransformation{auto_convert=%d}", m_auto_convert); + } + void set_auto_convert(bool enabled) { m_auto_convert = enabled; } + bool get_auto_convert() const { return m_auto_convert; } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/data_format.h b/imperative/src/include/megbrain/imperative/utils/data_format.h new file mode 100644 index 0000000000000000000000000000000000000000..f8e66477777036964609d4dc5e39e1447060cacd --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/data_format.h @@ -0,0 +1,56 @@ +#pragma once + +#include "megbrain/tensor.h" + +namespace mgb::imperative { + +/** + * \brief like TensorFormats, but only including common formats and DEFAULT. + * + */ +class Format { +public: + enum class Type { + DEFAULT = 0, + NCHW = 1, ///< [N, C, H, W] + NHWC = 2, ///< [N, H, W, C] + }; + std::string to_string() const { + switch (m_type) { + case Type::DEFAULT: + return "default"; + case Type::NCHW: + return "nchw"; + case Type::NHWC: + return "nhwc"; + default: + mgb_throw(MegBrainError, "bad format type"); + } + } + Format() : m_type(Type::DEFAULT) {} + Format(std::string str) { + if (str == "default") { + m_type = Type::DEFAULT; + } else if (str == "nchw") { + m_type = Type::NCHW; + } else if (str == "nhwc") { + m_type = Type::NHWC; + } else { + mgb_throw( + MegBrainError, + "Invalid format type." + " Only support \"default\", \"nchw\" and \"nhwc\""); + } + } + Format(Type type) : m_type(type) {} + Type type() const { return m_type; } + bool operator==(const Format& b) const { return m_type == b.type(); } + bool operator==(const Format::Type& b) const { return m_type == b; } + bool operator!=(const Format& b) const { return m_type != b.type(); } + bool operator!=(const Format::Type& b) const { return m_type != b; } + +private: + Type m_type = Type::DEFAULT; +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index 44970e86b5fc4ee41b60462f8422af22010d7007..ecf63b48d654b3a11a77672ffc2c4b4e5b4f6223 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -31,6 +31,7 @@ class HostValue; class DeviceValue; class ShapeValue; class DTypeValue; +class FormatValue; class CompNodeValue; class StringValue; class NodeValue; @@ -219,6 +220,7 @@ public: TypedValueRef device() const; TypedValueRef shape() const; TypedValueRef dtype() const; + TypedValueRef format() const; TypedValueRef name() const; bool is_scalar() const; @@ -431,9 +433,11 @@ inline const TypedValueRef& ValueRef::cast_ref(const Type& type) inline void ValueRef::on_cast_failure(const IType& type) const { // if this is ErrorValue, rethrow directly storage()->try_rethrow(); - mgb_assert( - storage()->type() != type, "expect type %s, got %s", type.name().c_str(), - to_string().c_str()); + if (storage()->type() != type) { + mgb_throw( + MegBrainError, "Unable to cast ValueRef: expect type %s, got %s", + type.name().c_str(), to_string().c_str()); + } } /** diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index a58276a091de19b7c3bb0ecca46402947cb735ee..c6964f33525bb485f69ad2a15b741915c04342eb 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape( bias_c = inp_shape[2][channel_idx]; mgb_assert( inp_c == scale_c && inp_c == bias_c, - "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias " + "inconsistent channel size, input channel: %zu, scale channel: %zu, bias " "channel: %zu", inp_c, scale_c, bias_c);