提交 3c3fc6f3 编写于 作者: M Megvii Engine Team

refactor(imperative): move python code of elemwise/reduce/conv2d/bn to c++

GitOrigin-RevId: 01b532439243aa2e7d40f150fcaa26fded0e4f27
上级 84466261
......@@ -28,6 +28,16 @@ void BNForward::check_exec(
const TensorLayout& variance, const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance, const TensorLayout& dst,
size_t workspace_in_bytes, size_t reserve_in_bytes) {
// moving some python assert to dnn to decrease the assert overhead
megdnn_assert(
src.ndim == 4,
"ndim of the input tensor for batch_norm should be 4, but you give %zu",
src.ndim);
megdnn_assert(bn_scale.ndim == 4, "expect 4, get %zu\n", bn_scale.ndim);
megdnn_assert(bn_bias.ndim == 4, "expect 4, get %zu\n", bn_bias.ndim);
megdnn_assert_eq_layout(bn_scale, bn_bias);
megdnn_assert_eq_layout(batch_mean, batch_inv_variance);
megdnn_assert_contiguous(src);
megdnn_assert_eq_layout(src, dst);
megdnn_assert_eq_layout(bn_scale, bn_bias);
......
......@@ -58,16 +58,19 @@ class autocast:
self._origin_low = None
def __enter__(self):
self._origin_enabled, amp._enabled = amp._enabled, self.enabled
self._origin_high = amp._high_prec_dtype
amp._high_prec_dtype = self.high_prec_dtype
self._origin_low = amp._low_prec_dtype
amp._low_prec_dtype = self.low_prec_dtype
self._origin_enabled = amp._enabled
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled)
amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype)
def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._high_prec_dtype = self._origin_high
amp._low_prec_dtype = self._origin_low
amp._set_amp_dtype_autocast(self._origin_enabled)
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)
def __call__(self, func):
@functools.wraps(func)
......
......@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .._imperative_rt.core2 import (
_get_amp_dtype_autocast,
_get_amp_high_prec_dtype,
_get_amp_low_prec_dtype,
_set_amp_dtype_autocast,
_set_amp_high_prec_dtype,
_set_amp_low_prec_dtype,
)
_enabled = False
_high_prec_dtype = "float32"
_low_prec_dtype = "float16"
_set_amp_dtype_autocast(_enabled)
@property
......@@ -28,6 +37,7 @@ def enabled(mod):
def enabled(mod, enabled: bool):
global _enabled
_enabled = enabled
_set_amp_dtype_autocast(_enabled)
@property
......@@ -42,13 +52,12 @@ def high_prec_dtype(mod):
import megengine as mge
mge.amp.high_prec_dtype = "float32"
"""
return _high_prec_dtype
return _get_amp_high_prec_dtype()
@high_prec_dtype.setter
def high_prec_dtype(mod, dtype: str):
global _high_prec_dtype
_high_prec_dtype = dtype
_set_amp_high_prec_dtype(dtype)
@property
......@@ -63,10 +72,9 @@ def low_prec_dtype(mod):
import megengine as mge
mge.amp.low_prec_dtype = "float16"
"""
return _low_prec_dtype
return _get_amp_low_prec_dtype()
@low_prec_dtype.setter
def low_prec_dtype(mod, dtype: str):
global _low_prec_dtype
_low_prec_dtype = dtype
_set_amp_low_prec_dtype(dtype)
......@@ -25,7 +25,6 @@ from .utils import (
astensor1d,
astype,
cast_tensors,
convert_inputs,
make_shape_tuple,
subgraph,
)
......@@ -40,38 +39,6 @@ def _elwise_apply(args, mode):
def _elwise(*args, mode):
args = convert_inputs(*args)
if (
mode
in (
_ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
)
and (
amp._enabled
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
)
or mode in (_ElwMod.TANH,)
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
args = cast_tensors(*args, promote=True)
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
return _elwise_apply(args, mode)
......@@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
if mode == "mean":
data = data.astype("float32")
elif self.dtype == np.bool_:
data = data.astype("int32")
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data)
......@@ -526,9 +489,6 @@ def _reduce(mode):
if not keepdims:
result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["min", "max"]:
result = result.astype("bool")
return result
return f
......
......@@ -16,6 +16,8 @@ from .._imperative_rt import make_const
from .._imperative_rt.core2 import (
SymbolVar,
Tensor,
_get_convert_inputs,
_set_convert_inputs,
apply,
dtype_promotion,
get_device,
......@@ -27,15 +29,13 @@ from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin
from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize
_enable_convert_inputs = True
def get_convert_inputs():
r"""get the curerent state of `_enable_convert_inputs`"""
return _enable_convert_inputs
return _get_convert_inputs()
def set_convert_inputs(flag):
......@@ -44,10 +44,7 @@ def set_convert_inputs(flag):
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
internal use only, and should be removed when the tensor-like system is refactored.
"""
global _enable_convert_inputs
backup = _enable_convert_inputs
_enable_convert_inputs = flag
return backup
return _set_convert_inputs(flag)
def concatenate(inputs, axis=0, *, device=None):
......@@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None):
def convert_inputs(*args, device=None):
if not _enable_convert_inputs:
if not _get_convert_inputs():
return args
dtype = dtype_promotion(args)
......@@ -109,9 +106,9 @@ def convert_inputs(*args, device=None):
def cast_tensors(*args, promote=False):
if promote:
dtype = _high_prec_dtype
dtype = _get_amp_high_prec_dtype()
else:
dtype = _low_prec_dtype
dtype = _get_amp_low_prec_dtype()
return tuple(arg.astype(dtype) if arg is not None else None for arg in args)
......
......@@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from .tensor_cache import get_scalar_one
__all__ = [
"abs",
......@@ -359,7 +360,11 @@ def asin(x):
def atan(x):
r"""Element-wise `inverse tangent`."""
return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
return _elwise(
x,
get_scalar_one("float32", x.device if isinstance(x, Tensor) else None),
mode=Elemwise.Mode.ATAN2,
)
def atan2(y, x):
......
......@@ -253,15 +253,6 @@ def conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -1328,29 +1319,32 @@ def batch_norm(
inplace: whether to update ``running_mean`` and ``running_var``
inplace or return new tensors. Default: True
"""
if inp.ndim != 4:
raise NotImplementedError("batch_norm for ndim != 4")
if param_dim == "dim_1c11":
C = inp.shape[1]
pshape = (1, C, 1, 1)
elif param_dim == "dim_111c":
C = inp.shape[3]
pshape = (1, 1, 1, C)
else:
raise ValueError("Invalid param_dim {}".format(param_dim))
def make_full_if_none(x, value):
x_ndim = None if x is None else x.ndim
# in general case, x will be returned here directly
if x_ndim is not None and x_ndim != 1:
return x
if param_dim == "dim_1c11":
C = inp.shape[1]
pshape = (1, C, 1, 1)
elif param_dim == "dim_111c":
C = inp.shape[3]
pshape = (1, 1, 1, C)
else:
raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
elif x.ndim == 1:
else:
assert x_ndim == 1
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape)
return result
return x
has_mean = running_mean is not None
has_var = running_var is not None
......@@ -1359,16 +1353,6 @@ def batch_norm(
assert has_mean, "running_mean must be provided in inference mode"
assert has_var, "running_var must be provided in inference mode"
if has_mean and running_mean.ndim != 4:
raise ValueError
if has_var and running_var.ndim != 4:
raise ValueError
if amp._enabled:
inp = inp.astype("float16")
weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True
)
weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0)
......
from ..core.ops.special import Const
from ..jit.tracing import is_tracing
small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache
if is_tracing():
(ret,) = Const(value, dtype=dtype, device=device)()
else:
cache_key = (value, dtype, device)
if cache_key not in small_tensor_cache:
(ret,) = Const(value, dtype=dtype, device=device)()
small_tensor_cache[cache_key] = ret
else:
ret = small_tensor_cache[cache_key]
return ret
def get_scalar_zero(dtype=None, device=None):
return _get_scalar_tensor_with_value(0, dtype, device)
def get_scalar_zero_point_five(dtype=None, device=None):
return _get_scalar_tensor_with_value(0.5, dtype, device)
def get_scalar_one(dtype=None, device=None):
return _get_scalar_tensor_with_value(1, dtype, device)
def get_scalar_two(dtype=None, device=None):
return _get_scalar_tensor_with_value(2, dtype, device)
......@@ -15,6 +15,7 @@
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
......@@ -59,16 +60,19 @@ struct SymbolVarContext {
TransformationContext context;
std::shared_ptr<SymbolTransformation> symbol_tsf;
std::shared_ptr<ScalarTransformation> scalar_tsf;
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
SymbolVarContext(cg::ComputingGraph* graph) {
symbol_tsf = std::make_shared<SymbolTransformation>(graph);
scalar_tsf = std::make_shared<ScalarTransformation>();
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
Transformation::swap_context(context);
}
void init() {
symbol_tsf->register_at(Transformation::top());
scalar_tsf->register_at(Transformation::top());
dtype_promote_tsf->register_at(Transformation::top());
}
ValueRef symvar2val(py::handle py_symbol_var) {
......@@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d)
#undef REGISTE_APPLY_FUNC
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
CompNode _get_device(PyObject* const* args, size_t nargs);
PyObject* py_apply(
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
try {
......@@ -133,19 +140,59 @@ PyObject* py_apply(
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
SmallVector<ValueRef, 8> tensors(nargs);
bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) &&
py::isinstance<PySymbolVar>(py::handle(args[0]));
if (is_symbol_var) {
SmallVector<bool, 8> is_symbol_var(nargs, false);
ComputingGraph* cg = nullptr;
for (size_t i = 0; i < nargs; ++i) {
if ((!TensorWrapper::try_cast(args[i])) &&
py::isinstance<PySymbolVar>(py::handle(args[i]))) {
is_symbol_var[i] = true;
ComputingGraph* cur_cg =
py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
if (cg == nullptr) {
cg = cur_cg;
} else {
mgb_assert(cg == cur_cg);
}
}
}
mgb::CompNode target_cn;
mgb::DType target_dtype;
auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
if (!target_dtype.valid()) {
target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
target_cn = _get_device(args, nargs);
}
HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
if (PyArray_Check(args[i])) { // non scaler
return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
HostStorage::make(ht.storage()))[0];
} else { // scaler
return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
HostStorage::make(ht.storage()))[0];
}
};
if (cg != nullptr) {
// swap to a special context to reuse scalar handle
SymbolVarContext context(
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph());
size_t symbol_var_idx = 8;
SymbolVarContext context(cg);
context.init();
for (size_t i = 0; i < nargs; ++i) {
tensors[i] = context.symvar2val(args[i]);
if (is_symbol_var[i]) {
symbol_var_idx = i;
tensors[i] = context.symvar2val(args[i]);
} else {
tensors[i] = convert_pyinput_to_tensor(i);
}
}
auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[0]).get_type();
auto typeobj = py::handle(args[symbol_var_idx]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) {
ret[i] = context.val2symvar(typeobj, outputs[i]);
}
......@@ -156,13 +203,7 @@ PyObject* py_apply(
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
tensors[i] = tw->m_tensor->data();
} else {
PyErr_SetString(
PyExc_TypeError,
ssprintf(
"op %s expect type Tensor as inputs, got %s actually",
op->make_name().c_str(), Py_TYPE(args[i])->tp_name)
.c_str());
return nullptr;
tensors[i] = convert_pyinput_to_tensor(i);
}
}
......@@ -616,6 +657,8 @@ void init_tensor(py::module m) {
std::shared_ptr<Channel>(channel, [](Channel*) {})));
transformations.register_at<Segment::Scalar>(
std::make_shared<ScalarTransformation>());
transformations.register_at<Segment::DTypePromote>(
std::make_shared<DTypePromoteTransformation>());
static py::exception<interpreter::AsyncError> py_async_error(
m, "AsyncError", PyExc_RuntimeError);
......@@ -1137,6 +1180,63 @@ void init_tensor(py::module m) {
m.def("reset_stats", [] { imperative::Stats::reset(); });
m.def("_get_convert_inputs",
[]() -> bool { return DTypePromoteCfg::convert_input_enabled; });
m.def("_set_convert_inputs", [](bool flag) -> bool {
bool ret = DTypePromoteCfg::convert_input_enabled;
DTypePromoteCfg::convert_input_enabled = flag;
return ret;
});
m.def("_get_amp_dtype_autocast",
[]() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
return ret;
});
static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
: DTypePromoteCfg::amp_low_prec_dtype;
mgb_assert(target.category() == DTypeCategory::FLOAT);
std::string ret = target.name();
transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
return ret;
};
static auto set_amp_prec_dtype = [](bool is_high,
std::string dtype_name) -> std::string {
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
: DTypePromoteCfg::amp_low_prec_dtype;
std::string ret = target.name();
if (dtype_name == "float32") {
target = dtype::Float32();
} else if (dtype_name == "float16") {
target = dtype::Float16();
} else if (dtype_name == "bfloat16") {
target = dtype::BFloat16();
} else {
mgb_assert(
false, "casted type of amp should be float, but you give %s\n",
dtype_name.c_str());
}
transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
return ret;
};
m.def("_get_amp_high_prec_dtype",
[]() -> std::string { return get_amp_prec_dtype(true); });
m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
return set_amp_prec_dtype(true, dtype_name);
});
m.def("_get_amp_low_prec_dtype",
[]() -> std::string { return get_amp_prec_dtype(false); });
m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
return set_amp_prec_dtype(false, dtype_name);
});
py::register_exception<TraceError>(m, "TraceError");
}
......
......@@ -26,12 +26,13 @@ struct TransformationManager {
enum Segment {
ModuleTrace,
Grad,
DTypePromote,
Scalar,
Trace,
Eval,
};
std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments;
std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments;
template <Segment segment>
void register_at(std::shared_ptr<Transformation> transformation) {
......
......@@ -14,20 +14,20 @@ def test_grad_scaler():
assert amp.enabled == enabled
assert origin_amp._enabled == enabled
assert amp.low_prec_dtype == low
assert origin_amp._low_prec_dtype == low
assert origin_amp._get_amp_low_prec_dtype() == low
assert amp.high_prec_dtype == high
assert origin_amp._high_prec_dtype == high
assert origin_amp._get_amp_high_prec_dtype() == high
origin_enabled = amp.enabled
origin_high = amp.high_prec_dtype
origin_low = amp.low_prec_dtype
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"):
check(True, "low", "high")
with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"):
check(True, "float16", "float32")
check(origin_enabled, origin_low, origin_high)
amp.enabled = True
amp.high_prec_dtype = "high"
amp.low_prec_dtype = "low"
check(True, "low", "high")
amp.high_prec_dtype = "float32"
amp.low_prec_dtype = "float16"
check(True, "float16", "float32")
amp.enabled = origin_enabled
amp.high_prec_dtype = origin_high
amp.low_prec_dtype = origin_low
......
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative {
bool DTypePromoteCfg::convert_input_enabled = true;
bool DTypePromoteCfg::amp_dtype_autocast_enabled = false;
DType DTypePromoteCfg::amp_high_prec_dtype = dtype::Float32();
DType DTypePromoteCfg::amp_low_prec_dtype = dtype::Float16();
namespace {
// TODO: ScalarRule and DTypePromoteRule should be unified
using DTypePromoteRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>;
static std::unordered_map<Typeinfo*, DTypePromoteRule> dtype_promotion_rules;
template <typename T>
void register_dtype_promote_rule(const DTypePromoteRule& rule) {
dtype_promotion_rules[T::typeinfo()] = [rule](const OpDef& def,
Span<ValueRef> inputs) {
return rule(def.cast_final_safe<T>(), inputs);
};
}
bool is_quantized_dtype(const DType& dtype) {
return dtype.category() == DTypeCategory::QUANTIZED;
}
bool is_all_integer(const SmallVector<DType>& dtypes) {
for (size_t i = 0; i < dtypes.size(); ++i) {
if (dtypes[i].category() != DTypeCategory::INT) {
return false;
}
}
return true;
}
SmallVector<DType> get_value_dtypes(const Span<ValueRef> inputs) {
SmallVector<DType> dtypes(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
dtypes[i] = *(inputs[i].dtype());
}
return dtypes;
}
mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) {
if (dtypes.size() == 0) {
mgb_assert(false, "there is no input for operator, dtype promote failed");
}
mgb::DType ret = dtypes[0];
for (size_t i = 1; i < dtypes.size(); ++i) {
ret = mgb::dtype_promotion(ret, dtypes[i]);
}
return ret;
}
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& elem_op = op.cast_final_safe<Elemwise>();
SmallVector<DType> dtypes(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
dtypes[i] = *(inputs[i].dtype());
}
ValueRefList converted(inputs.size());
mgb::DType target_dtype = get_promoted_dtype(dtypes);
// TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of
// this function, rather than perform TypeCvt eagerly. But for the compatibility, we
// implement this function with the similar process as the python version and
// perform TypeCvt here, so we maybe do TypeCvt several times in these function
for (size_t i = 0; i < inputs.size(); ++i) {
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype &&
DTypePromoteCfg::convert_input_enabled) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
dtypes[i] = target_dtype;
} else {
converted[i] = inputs[i];
}
}
static std::unordered_set<Elemwise::Mode> cast_case1 = {
Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP,
Elemwise::Mode::POW, Elemwise::Mode::LOG,
Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P,
Elemwise::Mode::ACOS, Elemwise::Mode::ASIN,
Elemwise::Mode::ATAN2, Elemwise::Mode::COS,
Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP,
};
static std::unordered_set<Elemwise::Mode> cast_case2 = {
Elemwise::Mode::TANH,
};
auto cast_to_high_prec = [&]() {
for (size_t i = 0; i < dtypes.size(); ++i) {
if (dtypes[i] != DTypePromoteCfg::amp_high_prec_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)),
converted[i])[0];
dtypes[i] = DTypePromoteCfg::amp_high_prec_dtype;
}
}
};
if (cast_case1.find(elem_op.mode) != cast_case1.end()) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled || is_all_integer(dtypes)) {
cast_to_high_prec();
}
}
if (cast_case2.find(elem_op.mode) != cast_case2.end()) {
if (is_all_integer(dtypes)) {
cast_to_high_prec();
}
}
static std::unordered_set<Elemwise::Mode> cast_case3 = {
Elemwise::Mode::CEIL, Elemwise::Mode::FLOOR, Elemwise::Mode::ROUND};
if (cast_case3.find(elem_op.mode) != cast_case3.end()) {
if (dtypes[0].category() == DTypeCategory::INT) {
return converted;
}
}
return imperative::apply(op, converted);
}
ValueRefList reduce_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& reduce_op = op.cast_final_safe<Reduce>();
DType org_dtype = *(inputs[0].dtype());
DType target_dtype = org_dtype;
ValueRefList converted(inputs.begin(), inputs.end());
if (reduce_op.mode == Reduce::Mode::MEAN) {
target_dtype = dtype::Float32();
} else if (org_dtype.category() == DTypeCategory::BOOL) {
target_dtype = dtype::Int32();
}
if (target_dtype != org_dtype) {
converted[0] =
imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[0])[0];
}
ValueRefList ret = imperative::apply(op, converted);
if (org_dtype.category() == DTypeCategory::BOOL) {
if (reduce_op.mode == Reduce::Mode::MIN ||
reduce_op.mode == Reduce::Mode::MAX) {
ret[0] = imperative::apply(
ApplyOp(*TypeCvt::make(dtype::Bool())), ret[0])[0];
}
}
return ret;
}
ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<Convolution&>(op.cast_final_safe<Convolution>());
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
conv_op.compute_mode = Convolution::ComputeMode::FLOAT32;
target_dtype = DTypePromoteCfg::amp_low_prec_dtype;
} else {
target_dtype = get_promoted_dtype(dtypes);
}
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
mgb_assert(inputs.size() > 0);
ValueRefList converted(inputs.size());
converted[0] = imperative::apply(
ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0];
for (size_t i = 1; i < inputs.size(); ++i) {
DType idtype = *(inputs[i].dtype());
if (idtype != DTypePromoteCfg::amp_high_prec_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)),
inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
return imperative::apply(op, inputs);
}
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
}
} register_helper;
} // namespace
ValueRefList DTypePromoteTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = dtype_promotion_rules.find(apply_op->op().dyn_typeinfo());
if (iter != dtype_promotion_rules.end()) {
return iter->second(apply_op->op(), inputs);
} else {
return imperative::apply(op, inputs);
}
}
return imperative::apply(op, inputs);
}
ValueRef DTypePromoteTransformation::unwrap(ValueRef value) {
return value;
}
std::string DTypePromoteTransformation::name() const {
return "DTypePromoteTransformation";
}
void DTypePromoteTransformation::on_register() {
// printf("DTypePromoteTransformation has been registered\n");
}
void DTypePromoteTransformation::on_unregister() noexcept {
// printf("DTypePromoteTransformation has been unregistered\n");
}
} // namespace mgb::imperative
\ No newline at end of file
#pragma once
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/value.h"
namespace mgb::imperative {
class DTypePromoteTransformation final : public Transformation {
private:
public:
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override;
std::string name() const override;
void on_register() override;
void on_unregister() noexcept override;
};
struct DTypePromoteCfg {
static bool convert_input_enabled;
static bool amp_dtype_autocast_enabled;
static DType amp_high_prec_dtype;
static DType amp_low_prec_dtype;
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册