提交 533fb5bf 编写于 作者: M Megvii Engine Team

feat(imperative): support formatted tensor and add special op rules

GitOrigin-RevId: 77ff909f2371f768442fb103ec7038832f9310f6
上级 4aa79c45
...@@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush) ...@@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush)
# subpackages # subpackages
import megengine.amp import megengine.amp
import megengine.autodiff import megengine.autodiff
import megengine.config
import megengine.data import megengine.data
import megengine.distributed import megengine.distributed
import megengine.dtr import megengine.dtr
......
...@@ -2,7 +2,13 @@ ...@@ -2,7 +2,13 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option from ._imperative_rt.core2 import (
_clear_algorithm_cache,
get_auto_format_convert,
get_option,
set_auto_format_convert,
set_option,
)
__compute_mode = "default" __compute_mode = "default"
__conv_format = "default" __conv_format = "default"
...@@ -153,20 +159,41 @@ def _conv_format(mod, format: str): ...@@ -153,20 +159,41 @@ def _conv_format(mod, format: str):
__conv_format = format __conv_format = format
@property
def _auto_format_convert(mod):
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
The default value is False, which means no convert.
Examples:
.. code-block::
import megengine as mge
mge.config._auto_format_convert = True
"""
return get_auto_format_convert()
@_auto_format_convert.setter
def _auto_format_convert(mod, option: bool):
set_auto_format_convert(option)
def _reset_execution_config( def _reset_execution_config(
benchmark_kernel=None, benchmark_kernel=None,
deterministic_kernel=None, deterministic_kernel=None,
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None, conv_format=None,
auto_format_convert=None,
): ):
global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format
orig_flags = ( orig_flags = (
_benchmark_kernel, _benchmark_kernel,
_deterministic_kernel, _deterministic_kernel,
get_option("async_level"), get_option("async_level"),
__compute_mode, __compute_mode,
__conv_format, __conv_format,
get_auto_format_convert(),
) )
if benchmark_kernel is not None: if benchmark_kernel is not None:
_benchmark_kernel = benchmark_kernel _benchmark_kernel = benchmark_kernel
...@@ -178,6 +205,8 @@ def _reset_execution_config( ...@@ -178,6 +205,8 @@ def _reset_execution_config(
__compute_mode = compute_mode __compute_mode = compute_mode
if conv_format is not None: if conv_format is not None:
__conv_format = conv_format __conv_format = conv_format
if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert)
return orig_flags return orig_flags
...@@ -189,6 +218,7 @@ def _override( ...@@ -189,6 +218,7 @@ def _override(
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None, conv_format=None,
auto_format_convert=None,
): ):
r"""A context manager that users can opt in by attaching the decorator to set r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable. the config of the global variable.
...@@ -204,11 +234,17 @@ def _override( ...@@ -204,11 +234,17 @@ def _override(
async_level=2, async_level=2,
compute_mode="float32", compute_mode="float32",
conv_format="NHWC", conv_format="NHWC",
auto_format_convert=True,
) )
def train(): def train():
""" """
orig_flags = _reset_execution_config( orig_flags = _reset_execution_config(
benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, benchmark_kernel,
deterministic_kernel,
async_level,
compute_mode,
conv_format,
auto_format_convert,
) )
try: try:
yield yield
......
...@@ -564,7 +564,6 @@ def interpolate( ...@@ -564,7 +564,6 @@ def interpolate(
if inp.dtype == np.float16: if inp.dtype == np.float16:
inp = inp.astype("float32") inp = inp.astype("float32")
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
assert conv_format == "NCHW", "Currently resize only support NCHW mode"
op = builtin.Resize(imode=mode_map[mode], format=conv_format) op = builtin.Resize(imode=mode_map[mode], format=conv_format)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape) (ret,) = apply(op, inp, shape)
......
...@@ -4,6 +4,7 @@ from typing import Union ...@@ -4,6 +4,7 @@ from typing import Union
import numpy as np import numpy as np
from .core._imperative_rt import CompNode from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import FormatType
from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply, set_py_tensor_type from .core._imperative_rt.core2 import apply, set_py_tensor_type
from .core._trace_option import use_symbolic_shape from .core._trace_option import use_symbolic_shape
...@@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`. is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`.
no_cache: Whether cache it for memory sharing. no_cache: Whether cache it for memory sharing.
name: Used to improve convenience in graph operation on dumped model. name: Used to improve convenience in graph operation on dumped model.
format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride,
but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc".
.. note:: .. note::
...@@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
is_const: bool = False, is_const: bool = False,
no_cache: bool = False, no_cache: bool = False,
name: str = None, name: str = None,
format: str = "default",
): ):
if name is None: if name is None:
name = "" name = ""
...@@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.""" r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`."""
return super().dtype return super().dtype
@property
def format(self) -> str:
return super().format
@property @property
def qparams(self): def qparams(self):
r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.""" r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`."""
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h" #include "megbrain/imperative/transformations/symbol.h"
...@@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) { ...@@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) {
// name // name
case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
} }
case 'f':
// format
return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1;
} }
// clang-format on // clang-format on
return -1; return -1;
...@@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
{"is_const", []() -> py::object { return py::bool_(false); }}, {"is_const", []() -> py::object { return py::bool_(false); }},
{"no_cache", []() -> py::object { return py::bool_(false); }}, {"no_cache", []() -> py::object { return py::bool_(false); }},
{"name", []() -> py::object { return py::none(); }}, {"name", []() -> py::object { return py::none(); }},
{"format", []() -> py::object { return py::none(); }},
}, },
name2idx}; name2idx};
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
...@@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} else { } else {
tup = parse_args(tup, descs); tup = parse_args(tup, descs);
} }
mgb_assert(tup.size() == 6); mgb_assert(tup.size() == 7);
if (auto* t = try_cast(tup[0].ptr())) { if (auto* t = try_cast(tup[0].ptr())) {
m_tensor = t->m_tensor->copy(); m_tensor = t->m_tensor->copy();
} else { } else {
auto data = tup[0]; auto data = tup[0];
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
CompNode cn = as_comp_node(tup[2]);
bool is_const = tup[3].cast<bool>(); bool is_const = tup[3].cast<bool>();
bool no_cache = tup[4].cast<bool>(); bool no_cache = tup[4].cast<bool>();
std::string name; std::string name;
if (!tup[5].is_none()) { if (!tup[5].is_none()) {
name = tup[5].cast<std::string>(); name = tup[5].cast<std::string>();
} }
CompNode cn = as_comp_node(tup[2]); Format format;
if (!tup[6].is_none()) {
format = tup[6].cast<std::string>();
}
{ {
CreateTensor::Kind kind = is_const ? CreateTensor::Const CreateTensor::Kind kind = is_const ? CreateTensor::Const
...@@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} else { } else {
auto&& hval = pyobj2hval(data, cn, dtype); auto&& hval = pyobj2hval(data, cn, dtype);
val = imperative::apply( val = imperative::apply(
CreateTensor(kind, cn, hval.dtype, hval.shape), CreateTensor(kind, cn, hval.dtype, hval.shape, format),
hval.storage)[0]; hval.storage)[0];
} }
m_tensor.emplace(val); m_tensor.emplace(val);
...@@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() { ...@@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() {
return py::cast(m_tensor->comp_node()).release().ptr(); return py::cast(m_tensor->comp_node()).release().ptr();
} }
PyObject* TensorWrapper::format() {
return py::cast(m_tensor->format().to_string()).release().ptr();
}
PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
auto hv = m_tensor->numpy(); auto hv = m_tensor->numpy();
if (!hv) { if (!hv) {
...@@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp); ...@@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp);
void init_tensor(py::module m) { void init_tensor(py::module m) {
imperative::Tensor::static_initialize(); imperative::Tensor::static_initialize();
// Transformations
static auto& transformations = TransformationManager::get_instance(); static auto& transformations = TransformationManager::get_instance();
using Segment = TransformationManager::Segment; using Segment = TransformationManager::Segment;
...@@ -755,6 +769,9 @@ void init_tensor(py::module m) { ...@@ -755,6 +769,9 @@ void init_tensor(py::module m) {
.register_at<Segment::DimExpansion>( .register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>()) std::make_shared<DimExpansionTransformation>())
.release()); .release());
auto format_trans = std::make_shared<FormatTransformation>();
MGB_MARK_USED_VAR(
transformations.register_at<Segment::Format>(format_trans).release());
static py::exception<interpreter::AsyncError> py_async_error( static py::exception<interpreter::AsyncError> py_async_error(
m, "AsyncError", PyExc_RuntimeError); m, "AsyncError", PyExc_RuntimeError);
...@@ -788,12 +805,14 @@ void init_tensor(py::module m) { ...@@ -788,12 +805,14 @@ void init_tensor(py::module m) {
} }
}); });
// Tensor
auto* tensor_type = auto* tensor_type =
TensorWrapper::wrap_t::type() TensorWrapper::wrap_t::type()
.def<&TensorWrapper::numpy>("numpy") .def<&TensorWrapper::numpy>("numpy")
.def_getset<&TensorWrapper::shape>("shape") .def_getset<&TensorWrapper::shape>("shape")
.def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device") .def_getset<&TensorWrapper::device>("device")
.def_getset<&TensorWrapper::format>("format")
.def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("_isscalar") .def<&TensorWrapper::isscalar>("_isscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
...@@ -812,6 +831,11 @@ void init_tensor(py::module m) { ...@@ -812,6 +831,11 @@ void init_tensor(py::module m) {
if (!tensor_type) if (!tensor_type)
throw py::error_already_set(); throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
py::enum_<Format::Type>(m, "FormatType")
.value("DEFAULT", Format::Type::DEFAULT)
.value("NCHW", Format::Type::NCHW)
.value("NHWC", Format::Type::NHWC)
.export_values();
py::class_<TensorWeakRef>(m, "TensorWeakRef") py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>()) .def(py::init<const TensorWrapper&>())
...@@ -911,6 +935,7 @@ void init_tensor(py::module m) { ...@@ -911,6 +935,7 @@ void init_tensor(py::module m) {
sync_py_task_q(); sync_py_task_q();
}); });
// GradTransformation
py::handle grad_key_type = py::handle grad_key_type =
GradKeyWrapper::wrap_t::type() GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
...@@ -1203,6 +1228,7 @@ void init_tensor(py::module m) { ...@@ -1203,6 +1228,7 @@ void init_tensor(py::module m) {
return wrapped_outputs; return wrapped_outputs;
}); });
// ModuleTraceTransformation
static py::function module_trace_hook; static py::function module_trace_hook;
static auto get_module_trace = [] { static auto get_module_trace = [] {
...@@ -1309,6 +1335,12 @@ void init_tensor(py::module m) { ...@@ -1309,6 +1335,12 @@ void init_tensor(py::module m) {
m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });
// FormatTransformation
m.def("set_auto_format_convert",
[format_trans](bool enabled) { format_trans->set_auto_convert(enabled); });
m.def("get_auto_format_convert",
[format_trans]() { return format_trans->get_auto_convert(); });
py::register_exception<TraceError>(m, "TraceError"); py::register_exception<TraceError>(m, "TraceError");
} }
......
#pragma once #pragma once
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include <variant>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <variant>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -57,6 +58,7 @@ public: ...@@ -57,6 +58,7 @@ public:
} }
return *shape; return *shape;
} }
inline Format format() { return *data().format(); }
inline HostValue::ref_t numpy() { return data().numpy(); } inline HostValue::ref_t numpy() { return data().numpy(); }
inline void reset(ValueRef value) { inline void reset(ValueRef value) {
m_data = value; m_data = value;
...@@ -116,6 +118,7 @@ public: ...@@ -116,6 +118,7 @@ public:
PyObject* shape(); PyObject* shape();
PyObject* dtype(); PyObject* dtype();
PyObject* device(); PyObject* device();
PyObject* format();
PyObject* numpy(); PyObject* numpy();
void reset(PyObject*); void reset(PyObject*);
PyObject* detach(); PyObject* detach();
......
...@@ -19,6 +19,7 @@ public: ...@@ -19,6 +19,7 @@ public:
DTypePromote, DTypePromote,
DimExpansion, DimExpansion,
Grad, Grad,
Format,
Scalar, Scalar,
Symbol, Symbol,
Trace, Trace,
......
...@@ -2,7 +2,7 @@ from megengine import amp ...@@ -2,7 +2,7 @@ from megengine import amp
from megengine.core.tensor import amp as origin_amp from megengine.core.tensor import amp as origin_amp
def test_grad_scaler(): def test_autocast():
def check(enabled, low, high): def check(enabled, low, high):
assert amp.enabled == enabled assert amp.enabled == enabled
assert origin_amp._enabled == enabled assert origin_amp._enabled == enabled
......
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine import tensor
from megengine.autodiff import GradManager
def test_basic():
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc")
assert a.format == "nhwc"
b = tensor(a)
assert b.format == "nhwc"
# TODO: fix Tensor init bug for another Tensor
# c = tensor(a, format="nchw")
# assert c.format == "nchw"
def _compare_nchw_nhwc(data, func):
x1 = tensor(data, format="nchw")
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
out1 = func(x1)
with mge.config._override(auto_format_convert=True):
out2 = func(x2)
np.testing.assert_equal(out1, out2)
def test_dimshuffle():
def func(x):
out = F.transpose(x, [2, 3, 0, 1])
assert out.format == "default"
return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_reshape():
# maintain NHWC format
def func(x):
out = F.reshape(x, (1, 2, 6, 2))
if x.format == "nhwc":
assert out.format == "nhwc"
return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
# not maintain NHWC format
def func2(x):
out = F.reshape(x, (1, 24))
assert out.format == "default"
return out.numpy()
_compare_nchw_nhwc(data, func2)
def test_flatten():
def func(x):
return F.flatten(x).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_broadcast():
# maintain NHWC format
def func(x):
out = F.broadcast_to(x, (4, 3, 2, 3))
if x.format == "nhwc":
assert out.format == "nhwc"
return out.numpy()
data = np.arange(0, 24).reshape((4, 3, 2, 1))
_compare_nchw_nhwc(data, func)
# not maintain NHWC format
def func2(x):
out = F.broadcast_to(x, (3, 4, 3, 2, 1))
assert out.format == "default"
return out.numpy()
_compare_nchw_nhwc(data, func2)
@pytest.mark.skip("repeat cannot maintain format yet")
def test_repeat():
def func(x):
rst = F.repeat(x, 3, axis=1)
assert rst.format == x.format
return rst.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_getshape():
def func(x):
return x.shape
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
@pytest.mark.skip("symbolic shape is not supported yet")
def test_get_symbolic_shape():
from megengine.core._trace_option import set_symbolic_shape
origin_opt = set_symbolic_shape(True)
def func(x):
return x.shape.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
set_symbolic_shape(origin_opt)
def test_getvalue():
def func(x):
return x.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_get_set_subtensor():
def get_subtensor(x):
return x[:, :1, :2, :3].numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, get_subtensor)
def set_subtensor(x):
x[:, :1, :2, :3] = 0
return x.numpy()
_compare_nchw_nhwc(data, set_subtensor)
def test_get_set_advanced_indexing():
def get_advanced_indexing(x):
x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy()
return x
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, get_advanced_indexing)
def set_advanced_indexing(x):
x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0
return x.numpy()
_compare_nchw_nhwc(data, set_advanced_indexing)
def test_typecvt():
def typecvt(x):
return x.astype("float16").numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, typecvt)
def test_elemwise():
def elemwise(x):
return (x * 2 + x / 2).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, elemwise)
def test_concat():
def func(x):
rst = F.concat([x / 2, x * 2], axis=1)
assert rst.format == x.format
return rst.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
@pytest.mark.parametrize(
"mode", ["bilinear", "nearest"],
)
def test_interpolate(mode):
def func(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
rst = F.vision.interpolate(x, scale_factor=3, mode=mode)
assert rst.format == "nhwc"
return rst.numpy()
else:
return F.vision.interpolate(x, scale_factor=3, mode=mode).numpy()
# NHWC interpolate only suppoted channel is 1 or 3
data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
_compare_nchw_nhwc(data, func)
def test_conv2d():
def conv2d(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = F.conv2d(
x,
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"),
)
assert x.format == "nhwc"
return x.numpy()
else:
return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, conv2d)
def test_group_conv2d():
def conv2d(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = F.conv2d(
x,
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"),
groups=2,
)
assert x.format == "nhwc"
return x.numpy()
else:
return F.conv2d(
x, F.ones((2, 2, 2, 1, 1)), F.ones((1, 4, 1, 1)), groups=2
).numpy()
data = np.arange(0, 48).reshape((1, 4, 3, 4))
_compare_nchw_nhwc(data, conv2d)
def test_bn():
def func(x):
if x.format == "nhwc":
with mge.config._override(bn_format="dim_111c"):
oups = F.batch_norm(
x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
training=True,
inplace=False,
)
assert oups[0].format == "nhwc", "y's format is wrong"
assert oups[1].format == "nhwc", "running_mean's format is wrong"
assert oups[2].format == "nhwc", "running_var's format is wrong"
return oups[0].numpy()
else:
return F.batch_norm(
x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 2, 1, 1))),
running_var=mge.tensor(np.ones((1, 2, 1, 1))),
weight=mge.tensor(np.ones((1, 2, 1, 1))),
bias=mge.tensor(np.ones((1, 2, 1, 1))),
training=True,
inplace=False,
)[0].numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
@pytest.mark.parametrize(
"pooling",
[F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d],
)
def test_pooling2d(pooling):
def func(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = pooling(x.astype("float32"), 2)
assert x.format == "nhwc"
return x.numpy()
else:
return pooling(x.astype("float32"), 2).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_backward():
data = np.arange(0, 24).reshape((1, 2, 3, 4))
x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
gm = GradManager().attach([w, b])
with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
x = F.conv2d(x, w, b)
gm.backward(x)
# TODO: backward grad has no format yet
np.testing.assert_equal(
w.grad.numpy(),
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
)
np.testing.assert_equal(
b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3))
)
...@@ -33,14 +33,20 @@ std::string GetAttr::to_string() const { ...@@ -33,14 +33,20 @@ std::string GetAttr::to_string() const {
return ssprintf("GetAttr{attr=%s}", attr_name); return ssprintf("GetAttr{attr=%s}", attr_name);
} }
CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) CreateTensor::CreateTensor(
: m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} Kind kind, CompNode device, DType dtype, ValueShape shape, Format format)
: m_kind(kind),
m_device(device),
m_dtype(dtype),
m_shape(shape),
m_format(format) {}
CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
: m_kind(kind), : m_kind(kind),
m_device(device), m_device(device),
m_dtype(layout.dtype), m_dtype(layout.dtype),
m_shape(ValueShape::from(layout)) { m_shape(ValueShape::from(layout)),
m_format(Format::Type::DEFAULT) {
mgb_assert( mgb_assert(
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
} }
...@@ -74,8 +80,9 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { ...@@ -74,8 +80,9 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
std::string CreateTensor::to_string() const { std::string CreateTensor::to_string() const {
return ssprintf( return ssprintf(
"CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}",
m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); (int)m_kind, m_device.to_string().c_str(), m_dtype.name(),
m_shape.to_string().c_str(), m_format.to_string().c_str());
} }
std::string DTRCommand::to_string() const { std::string DTRCommand::to_string() const {
......
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb {
namespace imperative {
using FT = Format::Type;
TypedValueRef<FormattedTensorValue> FormattedTensorValue::as(const FT& target) const {
return FormattedTensorValue::make(m_value, target);
}
TypedValueRef<FormattedTensorValue> FormattedTensorValue::to(
const FT& target, const std::string& scope) const {
std::vector<int32_t> pattern;
if (m_format == FT::NHWC && target == FT::NCHW) {
pattern = {0, 3, 1, 2};
} else if (m_format == FT::NCHW && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
m_format.to_string().c_str(), Format(target).to_string().c_str());
}
auto output = imperative::apply(
*Dimshuffle::make(pattern, scope), std::vector<ValueRef>{m_value})[0];
return FormattedTensorValue::make(output, target);
}
namespace {
ValueRef unwrap_input(const ValueRef& input) {
if (auto format_input = input.as_ref<FormattedTensorValue>()) {
return format_input->value();
} else {
return input;
}
}
std::vector<ValueRef> unwrap_inputs(const Span<ValueRef>& inputs) {
std::vector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
unwrapped_inputs.push_back(unwrap_input(input));
}
return unwrapped_inputs;
}
std::vector<ValueRef> wrap_outputs(
const std::vector<ValueRef>& outputs, FT type = FT::DEFAULT) {
std::vector<ValueRef> wrapped_outputs;
for (auto&& output : outputs) {
wrapped_outputs.push_back(FormattedTensorValue::make(output, type));
}
return wrapped_outputs;
}
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
mgb_assert(shape.ndim == 4);
auto out = ValueShape(shape);
out[3] = shape[2];
out[2] = shape[1];
out[1] = shape[3];
return out;
}
using FormatRule = std::function<std::vector<ValueRef>(
const OpDef&, Span<ValueRef>&, const bool&)>;
static std::unordered_map<Typeinfo*, FormatRule> format_rules;
template <typename T>
void register_format_rule(
std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>&, const bool&)) {
format_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef>& inputs,
const bool& auto_convert) {
return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert);
};
}
auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) {
mgb_assert(pattern.size() == 4);
auto nhwc_pattern = pattern;
for (size_t idx = 0; idx < 4; ++idx) {
auto dim = pattern[idx];
if (dim == 1) {
nhwc_pattern[idx] = 3;
} else if (dim == 2) {
nhwc_pattern[idx] = 1;
} else if (dim == 3) {
nhwc_pattern[idx] = 2;
}
}
return nhwc_pattern;
}
std::vector<ValueRef> dimshuffle_rule(
const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert) {
mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
// Only support converting pattern from NCHW to NHWC currently.
if (auto_convert && src.format() == FT::NHWC) {
auto pattern = convert_nchw2nhwc_pattern(op.pattern);
// dimshuffle will not maintain NHWC Format
return wrap_outputs(imperative::apply(
*Dimshuffle::make(std::move(pattern), op.scope()),
unwrap_inputs(inputs)));
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
mgb_assert(shape.layout().total_nr_elems() == 4);
auto* old_ptr = shape.ptr<dt_int32>();
auto cn = shape.comp_node();
auto layout = shape.layout();
auto nhwc_shape = HostTensorND(cn, layout);
auto* new_ptr = nhwc_shape.ptr<dt_int32>();
new_ptr[0] = old_ptr[0];
new_ptr[1] = old_ptr[2];
new_ptr[2] = old_ptr[3];
new_ptr[3] = old_ptr[1];
auto hv = HostStorage::make(nhwc_shape.storage());
auto nhwc_shape_input =
imperative::apply(CreateTensor(CreateTensor::Const, cn, layout), hv)[0];
return nhwc_shape_input;
}
std::vector<ValueRef> reshape_rule(
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert) {
mgb_assert(inputs.size() == 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
if (auto_convert && src.format() == FT::NHWC) {
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape});
return wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = src.to(FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])});
return wrap_outputs(outputs);
}
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
std::vector<ValueRef> broadcast_rule(
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert) {
mgb_assert(inputs.size() == 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
if (auto_convert && src.format() == FT::NHWC) {
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape});
return wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = src.to(FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])});
return wrap_outputs(outputs);
}
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
bool is_reduce_ndim_idx_items(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items,
const Span<ValueRef>& inputs) {
for (auto i = 0; i < items.size(); ++i) {
auto&& [axis, begin, end, step, idx] = items[i];
if (idx) {
// if inputs[i] contains more than one value, ndim will not be reduced.
return inputs[i].is_scalar();
}
}
return false;
}
auto convert_nchw2nhwc_idx_items(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) {
auto nhwc_items = items;
for (auto i = 0; i < nhwc_items.size(); ++i) {
auto&& [axis, begin, end, step, idx] = nhwc_items[i];
if (axis == 2 || axis == 3) {
nhwc_items[i] = {axis - 1, begin, end, step, idx};
} else if (axis == 1) {
nhwc_items[i] = {3, begin, end, step, idx};
}
}
return nhwc_items;
}
template <typename T>
std::vector<ValueRef> subtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) {
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
bool is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[1], &inputs[inputs.size() - 1]});
if (!is_reduce_ndim) {
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {FormattedTensorValue::make(
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())};
}
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
*T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs));
return wrap_outputs(outputs, FT::NHWC);
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
template <typename T>
std::vector<ValueRef> setsubtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) {
mgb_assert(inputs.size() >= 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
bool is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[2], &inputs[inputs.size() - 1]});
if (!is_reduce_ndim) {
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {FormattedTensorValue::make(
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())};
}
// value has been broadcasted to src's fake NCHW shape.
auto& value = inputs[1].cast<FormattedTensorValue>();
auto& format = value.format();
auto nhwc_inputs = std::vector<ValueRef>(inputs.size());
if (format == FT::DEFAULT || format == FT::NCHW) {
// value for setsubtensor should transpose to match shape.
auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC);
// make new inputs for setsubtensor
nhwc_inputs[0] = src.value();
nhwc_inputs[1] = nhwc_value->value();
for (auto i = 2; i < inputs.size(); ++i) {
nhwc_inputs[i] = inputs[i].as_ref<FormattedTensorValue>()->value();
}
} else if (format != FT::NHWC) {
mgb_throw(
MegBrainError, "Unsupported format(%s) of value for setsubtensor.",
format.to_string().c_str());
}
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
*T::make(std::move(nhwc_items), op.scope()), nhwc_inputs);
return wrap_outputs(outputs, FT::NHWC);
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
FT get_inputs_format(Span<ValueRef>& inputs) {
FT format(FT::DEFAULT);
for (auto& inp : inputs) {
auto& inp_format = inp.cast<FormattedTensorValue>().format();
if (inp_format != FT::DEFAULT) {
mgb_assert(format == FT::DEFAULT || inp_format == format);
format = inp_format.type();
}
}
return format;
}
std::vector<ValueRef> concat_rule(
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert) {
FT format = get_inputs_format(inputs);
if (!(format == FT::NHWC && auto_convert)) {
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
}
// TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis;
if (axis == 2 || axis == 3) {
axis = axis - 1;
} else if (axis == 1) {
axis = 3;
}
return wrap_outputs(
imperative::apply(
*Concat::make(axis, op.comp_node, op.scope()),
unwrap_inputs(inputs)),
format);
}
std::vector<ValueRef> elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert) {
FT format = get_inputs_format(inputs);
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
}
std::vector<ValueRef> identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs) {
// mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), src.format().type());
}
// clang-format off
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Pooling) \
cb(AdaptivePooling) \
cb(Dropout) \
cb(Convolution) \
cb(BatchNorm) \
cb(Resize) \
cb(Identity)
// clang-format on
#define CREATE_IDENTITY_OP_RULE(op) \
std::vector<ValueRef> op##_rule( \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert) { \
return identity_rule_helper(_op, inputs); \
}
FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
#undef CREATE_IDENTITY_OP_RULE
#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule);
struct FormatRuleRegistry {
FormatRuleRegistry() {
register_format_rule(dimshuffle_rule);
register_format_rule(reshape_rule);
register_format_rule(broadcast_rule);
register_format_rule(subtensor_rule<Subtensor>);
register_format_rule(subtensor_rule<IndexingMultiAxisVec>);
register_format_rule(setsubtensor_rule<SetSubtensor>);
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(concat_rule);
register_format_rule(elemwise_rule);
FOREACH_IDENTITY_OP(REGISTER_IDENTITY_OP_RULE)
}
} _;
#undef REGISTER_IDENTITY_OP_RULE
} // namespace
std::vector<ValueRef> FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
if (iter != format_rules.end()) {
return iter->second(apply_op->op(), inputs, m_auto_convert);
} else {
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format();
return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto* src = inputs.as_array<1>()[0].as<FormattedTensorValue>();
if (!m_auto_convert || !src || src->format() != FT::NHWC) {
return imperative::apply(op, unwrap_inputs(inputs));
}
switch (get_attr->attr()) {
case GetAttr::Shape: {
auto output = imperative::apply(op, unwrap_inputs(inputs))[0];
auto shape = convert_nhwc2nchw_shape(output.cast<ShapeValue>());
return {ShapeValue::make(shape)};
}
case GetAttr::Value: {
auto nchw_src = unwrap_input(src->to(FT::NCHW, ""));
return imperative::apply(op, std::vector<ValueRef>{nchw_src});
}
default:
return imperative::apply(op, unwrap_inputs(inputs));
}
} else if (op.is<GetFormat>()) {
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>();
if (is_formatted_tensor) {
return {FormatValue::make(inputs[0].cast<FormattedTensorValue>().format())};
} else {
mgb_log_warn(
"Not FormattedTensorValue input for GetFormat op: %s",
inputs[0].to_string().c_str());
return {FormatValue::make(FT::DEFAULT)};
}
} else if (op.is<Operator::IdentityLike>()) {
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>();
if (is_formatted_tensor) {
auto& format = inputs[0].cast<FormattedTensorValue>().format();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), format.type());
} else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s",
inputs[0].to_string().c_str());
return imperative::apply(op, inputs);
}
} else {
return imperative::apply(op, unwrap_inputs(inputs));
}
};
} // namespace imperative
} // namespace mgb
...@@ -58,6 +58,10 @@ TypedValueRef<DTypeValue> ValueRef::dtype() const { ...@@ -58,6 +58,10 @@ TypedValueRef<DTypeValue> ValueRef::dtype() const {
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>();
} }
TypedValueRef<FormatValue> ValueRef::format() const {
return imperative::apply(GetFormat(), *this)[0].as_ref<FormatValue>();
}
TypedValueRef<StringValue> ValueRef::name() const { TypedValueRef<StringValue> ValueRef::name() const {
return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>();
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/operator.h" #include "megbrain/imperative/operator.h"
#include "megbrain/imperative/utils/data_format.h"
#include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h" #include "megbrain/imperative/utils/value_shape.h"
...@@ -82,9 +83,12 @@ private: ...@@ -82,9 +83,12 @@ private:
CompNode m_device; CompNode m_device;
DType m_dtype; DType m_dtype;
ValueShape m_shape; ValueShape m_shape;
Format m_format;
public: public:
CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); CreateTensor(
Kind kind, CompNode device, DType dtype, ValueShape shape,
Format format = Format::Type::DEFAULT);
CreateTensor(Kind kind, CompNode device, TensorLayout layout); CreateTensor(Kind kind, CompNode device, TensorLayout layout);
/** /**
...@@ -99,6 +103,7 @@ public: ...@@ -99,6 +103,7 @@ public:
CompNode device() const { return m_device; } CompNode device() const { return m_device; }
DType dtype() const { return m_dtype; } DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; } ValueShape shape() const { return m_shape; }
Format format() const { return m_format; }
std::string to_string() const override; std::string to_string() const override;
}; };
...@@ -157,6 +162,11 @@ public: ...@@ -157,6 +162,11 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> {
public:
std::string to_string() const override { return "GetFormat{}"; }
};
class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> {
public: public:
std::string to_string() const override; std::string to_string() const override;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <future> #include <future>
#include <iomanip> #include <iomanip>
#include "megbrain/imperative/utils/data_format.h"
#include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h" #include "megbrain/imperative/utils/value_shape.h"
#include "megbrain/imperative/value.h" #include "megbrain/imperative/value.h"
...@@ -148,6 +149,13 @@ public: ...@@ -148,6 +149,13 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class FormatValue final : public PrimitiveValue<FormatValue, Format> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { return Format::to_string(); }
};
class StringValue final : public PrimitiveValue<StringValue, std::string> { class StringValue final : public PrimitiveValue<StringValue, std::string> {
public: public:
using PrimitiveValue::PrimitiveValue; using PrimitiveValue::PrimitiveValue;
......
#pragma once
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/data_format.h"
namespace mgb::imperative {
class FormattedTensorValue final : public ValueImpl<FormattedTensorValue> {
private:
ValueRef m_value;
Format m_format;
public:
FormattedTensorValue(ValueRef value, Format format)
: m_value(value), m_format(format) {}
std::string to_string() const override {
return ssprintf(
"FormattedTensorValue{value=%s, format=%s}",
m_value.to_string().c_str(), m_format.to_string().c_str());
}
ValueRef value() const { return m_value; }
const Format& format() const { return m_format; }
TypedValueRef<FormattedTensorValue> as(const Format::Type& target) const;
TypedValueRef<FormattedTensorValue> to(
const Format::Type& target, const std::string& scope = "") const;
void clear() override {
m_value = {};
m_format = {};
}
void on_watch() override { m_value.watch(); }
void on_unwatch() override { m_value.unwatch(); }
};
/**
* \brief simulates scalar because megbrain graph system don't support scalar
*
* Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'.
* This transformation simulates scalars with a flag. If a value is ScalarValue, it is
* scalar, vice versa. So there is not scalar down this layer.
*/
class FormatTransformation final : public Transformation {
private:
bool m_auto_convert = false;
public:
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<FormattedTensorValue>());
return value;
}
std::string name() const override {
return ssprintf("FormatTransformation{auto_convert=%d}", m_auto_convert);
}
void set_auto_convert(bool enabled) { m_auto_convert = enabled; }
bool get_auto_convert() const { return m_auto_convert; }
};
} // namespace mgb::imperative
#pragma once
#include "megbrain/tensor.h"
namespace mgb::imperative {
/**
* \brief like TensorFormats, but only including common formats and DEFAULT.
*
*/
class Format {
public:
enum class Type {
DEFAULT = 0,
NCHW = 1, ///< [N, C, H, W]
NHWC = 2, ///< [N, H, W, C]
};
std::string to_string() const {
switch (m_type) {
case Type::DEFAULT:
return "default";
case Type::NCHW:
return "nchw";
case Type::NHWC:
return "nhwc";
default:
mgb_throw(MegBrainError, "bad format type");
}
}
Format() : m_type(Type::DEFAULT) {}
Format(std::string str) {
if (str == "default") {
m_type = Type::DEFAULT;
} else if (str == "nchw") {
m_type = Type::NCHW;
} else if (str == "nhwc") {
m_type = Type::NHWC;
} else {
mgb_throw(
MegBrainError,
"Invalid format type."
" Only support \"default\", \"nchw\" and \"nhwc\"");
}
}
Format(Type type) : m_type(type) {}
Type type() const { return m_type; }
bool operator==(const Format& b) const { return m_type == b.type(); }
bool operator==(const Format::Type& b) const { return m_type == b; }
bool operator!=(const Format& b) const { return m_type != b.type(); }
bool operator!=(const Format::Type& b) const { return m_type != b; }
private:
Type m_type = Type::DEFAULT;
};
} // namespace mgb::imperative
...@@ -31,6 +31,7 @@ class HostValue; ...@@ -31,6 +31,7 @@ class HostValue;
class DeviceValue; class DeviceValue;
class ShapeValue; class ShapeValue;
class DTypeValue; class DTypeValue;
class FormatValue;
class CompNodeValue; class CompNodeValue;
class StringValue; class StringValue;
class NodeValue; class NodeValue;
...@@ -219,6 +220,7 @@ public: ...@@ -219,6 +220,7 @@ public:
TypedValueRef<CompNodeValue> device() const; TypedValueRef<CompNodeValue> device() const;
TypedValueRef<ShapeValue> shape() const; TypedValueRef<ShapeValue> shape() const;
TypedValueRef<DTypeValue> dtype() const; TypedValueRef<DTypeValue> dtype() const;
TypedValueRef<FormatValue> format() const;
TypedValueRef<StringValue> name() const; TypedValueRef<StringValue> name() const;
bool is_scalar() const; bool is_scalar() const;
...@@ -431,9 +433,11 @@ inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type) ...@@ -431,9 +433,11 @@ inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type)
inline void ValueRef::on_cast_failure(const IType& type) const { inline void ValueRef::on_cast_failure(const IType& type) const {
// if this is ErrorValue, rethrow directly // if this is ErrorValue, rethrow directly
storage()->try_rethrow(); storage()->try_rethrow();
mgb_assert( if (storage()->type() != type) {
storage()->type() != type, "expect type %s, got %s", type.name().c_str(), mgb_throw(
to_string().c_str()); MegBrainError, "Unable to cast ValueRef: expect type %s, got %s",
type.name().c_str(), to_string().c_str());
}
} }
/** /**
......
...@@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape( ...@@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape(
bias_c = inp_shape[2][channel_idx]; bias_c = inp_shape[2][channel_idx];
mgb_assert( mgb_assert(
inp_c == scale_c && inp_c == bias_c, inp_c == scale_c && inp_c == bias_c,
"inconsistent channel size, input chennel: %zu, scale channel: %zu, bias " "inconsistent channel size, input channel: %zu, scale channel: %zu, bias "
"channel: %zu", "channel: %zu",
inp_c, scale_c, bias_c); inp_c, scale_c, bias_c);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册