From 533fb5bf49fde96ba7aa9399bbf39a8a38b64855 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 7 Jan 2022 16:48:40 +0800 Subject: [PATCH] feat(imperative): support formatted tensor and add special op rules GitOrigin-RevId: 77ff909f2371f768442fb103ec7038832f9310f6 --- imperative/python/megengine/__init__.py | 1 + imperative/python/megengine/core/_config.py | 68 ++- .../python/megengine/functional/vision.py | 1 - imperative/python/megengine/tensor.py | 8 + imperative/python/src/tensor.cpp | 38 +- imperative/python/src/tensor.h | 7 +- imperative/python/src/transformation.h | 1 + .../python/test/unit/amp/test_autocast.py | 2 +- .../test/unit/core/test_formatted_tensor.py | 307 +++++++++++++ imperative/src/impl/basic_operators.cpp | 17 +- .../src/impl/transformations/format.cpp | 406 ++++++++++++++++++ imperative/src/impl/value.cpp | 4 + .../megbrain/imperative/basic_operators.h | 12 +- .../megbrain/imperative/basic_values.h | 8 + .../imperative/transformations/format.h | 70 +++ .../megbrain/imperative/utils/data_format.h | 56 +++ .../src/include/megbrain/imperative/value.h | 10 +- src/opr/impl/dnn/batch_norm.cpp | 2 +- 18 files changed, 985 insertions(+), 33 deletions(-) create mode 100644 imperative/python/test/unit/core/test_formatted_tensor.py create mode 100644 imperative/src/impl/transformations/format.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/format.h create mode 100644 imperative/src/include/megbrain/imperative/utils/data_format.h diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index a50de77d3..9ef038c78 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush) # subpackages import megengine.amp import megengine.autodiff +import megengine.config import megengine.data import megengine.distributed import megengine.dtr diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index e19d9f685..49877f2a0 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -2,7 +2,13 @@ import os from contextlib import contextmanager -from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option +from ._imperative_rt.core2 import ( + _clear_algorithm_cache, + get_auto_format_convert, + get_option, + set_auto_format_convert, + set_option, +) __compute_mode = "default" __conv_format = "default" @@ -24,8 +30,8 @@ __all__ = [ def benchmark_kernel(mod): r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, which means use heuristic to choose the fastest algorithm. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -47,8 +53,8 @@ def benchmark_kernel(mod, option: bool): def deterministic_kernel(mod): r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, which means the algorithm is not reproducible. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -67,8 +73,8 @@ def deterministic_kernel(mod, option: bool): def async_level(mod) -> int: r"""Get or set config whether raise error exactly when invoking op. The default level is 2, which means both device and user side errors are async. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -108,8 +114,8 @@ def _compute_mode(mod): which means that no special requirements will be placed on. When set to 'float32', it would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -137,8 +143,8 @@ def _conv_format(mod): ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` - - Examples: + + Examples: .. code-block:: import megengine as mge @@ -153,20 +159,41 @@ def _conv_format(mod, format: str): __conv_format = format +@property +def _auto_format_convert(mod): + r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. + The default value is False, which means no convert. + + Examples: + .. code-block:: + + import megengine as mge + mge.config._auto_format_convert = True + """ + return get_auto_format_convert() + + +@_auto_format_convert.setter +def _auto_format_convert(mod, option: bool): + set_auto_format_convert(option) + + def _reset_execution_config( benchmark_kernel=None, deterministic_kernel=None, async_level=None, compute_mode=None, conv_format=None, + auto_format_convert=None, ): - global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format + global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format orig_flags = ( _benchmark_kernel, _deterministic_kernel, get_option("async_level"), __compute_mode, __conv_format, + get_auto_format_convert(), ) if benchmark_kernel is not None: _benchmark_kernel = benchmark_kernel @@ -178,6 +205,8 @@ def _reset_execution_config( __compute_mode = compute_mode if conv_format is not None: __conv_format = conv_format + if auto_format_convert is not None: + set_auto_format_convert(auto_format_convert) return orig_flags @@ -189,26 +218,33 @@ def _override( async_level=None, compute_mode=None, conv_format=None, + auto_format_convert=None, ): r"""A context manager that users can opt in by attaching the decorator to set the config of the global variable. - - Examples: + + Examples: .. code-block:: import megengine as mge - + @mge.config._override( benchmark_kernel = True, deterministic_kernel = Fasle, async_level=2, compute_mode="float32", conv_format="NHWC", + auto_format_convert=True, ) def train(): """ orig_flags = _reset_execution_config( - benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, + benchmark_kernel, + deterministic_kernel, + async_level, + compute_mode, + conv_format, + auto_format_convert, ) try: yield diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index a05415a83..f7b45693d 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -564,7 +564,6 @@ def interpolate( if inp.dtype == np.float16: inp = inp.astype("float32") conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) - assert conv_format == "NCHW", "Currently resize only support NCHW mode" op = builtin.Resize(imode=mode_map[mode], format=conv_format) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) (ret,) = apply(op, inp, shape) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 3f3fa41a9..0aec35a51 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -4,6 +4,7 @@ from typing import Union import numpy as np from .core._imperative_rt import CompNode +from .core._imperative_rt.core2 import FormatType from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import apply, set_py_tensor_type from .core._trace_option import use_symbolic_shape @@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin): is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`. no_cache: Whether cache it for memory sharing. name: Used to improve convenience in graph operation on dumped model. + format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride, + but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc". .. note:: @@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin): is_const: bool = False, no_cache: bool = False, name: str = None, + format: str = "default", ): if name is None: name = "" @@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin): r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.""" return super().dtype + @property + def format(self) -> str: + return super().format + @property def qparams(self): r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.""" diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 4cd4a27f1..40852d7ae 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -8,6 +8,7 @@ #include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" +#include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/symbol.h" @@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) { // name case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; } + case 'f': + // format + return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1; } // clang-format on return -1; @@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { {"is_const", []() -> py::object { return py::bool_(false); }}, {"no_cache", []() -> py::object { return py::bool_(false); }}, {"name", []() -> py::object { return py::none(); }}, + {"format", []() -> py::object { return py::none(); }}, }, name2idx}; py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType @@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } else { tup = parse_args(tup, descs); } - mgb_assert(tup.size() == 6); + mgb_assert(tup.size() == 7); if (auto* t = try_cast(tup[0].ptr())) { m_tensor = t->m_tensor->copy(); } else { auto data = tup[0]; DType dtype = tup[1].cast(); + CompNode cn = as_comp_node(tup[2]); bool is_const = tup[3].cast(); bool no_cache = tup[4].cast(); std::string name; if (!tup[5].is_none()) { name = tup[5].cast(); } - CompNode cn = as_comp_node(tup[2]); + Format format; + if (!tup[6].is_none()) { + format = tup[6].cast(); + } { CreateTensor::Kind kind = is_const ? CreateTensor::Const @@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } else { auto&& hval = pyobj2hval(data, cn, dtype); val = imperative::apply( - CreateTensor(kind, cn, hval.dtype, hval.shape), + CreateTensor(kind, cn, hval.dtype, hval.shape, format), hval.storage)[0]; } m_tensor.emplace(val); @@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() { return py::cast(m_tensor->comp_node()).release().ptr(); } +PyObject* TensorWrapper::format() { + return py::cast(m_tensor->format().to_string()).release().ptr(); +} + PyObject* TensorWrapper::numpy() { auto hv = m_tensor->numpy(); if (!hv) { @@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp); void init_tensor(py::module m) { imperative::Tensor::static_initialize(); + // Transformations static auto& transformations = TransformationManager::get_instance(); using Segment = TransformationManager::Segment; @@ -755,6 +769,9 @@ void init_tensor(py::module m) { .register_at( std::make_shared()) .release()); + auto format_trans = std::make_shared(); + MGB_MARK_USED_VAR( + transformations.register_at(format_trans).release()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); @@ -788,12 +805,14 @@ void init_tensor(py::module m) { } }); + // Tensor auto* tensor_type = TensorWrapper::wrap_t::type() .def<&TensorWrapper::numpy>("numpy") .def_getset<&TensorWrapper::shape>("shape") .def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::device>("device") + .def_getset<&TensorWrapper::format>("format") .def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::isscalar>("_isscalar") .def<&TensorWrapper::detach>("detach") @@ -812,6 +831,11 @@ void init_tensor(py::module m) { if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); + py::enum_(m, "FormatType") + .value("DEFAULT", Format::Type::DEFAULT) + .value("NCHW", Format::Type::NCHW) + .value("NHWC", Format::Type::NHWC) + .export_values(); py::class_(m, "TensorWeakRef") .def(py::init()) @@ -911,6 +935,7 @@ void init_tensor(py::module m) { sync_py_task_q(); }); + // GradTransformation py::handle grad_key_type = GradKeyWrapper::wrap_t::type() .def<&GradKeyWrapper::attach>("attach") @@ -1203,6 +1228,7 @@ void init_tensor(py::module m) { return wrapped_outputs; }); + // ModuleTraceTransformation static py::function module_trace_hook; static auto get_module_trace = [] { @@ -1309,6 +1335,12 @@ void init_tensor(py::module m) { m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); + // FormatTransformation + m.def("set_auto_format_convert", + [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); }); + m.def("get_auto_format_convert", + [format_trans]() { return format_trans->get_auto_convert(); }); + py::register_exception(m, "TraceError"); } diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 86394cbac..7e243631e 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -1,10 +1,11 @@ #pragma once #pragma GCC diagnostic ignored "-Wmissing-field-initializers" -#include - #include #include +#include + +#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/interpreter.h" #include "pybind11/pybind11.h" @@ -57,6 +58,7 @@ public: } return *shape; } + inline Format format() { return *data().format(); } inline HostValue::ref_t numpy() { return data().numpy(); } inline void reset(ValueRef value) { m_data = value; @@ -116,6 +118,7 @@ public: PyObject* shape(); PyObject* dtype(); PyObject* device(); + PyObject* format(); PyObject* numpy(); void reset(PyObject*); PyObject* detach(); diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index ae5cf59aa..99c55cbcb 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -19,6 +19,7 @@ public: DTypePromote, DimExpansion, Grad, + Format, Scalar, Symbol, Trace, diff --git a/imperative/python/test/unit/amp/test_autocast.py b/imperative/python/test/unit/amp/test_autocast.py index 5b4d09e82..e81245b3d 100644 --- a/imperative/python/test/unit/amp/test_autocast.py +++ b/imperative/python/test/unit/amp/test_autocast.py @@ -2,7 +2,7 @@ from megengine import amp from megengine.core.tensor import amp as origin_amp -def test_grad_scaler(): +def test_autocast(): def check(enabled, low, high): assert amp.enabled == enabled assert origin_amp._enabled == enabled diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py new file mode 100644 index 000000000..1f87d2f68 --- /dev/null +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -0,0 +1,307 @@ +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +from megengine import tensor +from megengine.autodiff import GradManager + + +def test_basic(): + a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") + assert a.format == "nhwc" + b = tensor(a) + assert b.format == "nhwc" + # TODO: fix Tensor init bug for another Tensor + # c = tensor(a, format="nchw") + # assert c.format == "nchw" + + +def _compare_nchw_nhwc(data, func): + x1 = tensor(data, format="nchw") + x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + out1 = func(x1) + with mge.config._override(auto_format_convert=True): + out2 = func(x2) + np.testing.assert_equal(out1, out2) + + +def test_dimshuffle(): + def func(x): + out = F.transpose(x, [2, 3, 0, 1]) + assert out.format == "default" + return out.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_reshape(): + # maintain NHWC format + def func(x): + out = F.reshape(x, (1, 2, 6, 2)) + if x.format == "nhwc": + assert out.format == "nhwc" + return out.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + # not maintain NHWC format + def func2(x): + out = F.reshape(x, (1, 24)) + assert out.format == "default" + return out.numpy() + + _compare_nchw_nhwc(data, func2) + + +def test_flatten(): + def func(x): + return F.flatten(x).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_broadcast(): + # maintain NHWC format + def func(x): + out = F.broadcast_to(x, (4, 3, 2, 3)) + if x.format == "nhwc": + assert out.format == "nhwc" + return out.numpy() + + data = np.arange(0, 24).reshape((4, 3, 2, 1)) + _compare_nchw_nhwc(data, func) + + # not maintain NHWC format + def func2(x): + out = F.broadcast_to(x, (3, 4, 3, 2, 1)) + assert out.format == "default" + return out.numpy() + + _compare_nchw_nhwc(data, func2) + + +@pytest.mark.skip("repeat cannot maintain format yet") +def test_repeat(): + def func(x): + rst = F.repeat(x, 3, axis=1) + assert rst.format == x.format + return rst.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_getshape(): + def func(x): + return x.shape + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.skip("symbolic shape is not supported yet") +def test_get_symbolic_shape(): + from megengine.core._trace_option import set_symbolic_shape + + origin_opt = set_symbolic_shape(True) + + def func(x): + return x.shape.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + set_symbolic_shape(origin_opt) + + +def test_getvalue(): + def func(x): + return x.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_get_set_subtensor(): + def get_subtensor(x): + return x[:, :1, :2, :3].numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, get_subtensor) + + def set_subtensor(x): + x[:, :1, :2, :3] = 0 + return x.numpy() + + _compare_nchw_nhwc(data, set_subtensor) + + +def test_get_set_advanced_indexing(): + def get_advanced_indexing(x): + x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy() + return x + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, get_advanced_indexing) + + def set_advanced_indexing(x): + x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0 + return x.numpy() + + _compare_nchw_nhwc(data, set_advanced_indexing) + + +def test_typecvt(): + def typecvt(x): + return x.astype("float16").numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, typecvt) + + +def test_elemwise(): + def elemwise(x): + return (x * 2 + x / 2).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, elemwise) + + +def test_concat(): + def func(x): + rst = F.concat([x / 2, x * 2], axis=1) + assert rst.format == x.format + return rst.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.parametrize( + "mode", ["bilinear", "nearest"], +) +def test_interpolate(mode): + def func(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + rst = F.vision.interpolate(x, scale_factor=3, mode=mode) + assert rst.format == "nhwc" + return rst.numpy() + else: + return F.vision.interpolate(x, scale_factor=3, mode=mode).numpy() + + # NHWC interpolate only suppoted channel is 1 or 3 + data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") + _compare_nchw_nhwc(data, func) + + +def test_conv2d(): + def conv2d(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = F.conv2d( + x, + weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), + ) + assert x.format == "nhwc" + return x.numpy() + else: + return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, conv2d) + + +def test_group_conv2d(): + def conv2d(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = F.conv2d( + x, + weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), + groups=2, + ) + assert x.format == "nhwc" + return x.numpy() + else: + return F.conv2d( + x, F.ones((2, 2, 2, 1, 1)), F.ones((1, 4, 1, 1)), groups=2 + ).numpy() + + data = np.arange(0, 48).reshape((1, 4, 3, 4)) + _compare_nchw_nhwc(data, conv2d) + + +def test_bn(): + def func(x): + if x.format == "nhwc": + with mge.config._override(bn_format="dim_111c"): + oups = F.batch_norm( + x.astype("float32"), + running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + training=True, + inplace=False, + ) + assert oups[0].format == "nhwc", "y's format is wrong" + assert oups[1].format == "nhwc", "running_mean's format is wrong" + assert oups[2].format == "nhwc", "running_var's format is wrong" + return oups[0].numpy() + else: + return F.batch_norm( + x.astype("float32"), + running_mean=mge.tensor(np.ones((1, 2, 1, 1))), + running_var=mge.tensor(np.ones((1, 2, 1, 1))), + weight=mge.tensor(np.ones((1, 2, 1, 1))), + bias=mge.tensor(np.ones((1, 2, 1, 1))), + training=True, + inplace=False, + )[0].numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +@pytest.mark.parametrize( + "pooling", + [F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d], +) +def test_pooling2d(pooling): + def func(x): + if x.format == "nhwc": + with mge.config._override(conv_format="NHWC"): + x = pooling(x.astype("float32"), 2) + assert x.format == "nhwc" + return x.numpy() + else: + return pooling(x.astype("float32"), 2).numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_backward(): + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") + b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") + gm = GradManager().attach([w, b]) + with gm: + with mge.config._override(auto_format_convert=True, conv_format="NHWC"): + x = F.conv2d(x, w, b) + gm.backward(x) + # TODO: backward grad has no format yet + np.testing.assert_equal( + w.grad.numpy(), + np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), + ) + np.testing.assert_equal( + b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) + ) diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index bb4c265ac..a91c228f2 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -33,14 +33,20 @@ std::string GetAttr::to_string() const { return ssprintf("GetAttr{attr=%s}", attr_name); } -CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) - : m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} +CreateTensor::CreateTensor( + Kind kind, CompNode device, DType dtype, ValueShape shape, Format format) + : m_kind(kind), + m_device(device), + m_dtype(dtype), + m_shape(shape), + m_format(format) {} CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) : m_kind(kind), m_device(device), m_dtype(layout.dtype), - m_shape(ValueShape::from(layout)) { + m_shape(ValueShape::from(layout)), + m_format(Format::Type::DEFAULT) { mgb_assert( layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); } @@ -74,8 +80,9 @@ auto CreateTensor::parse(Span inputs) const -> Args { std::string CreateTensor::to_string() const { return ssprintf( - "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, - m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); + "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}", + (int)m_kind, m_device.to_string().c_str(), m_dtype.name(), + m_shape.to_string().c_str(), m_format.to_string().c_str()); } std::string DTRCommand::to_string() const { diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp new file mode 100644 index 000000000..fa431665e --- /dev/null +++ b/imperative/src/impl/transformations/format.cpp @@ -0,0 +1,406 @@ +#include "megbrain/imperative/transformations/format.h" + +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb { +namespace imperative { + +using FT = Format::Type; + +TypedValueRef FormattedTensorValue::as(const FT& target) const { + return FormattedTensorValue::make(m_value, target); +} + +TypedValueRef FormattedTensorValue::to( + const FT& target, const std::string& scope) const { + std::vector pattern; + if (m_format == FT::NHWC && target == FT::NCHW) { + pattern = {0, 3, 1, 2}; + } else if (m_format == FT::NCHW && target == FT::NHWC) { + pattern = {0, 2, 3, 1}; + } else { + mgb_throw( + MegBrainError, "Unsupport format conversion from %s to %s", + m_format.to_string().c_str(), Format(target).to_string().c_str()); + } + auto output = imperative::apply( + *Dimshuffle::make(pattern, scope), std::vector{m_value})[0]; + return FormattedTensorValue::make(output, target); +} + +namespace { + +ValueRef unwrap_input(const ValueRef& input) { + if (auto format_input = input.as_ref()) { + return format_input->value(); + } else { + return input; + } +} + +std::vector unwrap_inputs(const Span& inputs) { + std::vector unwrapped_inputs; + for (auto&& input : inputs) { + unwrapped_inputs.push_back(unwrap_input(input)); + } + return unwrapped_inputs; +} + +std::vector wrap_outputs( + const std::vector& outputs, FT type = FT::DEFAULT) { + std::vector wrapped_outputs; + for (auto&& output : outputs) { + wrapped_outputs.push_back(FormattedTensorValue::make(output, type)); + } + return wrapped_outputs; +} + +ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { + mgb_assert(shape.ndim == 4); + auto out = ValueShape(shape); + out[3] = shape[2]; + out[2] = shape[1]; + out[1] = shape[3]; + return out; +} + +using FormatRule = std::function( + const OpDef&, Span&, const bool&)>; +static std::unordered_map format_rules; + +template +void register_format_rule( + std::vector (*rule)(const T&, Span&, const bool&)) { + format_rules[T::typeinfo()] = [rule](const OpDef& def, Span& inputs, + const bool& auto_convert) { + return (*rule)(def.cast_final_safe(), inputs, auto_convert); + }; +} + +auto convert_nchw2nhwc_pattern(const std::vector& pattern) { + mgb_assert(pattern.size() == 4); + auto nhwc_pattern = pattern; + for (size_t idx = 0; idx < 4; ++idx) { + auto dim = pattern[idx]; + if (dim == 1) { + nhwc_pattern[idx] = 3; + } else if (dim == 2) { + nhwc_pattern[idx] = 1; + } else if (dim == 3) { + nhwc_pattern[idx] = 2; + } + } + return nhwc_pattern; +} + +std::vector dimshuffle_rule( + const Dimshuffle& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 1); + auto& src = inputs[0].cast(); + // Only support converting pattern from NCHW to NHWC currently. + if (auto_convert && src.format() == FT::NHWC) { + auto pattern = convert_nchw2nhwc_pattern(op.pattern); + // dimshuffle will not maintain NHWC Format + return wrap_outputs(imperative::apply( + *Dimshuffle::make(std::move(pattern), op.scope()), + unwrap_inputs(inputs))); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { + mgb_assert(shape.layout().total_nr_elems() == 4); + auto* old_ptr = shape.ptr(); + auto cn = shape.comp_node(); + auto layout = shape.layout(); + auto nhwc_shape = HostTensorND(cn, layout); + auto* new_ptr = nhwc_shape.ptr(); + new_ptr[0] = old_ptr[0]; + new_ptr[1] = old_ptr[2]; + new_ptr[2] = old_ptr[3]; + new_ptr[3] = old_ptr[1]; + auto hv = HostStorage::make(nhwc_shape.storage()); + auto nhwc_shape_input = + imperative::apply(CreateTensor(CreateTensor::Const, cn, layout), hv)[0]; + return nhwc_shape_input; +} + +std::vector reshape_rule( + const Reshape& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 2); + auto& src = inputs[0].cast(); + if (auto_convert && src.format() == FT::NHWC) { + auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + if (shape.layout().total_nr_elems() == 4) { + // output is still NHWC format + auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); + auto outputs = imperative::apply( + op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); + return wrap_outputs(outputs, FT::NHWC); + } else { + // will not maintain src's format + auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto outputs = imperative::apply( + op, std::vector{nchw_src, unwrap_input(inputs[1])}); + return wrap_outputs(outputs); + } + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +std::vector broadcast_rule( + const Broadcast& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() == 2); + auto& src = inputs[0].cast(); + if (auto_convert && src.format() == FT::NHWC) { + auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + if (shape.layout().total_nr_elems() == 4) { + // output is still NHWC format + auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); + auto outputs = imperative::apply( + op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); + return wrap_outputs(outputs, FT::NHWC); + } else { + // will not maintain src's format + auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto outputs = imperative::apply( + op, std::vector{nchw_src, unwrap_input(inputs[1])}); + return wrap_outputs(outputs); + } + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +bool is_reduce_ndim_idx_items( + const std::vector>& items, + const Span& inputs) { + for (auto i = 0; i < items.size(); ++i) { + auto&& [axis, begin, end, step, idx] = items[i]; + if (idx) { + // if inputs[i] contains more than one value, ndim will not be reduced. + return inputs[i].is_scalar(); + } + } + return false; +} + +auto convert_nchw2nhwc_idx_items( + const std::vector>& items) { + auto nhwc_items = items; + for (auto i = 0; i < nhwc_items.size(); ++i) { + auto&& [axis, begin, end, step, idx] = nhwc_items[i]; + if (axis == 2 || axis == 3) { + nhwc_items[i] = {axis - 1, begin, end, step, idx}; + } else if (axis == 1) { + nhwc_items[i] = {3, begin, end, step, idx}; + } + } + return nhwc_items; +} + +template +std::vector subtensor_rule( + const T& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() >= 1); + auto& src = inputs[0].cast(); + bool is_reduce_ndim = is_reduce_ndim_idx_items( + op.items, {&inputs[1], &inputs[inputs.size() - 1]}); + if (!is_reduce_ndim) { + // only support NHWC2NCHW convert, otherwise maintain src's format + if (!(auto_convert && src.format() == FT::NHWC)) { + return {FormattedTensorValue::make( + imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + } + auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); + auto outputs = imperative::apply( + *T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs)); + return wrap_outputs(outputs, FT::NHWC); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +template +std::vector setsubtensor_rule( + const T& op, Span& inputs, const bool& auto_convert) { + mgb_assert(inputs.size() >= 2); + auto& src = inputs[0].cast(); + bool is_reduce_ndim = is_reduce_ndim_idx_items( + op.items, {&inputs[2], &inputs[inputs.size() - 1]}); + if (!is_reduce_ndim) { + // only support NHWC2NCHW convert, otherwise maintain src's format + if (!(auto_convert && src.format() == FT::NHWC)) { + return {FormattedTensorValue::make( + imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + } + // value has been broadcasted to src's fake NCHW shape. + auto& value = inputs[1].cast(); + auto& format = value.format(); + auto nhwc_inputs = std::vector(inputs.size()); + if (format == FT::DEFAULT || format == FT::NCHW) { + // value for setsubtensor should transpose to match shape. + auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC); + // make new inputs for setsubtensor + nhwc_inputs[0] = src.value(); + nhwc_inputs[1] = nhwc_value->value(); + for (auto i = 2; i < inputs.size(); ++i) { + nhwc_inputs[i] = inputs[i].as_ref()->value(); + } + } else if (format != FT::NHWC) { + mgb_throw( + MegBrainError, "Unsupported format(%s) of value for setsubtensor.", + format.to_string().c_str()); + } + auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); + auto outputs = imperative::apply( + *T::make(std::move(nhwc_items), op.scope()), nhwc_inputs); + return wrap_outputs(outputs, FT::NHWC); + } + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); +} + +FT get_inputs_format(Span& inputs) { + FT format(FT::DEFAULT); + for (auto& inp : inputs) { + auto& inp_format = inp.cast().format(); + if (inp_format != FT::DEFAULT) { + mgb_assert(format == FT::DEFAULT || inp_format == format); + format = inp_format.type(); + } + } + return format; +} + +std::vector concat_rule( + const Concat& op, Span& inputs, const bool& auto_convert) { + FT format = get_inputs_format(inputs); + if (!(format == FT::NHWC && auto_convert)) { + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); + } + // TODO: handle 5D NHWC Tensor from group conv + auto axis = op.axis; + if (axis == 2 || axis == 3) { + axis = axis - 1; + } else if (axis == 1) { + axis = 3; + } + return wrap_outputs( + imperative::apply( + *Concat::make(axis, op.comp_node, op.scope()), + unwrap_inputs(inputs)), + format); +} + +std::vector elemwise_rule( + const Elemwise& op, Span& inputs, const bool& auto_convert) { + FT format = get_inputs_format(inputs); + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); +} + +std::vector identity_rule_helper( + const OpDef& op, const Span& inputs) { + // mgb_assert(inputs.size() == 1); + auto& src = inputs[0].cast(); + return wrap_outputs( + imperative::apply(op, unwrap_inputs(inputs)), src.format().type()); +} + +// clang-format off +#define FOREACH_IDENTITY_OP(cb) \ + cb(Copy) \ + cb(FastpathCopy) \ + cb(TypeCvt) \ + cb(Pooling) \ + cb(AdaptivePooling) \ + cb(Dropout) \ + cb(Convolution) \ + cb(BatchNorm) \ + cb(Resize) \ + cb(Identity) +// clang-format on + +#define CREATE_IDENTITY_OP_RULE(op) \ + std::vector op##_rule( \ + const op& _op, Span& inputs, const bool& auto_convert) { \ + return identity_rule_helper(_op, inputs); \ + } +FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) +#undef CREATE_IDENTITY_OP_RULE + +#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule); +struct FormatRuleRegistry { + FormatRuleRegistry() { + register_format_rule(dimshuffle_rule); + register_format_rule(reshape_rule); + register_format_rule(broadcast_rule); + register_format_rule(subtensor_rule); + register_format_rule(subtensor_rule); + register_format_rule(setsubtensor_rule); + register_format_rule(setsubtensor_rule); + register_format_rule(concat_rule); + register_format_rule(elemwise_rule); + FOREACH_IDENTITY_OP(REGISTER_IDENTITY_OP_RULE) + } +} _; +#undef REGISTER_IDENTITY_OP_RULE +} // namespace + +std::vector FormatTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* apply_op = op.as()) { + // all inputs should be FormattedTensorValue + auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != format_rules.end()) { + return iter->second(apply_op->op(), inputs, m_auto_convert); + } else { + return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + } + } else if (auto* create_tensor = op.as()) { + auto format = create_tensor->format(); + return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)}; + } else if (auto* get_attr = op.as()) { + auto* src = inputs.as_array<1>()[0].as(); + if (!m_auto_convert || !src || src->format() != FT::NHWC) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + switch (get_attr->attr()) { + case GetAttr::Shape: { + auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; + auto shape = convert_nhwc2nchw_shape(output.cast()); + return {ShapeValue::make(shape)}; + } + case GetAttr::Value: { + auto nchw_src = unwrap_input(src->to(FT::NCHW, "")); + return imperative::apply(op, std::vector{nchw_src}); + } + default: + return imperative::apply(op, unwrap_inputs(inputs)); + } + } else if (op.is()) { + bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + if (is_formatted_tensor) { + return {FormatValue::make(inputs[0].cast().format())}; + } else { + mgb_log_warn( + "Not FormattedTensorValue input for GetFormat op: %s", + inputs[0].to_string().c_str()); + return {FormatValue::make(FT::DEFAULT)}; + } + } else if (op.is()) { + bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + if (is_formatted_tensor) { + auto& format = inputs[0].cast().format(); + return wrap_outputs( + imperative::apply(op, unwrap_inputs(inputs)), format.type()); + } else { + mgb_log_warn( + "Not FormattedTensorValue input for IdentityLike op: %s", + inputs[0].to_string().c_str()); + return imperative::apply(op, inputs); + } + } else { + return imperative::apply(op, unwrap_inputs(inputs)); + } +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 8b4cb4dd3..e4344e212 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -58,6 +58,10 @@ TypedValueRef ValueRef::dtype() const { return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref(); } +TypedValueRef ValueRef::format() const { + return imperative::apply(GetFormat(), *this)[0].as_ref(); +} + TypedValueRef ValueRef::name() const { return imperative::apply(GetName(), *this)[0].cast_ref(); } diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index 963f24d11..4d35c746f 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -5,6 +5,7 @@ #include "megbrain/imperative/op_def.h" #include "megbrain/imperative/operator.h" +#include "megbrain/imperative/utils/data_format.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/value_shape.h" @@ -82,9 +83,12 @@ private: CompNode m_device; DType m_dtype; ValueShape m_shape; + Format m_format; public: - CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); + CreateTensor( + Kind kind, CompNode device, DType dtype, ValueShape shape, + Format format = Format::Type::DEFAULT); CreateTensor(Kind kind, CompNode device, TensorLayout layout); /** @@ -99,6 +103,7 @@ public: CompNode device() const { return m_device; } DType dtype() const { return m_dtype; } ValueShape shape() const { return m_shape; } + Format format() const { return m_format; } std::string to_string() const override; }; @@ -157,6 +162,11 @@ public: std::string to_string() const override; }; +class GetFormat final : public OperatorImpl { +public: + std::string to_string() const override { return "GetFormat{}"; } +}; + class GetVarVal final : public OperatorImpl { public: std::string to_string() const override; diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 3aa0d3d16..ad4adae56 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -3,6 +3,7 @@ #include #include +#include "megbrain/imperative/utils/data_format.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/value_shape.h" #include "megbrain/imperative/value.h" @@ -148,6 +149,13 @@ public: std::string to_string() const override; }; +class FormatValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override { return Format::to_string(); } +}; + class StringValue final : public PrimitiveValue { public: using PrimitiveValue::PrimitiveValue; diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h new file mode 100644 index 000000000..81fd83bfa --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -0,0 +1,70 @@ +#pragma once + +#include "megbrain/imperative/basic_values.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/data_format.h" + +namespace mgb::imperative { + +class FormattedTensorValue final : public ValueImpl { +private: + ValueRef m_value; + Format m_format; + +public: + FormattedTensorValue(ValueRef value, Format format) + : m_value(value), m_format(format) {} + + std::string to_string() const override { + return ssprintf( + "FormattedTensorValue{value=%s, format=%s}", + m_value.to_string().c_str(), m_format.to_string().c_str()); + } + + ValueRef value() const { return m_value; } + + const Format& format() const { return m_format; } + + TypedValueRef as(const Format::Type& target) const; + TypedValueRef to( + const Format::Type& target, const std::string& scope = "") const; + + void clear() override { + m_value = {}; + m_format = {}; + } + + void on_watch() override { m_value.watch(); } + + void on_unwatch() override { m_value.unwatch(); } +}; + +/** + * \brief simulates scalar because megbrain graph system don't support scalar + * + * Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'. + * This transformation simulates scalars with a flag. If a value is ScalarValue, it is + * scalar, vice versa. So there is not scalar down this layer. + */ +class FormatTransformation final : public Transformation { +private: + bool m_auto_convert = false; + +public: + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is()); + return value; + } + + std::string name() const override { + return ssprintf("FormatTransformation{auto_convert=%d}", m_auto_convert); + } + void set_auto_convert(bool enabled) { m_auto_convert = enabled; } + bool get_auto_convert() const { return m_auto_convert; } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/data_format.h b/imperative/src/include/megbrain/imperative/utils/data_format.h new file mode 100644 index 000000000..f8e664777 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/data_format.h @@ -0,0 +1,56 @@ +#pragma once + +#include "megbrain/tensor.h" + +namespace mgb::imperative { + +/** + * \brief like TensorFormats, but only including common formats and DEFAULT. + * + */ +class Format { +public: + enum class Type { + DEFAULT = 0, + NCHW = 1, ///< [N, C, H, W] + NHWC = 2, ///< [N, H, W, C] + }; + std::string to_string() const { + switch (m_type) { + case Type::DEFAULT: + return "default"; + case Type::NCHW: + return "nchw"; + case Type::NHWC: + return "nhwc"; + default: + mgb_throw(MegBrainError, "bad format type"); + } + } + Format() : m_type(Type::DEFAULT) {} + Format(std::string str) { + if (str == "default") { + m_type = Type::DEFAULT; + } else if (str == "nchw") { + m_type = Type::NCHW; + } else if (str == "nhwc") { + m_type = Type::NHWC; + } else { + mgb_throw( + MegBrainError, + "Invalid format type." + " Only support \"default\", \"nchw\" and \"nhwc\""); + } + } + Format(Type type) : m_type(type) {} + Type type() const { return m_type; } + bool operator==(const Format& b) const { return m_type == b.type(); } + bool operator==(const Format::Type& b) const { return m_type == b; } + bool operator!=(const Format& b) const { return m_type != b.type(); } + bool operator!=(const Format::Type& b) const { return m_type != b; } + +private: + Type m_type = Type::DEFAULT; +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index 44970e86b..ecf63b48d 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -31,6 +31,7 @@ class HostValue; class DeviceValue; class ShapeValue; class DTypeValue; +class FormatValue; class CompNodeValue; class StringValue; class NodeValue; @@ -219,6 +220,7 @@ public: TypedValueRef device() const; TypedValueRef shape() const; TypedValueRef dtype() const; + TypedValueRef format() const; TypedValueRef name() const; bool is_scalar() const; @@ -431,9 +433,11 @@ inline const TypedValueRef& ValueRef::cast_ref(const Type& type) inline void ValueRef::on_cast_failure(const IType& type) const { // if this is ErrorValue, rethrow directly storage()->try_rethrow(); - mgb_assert( - storage()->type() != type, "expect type %s, got %s", type.name().c_str(), - to_string().c_str()); + if (storage()->type() != type) { + mgb_throw( + MegBrainError, "Unable to cast ValueRef: expect type %s, got %s", + type.name().c_str(), to_string().c_str()); + } } /** diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index a58276a09..c6964f335 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape( bias_c = inp_shape[2][channel_idx]; mgb_assert( inp_c == scale_c && inp_c == bias_c, - "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias " + "inconsistent channel size, input channel: %zu, scale channel: %zu, bias " "channel: %zu", inp_c, scale_c, bias_c); -- GitLab