diff --git a/dnn/src/common/batch_normalization.cpp b/dnn/src/common/batch_normalization.cpp index ec7fa6561e6226744c23b32d1f26adcbbaae3b21..7f963dace8c360532024223cbcac6aea1084bcce 100644 --- a/dnn/src/common/batch_normalization.cpp +++ b/dnn/src/common/batch_normalization.cpp @@ -28,6 +28,16 @@ void BNForward::check_exec( const TensorLayout& variance, const TensorLayout& batch_mean, const TensorLayout& batch_inv_variance, const TensorLayout& dst, size_t workspace_in_bytes, size_t reserve_in_bytes) { + // moving some python assert to dnn to decrease the assert overhead + megdnn_assert( + src.ndim == 4, + "ndim of the input tensor for batch_norm should be 4, but you give %zu", + src.ndim); + megdnn_assert(bn_scale.ndim == 4, "expect 4, get %zu\n", bn_scale.ndim); + megdnn_assert(bn_bias.ndim == 4, "expect 4, get %zu\n", bn_bias.ndim); + megdnn_assert_eq_layout(bn_scale, bn_bias); + megdnn_assert_eq_layout(batch_mean, batch_inv_variance); + megdnn_assert_contiguous(src); megdnn_assert_eq_layout(src, dst); megdnn_assert_eq_layout(bn_scale, bn_bias); diff --git a/imperative/python/megengine/amp/autocast.py b/imperative/python/megengine/amp/autocast.py index 01b98e8f128de1e8204c89862799481850e2083e..ffbd4282b390fe54d4915f9a54a330ef11062a59 100644 --- a/imperative/python/megengine/amp/autocast.py +++ b/imperative/python/megengine/amp/autocast.py @@ -58,16 +58,19 @@ class autocast: self._origin_low = None def __enter__(self): - self._origin_enabled, amp._enabled = amp._enabled, self.enabled - self._origin_high = amp._high_prec_dtype - amp._high_prec_dtype = self.high_prec_dtype - self._origin_low = amp._low_prec_dtype - amp._low_prec_dtype = self.low_prec_dtype + self._origin_enabled = amp._enabled + self._origin_high = amp._get_amp_high_prec_dtype() + self._origin_low = amp._get_amp_low_prec_dtype() + amp._enabled = self.enabled + amp._set_amp_dtype_autocast(self.enabled) + amp._set_amp_high_prec_dtype(self.high_prec_dtype) + amp._set_amp_low_prec_dtype(self.low_prec_dtype) def __exit__(self, *args): amp._enabled = self._origin_enabled - amp._high_prec_dtype = self._origin_high - amp._low_prec_dtype = self._origin_low + amp._set_amp_dtype_autocast(self._origin_enabled) + amp._set_amp_high_prec_dtype(self._origin_high) + amp._set_amp_low_prec_dtype(self._origin_low) def __call__(self, func): @functools.wraps(func) diff --git a/imperative/python/megengine/core/tensor/amp.py b/imperative/python/megengine/core/tensor/amp.py index cb6ae5b095581588aa05d0e0b8bfac2afacf5189..1c0d9e5bb96586bf98a45de6fb9313ceb4d90b84 100644 --- a/imperative/python/megengine/core/tensor/amp.py +++ b/imperative/python/megengine/core/tensor/amp.py @@ -5,9 +5,18 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from .._imperative_rt.core2 import ( + _get_amp_dtype_autocast, + _get_amp_high_prec_dtype, + _get_amp_low_prec_dtype, + _set_amp_dtype_autocast, + _set_amp_high_prec_dtype, + _set_amp_low_prec_dtype, +) + _enabled = False -_high_prec_dtype = "float32" -_low_prec_dtype = "float16" +_set_amp_dtype_autocast(_enabled) @property @@ -28,6 +37,7 @@ def enabled(mod): def enabled(mod, enabled: bool): global _enabled _enabled = enabled + _set_amp_dtype_autocast(_enabled) @property @@ -42,13 +52,12 @@ def high_prec_dtype(mod): import megengine as mge mge.amp.high_prec_dtype = "float32" """ - return _high_prec_dtype + return _get_amp_high_prec_dtype() @high_prec_dtype.setter def high_prec_dtype(mod, dtype: str): - global _high_prec_dtype - _high_prec_dtype = dtype + _set_amp_high_prec_dtype(dtype) @property @@ -63,10 +72,9 @@ def low_prec_dtype(mod): import megengine as mge mge.amp.low_prec_dtype = "float16" """ - return _low_prec_dtype + return _get_amp_low_prec_dtype() @low_prec_dtype.setter def low_prec_dtype(mod, dtype: str): - global _low_prec_dtype - _low_prec_dtype = dtype + _set_amp_low_prec_dtype(dtype) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index bbb1f4b26d1cda1435b6ffd0df62ce7ac9e4e21a..9072f8e686a8e622ff4182c0ec6b7f5d8bf94471 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -25,7 +25,6 @@ from .utils import ( astensor1d, astype, cast_tensors, - convert_inputs, make_shape_tuple, subgraph, ) @@ -40,38 +39,6 @@ def _elwise_apply(args, mode): def _elwise(*args, mode): - args = convert_inputs(*args) - if ( - mode - in ( - _ElwMod.TRUE_DIV, - _ElwMod.EXP, - _ElwMod.POW, - _ElwMod.LOG, - _ElwMod.EXPM1, - _ElwMod.LOG1P, - _ElwMod.ACOS, - _ElwMod.ASIN, - _ElwMod.ATAN2, - _ElwMod.COS, - _ElwMod.SIN, - _ElwMod.LOG_SUM_EXP, - ) - and ( - amp._enabled - or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) - ) - or mode in (_ElwMod.TANH,) - and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) - ): - # autocast to FP32 to maintain precision - # or to avoid op's not supporting all int args - args = cast_tensors(*args, promote=True) - - if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( - args[0].dtype, np.integer - ): - return args[0] return _elwise_apply(args, mode) @@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self - if mode == "mean": - data = data.astype("float32") - elif self.dtype == np.bool_: - data = data.astype("int32") if axis is None: assert not keepdims, "can not set axis=None and keepdims=True" result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) @@ -526,9 +489,6 @@ def _reduce(mode): if not keepdims: result = _remove_axis(result, axis) - if self.dtype == np.bool_: - if mode in ["min", "max"]: - result = result.astype("bool") return result return f diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 97199cde98e0df94b753b6cd7244ac6b9a8af2be..43c4bad55899b8bcb04632396326e0d7e051523f 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -16,6 +16,8 @@ from .._imperative_rt import make_const from .._imperative_rt.core2 import ( SymbolVar, Tensor, + _get_convert_inputs, + _set_convert_inputs, apply, dtype_promotion, get_device, @@ -27,15 +29,13 @@ from .._wrap import as_device from ..autodiff.grad import Function from ..ops import builtin from ..ops.special import Const -from .amp import _high_prec_dtype, _low_prec_dtype +from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype from .dtype import is_dtype_equal, is_quantize -_enable_convert_inputs = True - def get_convert_inputs(): r"""get the curerent state of `_enable_convert_inputs`""" - return _enable_convert_inputs + return _get_convert_inputs() def set_convert_inputs(flag): @@ -44,10 +44,7 @@ def set_convert_inputs(flag): `_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for internal use only, and should be removed when the tensor-like system is refactored. """ - global _enable_convert_inputs - backup = _enable_convert_inputs - _enable_convert_inputs = flag - return backup + return _set_convert_inputs(flag) def concatenate(inputs, axis=0, *, device=None): @@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None): def convert_inputs(*args, device=None): - if not _enable_convert_inputs: + if not _get_convert_inputs(): return args dtype = dtype_promotion(args) @@ -109,9 +106,9 @@ def convert_inputs(*args, device=None): def cast_tensors(*args, promote=False): if promote: - dtype = _high_prec_dtype + dtype = _get_amp_high_prec_dtype() else: - dtype = _low_prec_dtype + dtype = _get_amp_low_prec_dtype() return tuple(arg.astype(dtype) if arg is not None else None for arg in args) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 1b018e47afc3b7c90b258dcc569b4577be8f19f4..2088d378973a576f33ea88ce7c3ad96f2b9f5209 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise from ..core.tensor.utils import convert_inputs from ..tensor import Tensor from ..utils.deprecation import deprecated_func +from .tensor_cache import get_scalar_one __all__ = [ "abs", @@ -359,7 +360,11 @@ def asin(x): def atan(x): r"""Element-wise `inverse tangent`.""" - return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) + return _elwise( + x, + get_scalar_one("float32", x.device if isinstance(x, Tensor) else None), + mode=Elemwise.Mode.ATAN2, + ) def atan2(y, x): diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index a5d06995001a22a1d1afbecfac4b20c995f480a8..7ca4395dc75c2ac83a0d86af561e5ecda705fb17 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -253,15 +253,6 @@ def conv2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - if amp._enabled: - compute_mode = "float32" - inp, weight, bias = cast_tensors(inp, weight, bias) - else: - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -1328,29 +1319,32 @@ def batch_norm( inplace: whether to update ``running_mean`` and ``running_var`` inplace or return new tensors. Default: True """ - if inp.ndim != 4: - raise NotImplementedError("batch_norm for ndim != 4") - - if param_dim == "dim_1c11": - C = inp.shape[1] - pshape = (1, C, 1, 1) - elif param_dim == "dim_111c": - C = inp.shape[3] - pshape = (1, 1, 1, C) - else: - raise ValueError("Invalid param_dim {}".format(param_dim)) def make_full_if_none(x, value): + x_ndim = None if x is None else x.ndim + # in general case, x will be returned here directly + if x_ndim is not None and x_ndim != 1: + return x + + if param_dim == "dim_1c11": + C = inp.shape[1] + pshape = (1, C, 1, 1) + elif param_dim == "dim_111c": + C = inp.shape[3] + pshape = (1, 1, 1, C) + else: + raise ValueError("Invalid param_dim {}".format(param_dim)) + if x is None: (x,) = Const(value, dtype=inp.dtype, device=inp.device)() shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) return result - elif x.ndim == 1: + else: + assert x_ndim == 1 shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Reshape(), x, shape) return result - return x has_mean = running_mean is not None has_var = running_var is not None @@ -1359,16 +1353,6 @@ def batch_norm( assert has_mean, "running_mean must be provided in inference mode" assert has_var, "running_var must be provided in inference mode" - if has_mean and running_mean.ndim != 4: - raise ValueError - if has_var and running_var.ndim != 4: - raise ValueError - - if amp._enabled: - inp = inp.astype("float16") - weight, bias, running_mean, running_var = cast_tensors( - weight, bias, running_mean, running_var, promote=True - ) weight = make_full_if_none(weight, 1) bias = make_full_if_none(bias, 0) diff --git a/imperative/python/megengine/functional/tensor_cache.py b/imperative/python/megengine/functional/tensor_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..582be4ad458915668e90ef778afe33681b50aed5 --- /dev/null +++ b/imperative/python/megengine/functional/tensor_cache.py @@ -0,0 +1,34 @@ +from ..core.ops.special import Const +from ..jit.tracing import is_tracing + +small_tensor_cache = {} + + +def _get_scalar_tensor_with_value(value, dtype=None, device=None): + global small_tensor_cache + if is_tracing(): + (ret,) = Const(value, dtype=dtype, device=device)() + else: + cache_key = (value, dtype, device) + if cache_key not in small_tensor_cache: + (ret,) = Const(value, dtype=dtype, device=device)() + small_tensor_cache[cache_key] = ret + else: + ret = small_tensor_cache[cache_key] + return ret + + +def get_scalar_zero(dtype=None, device=None): + return _get_scalar_tensor_with_value(0, dtype, device) + + +def get_scalar_zero_point_five(dtype=None, device=None): + return _get_scalar_tensor_with_value(0.5, dtype, device) + + +def get_scalar_one(dtype=None, device=None): + return _get_scalar_tensor_with_value(1, dtype, device) + + +def get_scalar_two(dtype=None, device=None): + return _get_scalar_tensor_with_value(2, dtype, device) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index a9d3495c6f5fe8625046ff7defb7c82776832c26..7847ccff2d2c40e18fefcf0e8ba74c794215fc03 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -15,6 +15,7 @@ #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/profiler.h" +#include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/scalar.h" @@ -59,16 +60,19 @@ struct SymbolVarContext { TransformationContext context; std::shared_ptr symbol_tsf; std::shared_ptr scalar_tsf; + std::shared_ptr dtype_promote_tsf; SymbolVarContext(cg::ComputingGraph* graph) { symbol_tsf = std::make_shared(graph); scalar_tsf = std::make_shared(); + dtype_promote_tsf = std::make_shared(); Transformation::swap_context(context); } void init() { symbol_tsf->register_at(Transformation::top()); scalar_tsf->register_at(Transformation::top()); + dtype_promote_tsf->register_at(Transformation::top()); } ValueRef symvar2val(py::handle py_symbol_var) { @@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d) #undef REGISTE_APPLY_FUNC +PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs); +CompNode _get_device(PyObject* const* args, size_t nargs); + PyObject* py_apply( PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { try { @@ -133,19 +140,59 @@ PyObject* py_apply( auto op = py::handle(py_op).cast>(); SmallVector tensors(nargs); - bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && - py::isinstance(py::handle(args[0])); - if (is_symbol_var) { + SmallVector is_symbol_var(nargs, false); + ComputingGraph* cg = nullptr; + for (size_t i = 0; i < nargs; ++i) { + if ((!TensorWrapper::try_cast(args[i])) && + py::isinstance(py::handle(args[i]))) { + is_symbol_var[i] = true; + ComputingGraph* cur_cg = + py::handle(args[i]).cast()->m_node->owner_graph(); + if (cg == nullptr) { + cg = cur_cg; + } else { + mgb_assert(cg == cur_cg); + } + } + } + + mgb::CompNode target_cn; + mgb::DType target_dtype; + + auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef { + if (!target_dtype.valid()) { + target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs)); + target_cn = _get_device(args, nargs); + } + HostTensorND ht(target_cn); + ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); + if (PyArray_Check(args[i])) { // non scaler + return imperative::apply( + CreateTensor(CreateTensor::Const, target_cn, ht.layout()), + HostStorage::make(ht.storage()))[0]; + } else { // scaler + return imperative::apply( + CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}), + HostStorage::make(ht.storage()))[0]; + } + }; + + if (cg != nullptr) { // swap to a special context to reuse scalar handle - SymbolVarContext context( - py::handle(args[0]).cast()->m_node->owner_graph()); + size_t symbol_var_idx = 8; + SymbolVarContext context(cg); context.init(); for (size_t i = 0; i < nargs; ++i) { - tensors[i] = context.symvar2val(args[i]); + if (is_symbol_var[i]) { + symbol_var_idx = i; + tensors[i] = context.symvar2val(args[i]); + } else { + tensors[i] = convert_pyinput_to_tensor(i); + } } auto outputs = imperative::apply(*op, tensors); auto ret = pybind11::tuple(outputs.size()); - auto typeobj = py::handle(args[0]).get_type(); + auto typeobj = py::handle(args[symbol_var_idx]).get_type(); for (size_t i = 0; i < outputs.size(); ++i) { ret[i] = context.val2symvar(typeobj, outputs[i]); } @@ -156,13 +203,7 @@ PyObject* py_apply( if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { tensors[i] = tw->m_tensor->data(); } else { - PyErr_SetString( - PyExc_TypeError, - ssprintf( - "op %s expect type Tensor as inputs, got %s actually", - op->make_name().c_str(), Py_TYPE(args[i])->tp_name) - .c_str()); - return nullptr; + tensors[i] = convert_pyinput_to_tensor(i); } } @@ -616,6 +657,8 @@ void init_tensor(py::module m) { std::shared_ptr(channel, [](Channel*) {}))); transformations.register_at( std::make_shared()); + transformations.register_at( + std::make_shared()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); @@ -1137,6 +1180,63 @@ void init_tensor(py::module m) { m.def("reset_stats", [] { imperative::Stats::reset(); }); + m.def("_get_convert_inputs", + []() -> bool { return DTypePromoteCfg::convert_input_enabled; }); + m.def("_set_convert_inputs", [](bool flag) -> bool { + bool ret = DTypePromoteCfg::convert_input_enabled; + DTypePromoteCfg::convert_input_enabled = flag; + return ret; + }); + m.def("_get_amp_dtype_autocast", + []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; }); + m.def("_set_amp_dtype_autocast", [](bool flag) -> bool { + bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled; + DTypePromoteCfg::amp_dtype_autocast_enabled = flag; + return ret; + }); + + static auto get_amp_prec_dtype = [](bool is_high) -> std::string { + DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype + : DTypePromoteCfg::amp_low_prec_dtype; + mgb_assert(target.category() == DTypeCategory::FLOAT); + std::string ret = target.name(); + transform(ret.begin(), ret.end(), ret.begin(), ::tolower); + return ret; + }; + + static auto set_amp_prec_dtype = [](bool is_high, + std::string dtype_name) -> std::string { + DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype + : DTypePromoteCfg::amp_low_prec_dtype; + std::string ret = target.name(); + + if (dtype_name == "float32") { + target = dtype::Float32(); + } else if (dtype_name == "float16") { + target = dtype::Float16(); + } else if (dtype_name == "bfloat16") { + target = dtype::BFloat16(); + } else { + mgb_assert( + false, "casted type of amp should be float, but you give %s\n", + dtype_name.c_str()); + } + + transform(ret.begin(), ret.end(), ret.begin(), ::tolower); + return ret; + }; + + m.def("_get_amp_high_prec_dtype", + []() -> std::string { return get_amp_prec_dtype(true); }); + m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string { + return set_amp_prec_dtype(true, dtype_name); + }); + m.def("_get_amp_low_prec_dtype", + []() -> std::string { return get_amp_prec_dtype(false); }); + m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string { + return set_amp_prec_dtype(false, dtype_name); + }); + py::register_exception(m, "TraceError"); } diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 37e66c9fe480c5cf304be7fb4f3c3c718559baaa..84bd6c4997f68686099854dce94179ae70a5bbe7 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -26,12 +26,13 @@ struct TransformationManager { enum Segment { ModuleTrace, Grad, + DTypePromote, Scalar, Trace, Eval, }; - std::array>, 5> segments; + std::array>, 6> segments; template void register_at(std::shared_ptr transformation) { diff --git a/imperative/python/test/unit/amp/test_autocast.py b/imperative/python/test/unit/amp/test_autocast.py index 0aff6aa0464108f99d4909880983e722af57df10..1ed81639ded70cfb3f762d6d0090b4667d306e6c 100644 --- a/imperative/python/test/unit/amp/test_autocast.py +++ b/imperative/python/test/unit/amp/test_autocast.py @@ -14,20 +14,20 @@ def test_grad_scaler(): assert amp.enabled == enabled assert origin_amp._enabled == enabled assert amp.low_prec_dtype == low - assert origin_amp._low_prec_dtype == low + assert origin_amp._get_amp_low_prec_dtype() == low assert amp.high_prec_dtype == high - assert origin_amp._high_prec_dtype == high + assert origin_amp._get_amp_high_prec_dtype() == high origin_enabled = amp.enabled origin_high = amp.high_prec_dtype origin_low = amp.low_prec_dtype - with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): - check(True, "low", "high") + with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"): + check(True, "float16", "float32") check(origin_enabled, origin_low, origin_high) amp.enabled = True - amp.high_prec_dtype = "high" - amp.low_prec_dtype = "low" - check(True, "low", "high") + amp.high_prec_dtype = "float32" + amp.low_prec_dtype = "float16" + check(True, "float16", "float32") amp.enabled = origin_enabled amp.high_prec_dtype = origin_high amp.low_prec_dtype = origin_low diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc4d9f2ddac4ab093c6aedd2437ee98e0d4af3dd --- /dev/null +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -0,0 +1,251 @@ +#include "megbrain/imperative/transformations/dtype_promote.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb::imperative { + +bool DTypePromoteCfg::convert_input_enabled = true; +bool DTypePromoteCfg::amp_dtype_autocast_enabled = false; +DType DTypePromoteCfg::amp_high_prec_dtype = dtype::Float32(); +DType DTypePromoteCfg::amp_low_prec_dtype = dtype::Float16(); + +namespace { +// TODO: ScalarRule and DTypePromoteRule should be unified +using DTypePromoteRule = std::function)>; +static std::unordered_map dtype_promotion_rules; + +template +void register_dtype_promote_rule(const DTypePromoteRule& rule) { + dtype_promotion_rules[T::typeinfo()] = [rule](const OpDef& def, + Span inputs) { + return rule(def.cast_final_safe(), inputs); + }; +} + +bool is_quantized_dtype(const DType& dtype) { + return dtype.category() == DTypeCategory::QUANTIZED; +} + +bool is_all_integer(const SmallVector& dtypes) { + for (size_t i = 0; i < dtypes.size(); ++i) { + if (dtypes[i].category() != DTypeCategory::INT) { + return false; + } + } + return true; +} + +SmallVector get_value_dtypes(const Span inputs) { + SmallVector dtypes(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + dtypes[i] = *(inputs[i].dtype()); + } + return dtypes; +} + +mgb::DType get_promoted_dtype(const SmallVector& dtypes) { + if (dtypes.size() == 0) { + mgb_assert(false, "there is no input for operator, dtype promote failed"); + } + mgb::DType ret = dtypes[0]; + for (size_t i = 1; i < dtypes.size(); ++i) { + ret = mgb::dtype_promotion(ret, dtypes[i]); + } + return ret; +} + +ValueRefList elemwise_rule(const OpDef& op, Span inputs) { + auto&& elem_op = op.cast_final_safe(); + + SmallVector dtypes(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + dtypes[i] = *(inputs[i].dtype()); + } + + ValueRefList converted(inputs.size()); + mgb::DType target_dtype = get_promoted_dtype(dtypes); + + // TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of + // this function, rather than perform TypeCvt eagerly. But for the compatibility, we + // implement this function with the similar process as the python version and + // perform TypeCvt here, so we maybe do TypeCvt several times in these function + + for (size_t i = 0; i < inputs.size(); ++i) { + if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && + DTypePromoteCfg::convert_input_enabled) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + dtypes[i] = target_dtype; + } else { + converted[i] = inputs[i]; + } + } + + static std::unordered_set cast_case1 = { + Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, + Elemwise::Mode::POW, Elemwise::Mode::LOG, + Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, + Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, + Elemwise::Mode::ATAN2, Elemwise::Mode::COS, + Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, + }; + + static std::unordered_set cast_case2 = { + Elemwise::Mode::TANH, + }; + + auto cast_to_high_prec = [&]() { + for (size_t i = 0; i < dtypes.size(); ++i) { + if (dtypes[i] != DTypePromoteCfg::amp_high_prec_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), + converted[i])[0]; + dtypes[i] = DTypePromoteCfg::amp_high_prec_dtype; + } + } + }; + + if (cast_case1.find(elem_op.mode) != cast_case1.end()) { + if (DTypePromoteCfg::amp_dtype_autocast_enabled || is_all_integer(dtypes)) { + cast_to_high_prec(); + } + } + + if (cast_case2.find(elem_op.mode) != cast_case2.end()) { + if (is_all_integer(dtypes)) { + cast_to_high_prec(); + } + } + + static std::unordered_set cast_case3 = { + Elemwise::Mode::CEIL, Elemwise::Mode::FLOOR, Elemwise::Mode::ROUND}; + + if (cast_case3.find(elem_op.mode) != cast_case3.end()) { + if (dtypes[0].category() == DTypeCategory::INT) { + return converted; + } + } + + return imperative::apply(op, converted); +} + +ValueRefList reduce_rule(const OpDef& op, Span inputs) { + auto&& reduce_op = op.cast_final_safe(); + DType org_dtype = *(inputs[0].dtype()); + DType target_dtype = org_dtype; + + ValueRefList converted(inputs.begin(), inputs.end()); + + if (reduce_op.mode == Reduce::Mode::MEAN) { + target_dtype = dtype::Float32(); + } else if (org_dtype.category() == DTypeCategory::BOOL) { + target_dtype = dtype::Int32(); + } + + if (target_dtype != org_dtype) { + converted[0] = + imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[0])[0]; + } + + ValueRefList ret = imperative::apply(op, converted); + + if (org_dtype.category() == DTypeCategory::BOOL) { + if (reduce_op.mode == Reduce::Mode::MIN || + reduce_op.mode == Reduce::Mode::MAX) { + ret[0] = imperative::apply( + ApplyOp(*TypeCvt::make(dtype::Bool())), ret[0])[0]; + } + } + return ret; +} + +ValueRefList convolution_rule(const OpDef& op, Span inputs) { + auto&& conv_op = const_cast(op.cast_final_safe()); + SmallVector dtypes = get_value_dtypes(inputs); + mgb::DType target_dtype; + + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + conv_op.compute_mode = Convolution::ComputeMode::FLOAT32; + target_dtype = DTypePromoteCfg::amp_low_prec_dtype; + } else { + target_dtype = get_promoted_dtype(dtypes); + } + + ValueRefList converted(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + +ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + mgb_assert(inputs.size() > 0); + + ValueRefList converted(inputs.size()); + converted[0] = imperative::apply( + ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; + + for (size_t i = 1; i < inputs.size(); ++i) { + DType idtype = *(inputs[i].dtype()); + if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), + inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); + } + + return imperative::apply(op, inputs); +} + +struct DTypePromoteRuleRegistry { + DTypePromoteRuleRegistry() { + register_dtype_promote_rule(elemwise_rule); + register_dtype_promote_rule(reduce_rule); + register_dtype_promote_rule(convolution_rule); + register_dtype_promote_rule(batch_norm_rule); + } +} register_helper; + +} // namespace + +ValueRefList DTypePromoteTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto apply_op = op.as()) { + auto iter = dtype_promotion_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != dtype_promotion_rules.end()) { + return iter->second(apply_op->op(), inputs); + } else { + return imperative::apply(op, inputs); + } + } + return imperative::apply(op, inputs); +} + +ValueRef DTypePromoteTransformation::unwrap(ValueRef value) { + return value; +} + +std::string DTypePromoteTransformation::name() const { + return "DTypePromoteTransformation"; +} + +void DTypePromoteTransformation::on_register() { + // printf("DTypePromoteTransformation has been registered\n"); +} + +void DTypePromoteTransformation::on_unregister() noexcept { + // printf("DTypePromoteTransformation has been unregistered\n"); +} + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/transformations/dtype_promote.h b/imperative/src/include/megbrain/imperative/transformations/dtype_promote.h new file mode 100644 index 0000000000000000000000000000000000000000..b5dd78ec8c86be1ad89108d954907574be92f1b8 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/dtype_promote.h @@ -0,0 +1,26 @@ +#pragma once + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/value.h" + +namespace mgb::imperative { + +class DTypePromoteTransformation final : public Transformation { +private: +public: + ValueRefList apply_transformation( + const Operator& op, Span inputs) override; + ValueRef unwrap(ValueRef value) override; + std::string name() const override; + void on_register() override; + void on_unregister() noexcept override; +}; + +struct DTypePromoteCfg { + static bool convert_input_enabled; + static bool amp_dtype_autocast_enabled; + static DType amp_high_prec_dtype; + static DType amp_low_prec_dtype; +}; + +} // namespace mgb::imperative