diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 40852d7ae847e1fa59e011925fa18cace9452fb1..83b6b40ca56e25264b353e30c025e7153a3102ce 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1017,10 +1017,14 @@ void init_tensor(py::module m) { using namespace std::placeholders; self.compiled = std::make_shared( *self.trace_result, self.record_input_shapes); - self.compiled->set_value_comparator( - std::bind(&Trace::compare_value, this, _1, _2)); - self.options_visitor(py::cast(&self.compiled->options())); - self.compiled->compile(); + self.compiled->set_value_comparator( + std::bind(&Trace::compare_value, this, _1, _2)); + self.options_visitor(py::cast(&self.compiled->options())); + try { + self.compiled->compile(); + } catch (const std::exception& e) { + mgb_log_error(e.what()); + } } // register transformations if (self.compiled) { diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 6c71fe0a962ffc1b0192ae2c2ae6a2942c0eba18..f6d9bdc411070574de6fc31485ce456d583e3e39 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -5,6 +5,7 @@ import megengine as mge import megengine.functional as F from megengine import tensor from megengine.autodiff import GradManager +from megengine.jit import trace def test_basic(): @@ -30,26 +31,29 @@ def test_basic(): assert b.format == "nchw" -def _compare_nchw_nhwc(data, func): +def _compare_nchw_nhwc(data, func, is_symbolic=None): x1 = tensor(data, format="nchw") x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + if is_symbolic is not None: + func = trace(func, symbolic=is_symbolic) out1 = func(x1) - with mge.config._override(auto_format_convert=True): - out2 = func(x2) + out2 = func(x2) np.testing.assert_almost_equal(out1, out2, decimal=5) -def test_dimshuffle(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_dimshuffle(is_symbolic): def func(x): out = F.transpose(x, [2, 3, 0, 1]) assert out.format == "default" return out.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_reshape(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_reshape(is_symbolic): # maintain NHWC format def func(x): out = F.reshape(x, (1, 2, 6, 2)) @@ -58,7 +62,7 @@ def test_reshape(): return out.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) # not maintain NHWC format def func2(x): @@ -66,18 +70,20 @@ def test_reshape(): assert out.format == "default" return out.numpy() - _compare_nchw_nhwc(data, func2) + _compare_nchw_nhwc(data, func2, is_symbolic) -def test_flatten(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_flatten(is_symbolic): def func(x): return F.flatten(x).numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_broadcast(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_broadcast(is_symbolic): # maintain NHWC format def func(x): out = F.broadcast_to(x, (4, 3, 2, 3)) @@ -86,7 +92,7 @@ def test_broadcast(): return out.numpy() data = np.arange(0, 24).reshape((4, 3, 2, 1)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) # not maintain NHWC format def func2(x): @@ -94,30 +100,32 @@ def test_broadcast(): assert out.format == "default" return out.numpy() - _compare_nchw_nhwc(data, func2) + _compare_nchw_nhwc(data, func2, is_symbolic) @pytest.mark.skip("repeat cannot maintain format yet") -def test_repeat(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_repeat(is_symbolic): def func(x): rst = F.repeat(x, 3, axis=1) assert rst.format == x.format return rst.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_getshape(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_getshape(is_symbolic): def func(x): return x.shape data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) @pytest.mark.skip("symbolic shape is not supported yet") -def test_get_symbolic_shape(): +def test_get_symbolic_shape(is_symbolic): from megengine.core._trace_option import set_symbolic_shape origin_opt = set_symbolic_shape(True) @@ -126,77 +134,84 @@ def test_get_symbolic_shape(): return x.shape.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) set_symbolic_shape(origin_opt) -def test_getvalue(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_getvalue(is_symbolic): def func(x): return x.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_get_set_subtensor(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_get_set_subtensor(is_symbolic): def get_subtensor(x): return x[:, :1, :2, :3].numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, get_subtensor) + _compare_nchw_nhwc(data, get_subtensor, is_symbolic) def set_subtensor(x): x[:, :1, :2, :3] = 0 return x.numpy() - _compare_nchw_nhwc(data, set_subtensor) + _compare_nchw_nhwc(data, set_subtensor, is_symbolic) -def test_get_set_advanced_indexing(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_get_set_advanced_indexing(is_symbolic): def get_advanced_indexing(x): x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy() return x data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, get_advanced_indexing) + _compare_nchw_nhwc(data, get_advanced_indexing, is_symbolic) def set_advanced_indexing(x): x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0 return x.numpy() - _compare_nchw_nhwc(data, set_advanced_indexing) + _compare_nchw_nhwc(data, set_advanced_indexing, is_symbolic) -def test_typecvt(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_typecvt(is_symbolic): def typecvt(x): return x.astype("float16").numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, typecvt) + _compare_nchw_nhwc(data, typecvt, is_symbolic) -def test_elemwise(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_elemwise(is_symbolic): def elemwise(x): return (x * 2 + x / 2).numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, elemwise) + _compare_nchw_nhwc(data, elemwise, is_symbolic) -def test_concat(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_concat(is_symbolic): def func(x): rst = F.concat([x / 2, x * 2], axis=1) assert rst.format == x.format return rst.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) @pytest.mark.parametrize( "mode", ["bilinear", "nearest"], ) -def test_interpolate(mode): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_interpolate(mode, is_symbolic): def func(x): if x.format == "nhwc": with mge.config._override(conv_format="NHWC"): @@ -208,10 +223,11 @@ def test_interpolate(mode): # NHWC interpolate only suppoted channel is 1 or 3 data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_conv2d(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_conv2d(is_symbolic): def conv2d(x): if x.format == "nhwc": with mge.config._override(conv_format="NHWC"): @@ -226,10 +242,11 @@ def test_conv2d(): return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, conv2d) + _compare_nchw_nhwc(data, conv2d, is_symbolic) -def test_group_conv2d(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_group_conv2d(is_symbolic): def conv2d(x): if x.format == "nhwc": with mge.config._override(conv_format="NHWC"): @@ -247,10 +264,11 @@ def test_group_conv2d(): ).numpy() data = np.arange(0, 48).reshape((1, 4, 3, 4)) - _compare_nchw_nhwc(data, conv2d) + _compare_nchw_nhwc(data, conv2d, is_symbolic) -def test_bn(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_bn(is_symbolic): def func(x): if x.format == "nhwc": with mge.config._override(bn_format="dim_111c"): @@ -279,14 +297,15 @@ def test_bn(): )[0].numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) @pytest.mark.parametrize( "pooling", [F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d], ) -def test_pooling2d(pooling): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_pooling2d(pooling, is_symbolic): def func(x): if x.format == "nhwc": with mge.config._override(conv_format="NHWC"): @@ -297,18 +316,25 @@ def test_pooling2d(pooling): return pooling(x.astype("float32"), 2).numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) - _compare_nchw_nhwc(data, func) + _compare_nchw_nhwc(data, func, is_symbolic) -def test_backward(): +@pytest.mark.parametrize("is_symbolic", [None]) +def test_backward(is_symbolic): data = np.arange(0, 24).reshape((1, 2, 3, 4)) x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") gm = GradManager().attach([w, b]) + + def func(x, w, b): + return F.conv2d(x, w, b) + with gm: with mge.config._override(auto_format_convert=True, conv_format="NHWC"): - x = F.conv2d(x, w, b) + if is_symbolic is not None: + func = trace(func, symbolic=is_symbolic) + x = func(x, w, b) # TODO: fix manually convert to NHWC, usually used in detection head # x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) gm.backward(x) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index fa431665ef101c2a7782994cfe115fdafc0d46ab..bdb4a77aea8f510ce137fed2b7c08963e355efc6 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -7,53 +7,62 @@ namespace imperative { using FT = Format::Type; -TypedValueRef FormattedTensorValue::as(const FT& target) const { - return FormattedTensorValue::make(m_value, target); +TypedValueRef FormatTransformation::as( + const FormattedTensorValue& tensor, const FT& target) const { + return m_value_type.make(tensor.value(), target); } -TypedValueRef FormattedTensorValue::to( - const FT& target, const std::string& scope) const { +TypedValueRef FormatTransformation::to( + const FormattedTensorValue& tensor, const FT& target, + const std::string& scope) const { std::vector pattern; - if (m_format == FT::NHWC && target == FT::NCHW) { + if (tensor.format() == FT::NHWC && target == FT::NCHW) { pattern = {0, 3, 1, 2}; - } else if (m_format == FT::NCHW && target == FT::NHWC) { + } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { pattern = {0, 2, 3, 1}; } else { mgb_throw( MegBrainError, "Unsupport format conversion from %s to %s", - m_format.to_string().c_str(), Format(target).to_string().c_str()); + tensor.format().to_string().c_str(), + Format(target).to_string().c_str()); } auto output = imperative::apply( - *Dimshuffle::make(pattern, scope), std::vector{m_value})[0]; - return FormattedTensorValue::make(output, target); + *Dimshuffle::make(pattern, scope), + SmallVector{tensor.value()})[0]; + return m_value_type.make(output, target); } -namespace { - -ValueRef unwrap_input(const ValueRef& input) { - if (auto format_input = input.as_ref()) { +inline ValueRef FormatTransformation::unwrap_input(const ValueRef& input) const { + if (auto format_input = input.as_ref(m_value_type)) { return format_input->value(); } else { return input; } } -std::vector unwrap_inputs(const Span& inputs) { - std::vector unwrapped_inputs; - for (auto&& input : inputs) { - unwrapped_inputs.push_back(unwrap_input(input)); +inline ValueRefList FormatTransformation::unwrap_inputs( + const Span& inputs) const { + ValueRefList unwrapped_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + unwrapped_inputs[i] = unwrap_input(inputs[i]); } return unwrapped_inputs; } -std::vector wrap_outputs( - const std::vector& outputs, FT type = FT::DEFAULT) { - std::vector wrapped_outputs; - for (auto&& output : outputs) { - wrapped_outputs.push_back(FormattedTensorValue::make(output, type)); +inline ValueRef FormatTransformation::wrap_output( + const ValueRef& output, FT type) const { + return m_value_type.make(output, type); +} + +inline ValueRefList FormatTransformation::wrap_outputs( + const ValueRefList& outputs, FT type) const { + ValueRefList wrapped_outputs(outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + wrapped_outputs[i] = wrap_output(outputs[i], type); } return wrapped_outputs; } +namespace { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { mgb_assert(shape.ndim == 4); @@ -64,20 +73,21 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { return out; } -using FormatRule = std::function( - const OpDef&, Span&, const bool&)>; +using FormatRule = std::function&, const bool&, const FormatTransformation&)>; static std::unordered_map format_rules; template -void register_format_rule( - std::vector (*rule)(const T&, Span&, const bool&)) { +void register_format_rule(ValueRefList (*rule)( + const T&, Span&, const bool&, const FormatTransformation&)) { format_rules[T::typeinfo()] = [rule](const OpDef& def, Span& inputs, - const bool& auto_convert) { - return (*rule)(def.cast_final_safe(), inputs, auto_convert); + const bool& auto_convert, + const FormatTransformation& t) { + return (*rule)(def.cast_final_safe(), inputs, auto_convert, t); }; } -auto convert_nchw2nhwc_pattern(const std::vector& pattern) { +inline auto convert_nchw2nhwc_pattern(const std::vector& pattern) { mgb_assert(pattern.size() == 4); auto nhwc_pattern = pattern; for (size_t idx = 0; idx < 4; ++idx) { @@ -93,19 +103,20 @@ auto convert_nchw2nhwc_pattern(const std::vector& pattern) { return nhwc_pattern; } -std::vector dimshuffle_rule( - const Dimshuffle& op, Span& inputs, const bool& auto_convert) { +ValueRefList dimshuffle_rule( + const Dimshuffle& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { mgb_assert(inputs.size() == 1); - auto& src = inputs[0].cast(); + auto& src = inputs[0].cast(t.value_type()); // Only support converting pattern from NCHW to NHWC currently. if (auto_convert && src.format() == FT::NHWC) { auto pattern = convert_nchw2nhwc_pattern(op.pattern); // dimshuffle will not maintain NHWC Format - return wrap_outputs(imperative::apply( + return t.wrap_outputs(imperative::apply( *Dimshuffle::make(std::move(pattern), op.scope()), - unwrap_inputs(inputs))); + t.unwrap_inputs(inputs))); } - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); } ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { @@ -125,53 +136,55 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { return nhwc_shape_input; } -std::vector reshape_rule( - const Reshape& op, Span& inputs, const bool& auto_convert) { +ValueRefList reshape_rule( + const Reshape& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { mgb_assert(inputs.size() == 2); - auto& src = inputs[0].cast(); + auto& src = inputs[0].cast(t.value_type()); if (auto_convert && src.format() == FT::NHWC) { - auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); if (shape.layout().total_nr_elems() == 4) { // output is still NHWC format auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); auto outputs = imperative::apply( - op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); - return wrap_outputs(outputs, FT::NHWC); + op, SmallVector{t.unwrap_input(inputs[0]), nhwc_shape}); + return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); auto outputs = imperative::apply( - op, std::vector{nchw_src, unwrap_input(inputs[1])}); - return wrap_outputs(outputs); + op, SmallVector{nchw_src, t.unwrap_input(inputs[1])}); + return t.wrap_outputs(outputs); } } - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); } -std::vector broadcast_rule( - const Broadcast& op, Span& inputs, const bool& auto_convert) { +ValueRefList broadcast_rule( + const Broadcast& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { mgb_assert(inputs.size() == 2); - auto& src = inputs[0].cast(); + auto& src = inputs[0].cast(t.value_type()); if (auto_convert && src.format() == FT::NHWC) { - auto shape = unwrap_input(inputs[1]).numpy().cast().as_nd(); + auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); if (shape.layout().total_nr_elems() == 4) { // output is still NHWC format auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); auto outputs = imperative::apply( - op, std::vector{unwrap_input(inputs[0]), nhwc_shape}); - return wrap_outputs(outputs, FT::NHWC); + op, SmallVector{t.unwrap_input(inputs[0]), nhwc_shape}); + return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = src.to(FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); auto outputs = imperative::apply( - op, std::vector{nchw_src, unwrap_input(inputs[1])}); - return wrap_outputs(outputs); + op, SmallVector{nchw_src, t.unwrap_input(inputs[1])}); + return t.wrap_outputs(outputs); } } - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); } -bool is_reduce_ndim_idx_items( +inline bool is_reduce_ndim_idx_items( const std::vector>& items, const Span& inputs) { for (auto i = 0; i < items.size(); ++i) { @@ -184,7 +197,7 @@ bool is_reduce_ndim_idx_items( return false; } -auto convert_nchw2nhwc_idx_items( +inline auto convert_nchw2nhwc_idx_items( const std::vector>& items) { auto nhwc_items = items; for (auto i = 0; i < nhwc_items.size(); ++i) { @@ -199,51 +212,55 @@ auto convert_nchw2nhwc_idx_items( } template -std::vector subtensor_rule( - const T& op, Span& inputs, const bool& auto_convert) { +ValueRefList subtensor_rule( + const T& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { mgb_assert(inputs.size() >= 1); - auto& src = inputs[0].cast(); + auto& src = inputs[0].cast(t.value_type()); bool is_reduce_ndim = is_reduce_ndim_idx_items( op.items, {&inputs[1], &inputs[inputs.size() - 1]}); if (!is_reduce_ndim) { // only support NHWC2NCHW convert, otherwise maintain src's format if (!(auto_convert && src.format() == FT::NHWC)) { - return {FormattedTensorValue::make( - imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + return {t.wrap_output( + imperative::apply(op, t.unwrap_inputs(inputs))[0], + src.format().type())}; } auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); auto outputs = imperative::apply( - *T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs)); - return wrap_outputs(outputs, FT::NHWC); + *T::make(std::move(nhwc_items), op.scope()), t.unwrap_inputs(inputs)); + return t.wrap_outputs(outputs, FT::NHWC); } - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); } template -std::vector setsubtensor_rule( - const T& op, Span& inputs, const bool& auto_convert) { +ValueRefList setsubtensor_rule( + const T& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { mgb_assert(inputs.size() >= 2); - auto& src = inputs[0].cast(); + auto& src = inputs[0].cast(t.value_type()); bool is_reduce_ndim = is_reduce_ndim_idx_items( op.items, {&inputs[2], &inputs[inputs.size() - 1]}); if (!is_reduce_ndim) { // only support NHWC2NCHW convert, otherwise maintain src's format if (!(auto_convert && src.format() == FT::NHWC)) { - return {FormattedTensorValue::make( - imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; + return {t.wrap_output( + imperative::apply(op, t.unwrap_inputs(inputs))[0], + src.format().type())}; } // value has been broadcasted to src's fake NCHW shape. - auto& value = inputs[1].cast(); + auto& value = inputs[1].cast(t.value_type()); auto& format = value.format(); - auto nhwc_inputs = std::vector(inputs.size()); + auto nhwc_inputs = ValueRefList(inputs.size()); if (format == FT::DEFAULT || format == FT::NCHW) { // value for setsubtensor should transpose to match shape. - auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC); + auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC); // make new inputs for setsubtensor nhwc_inputs[0] = src.value(); nhwc_inputs[1] = nhwc_value->value(); for (auto i = 2; i < inputs.size(); ++i) { - nhwc_inputs[i] = inputs[i].as_ref()->value(); + nhwc_inputs[i] = t.unwrap_input(inputs[i]); } } else if (format != FT::NHWC) { mgb_throw( @@ -253,15 +270,15 @@ std::vector setsubtensor_rule( auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); auto outputs = imperative::apply( *T::make(std::move(nhwc_items), op.scope()), nhwc_inputs); - return wrap_outputs(outputs, FT::NHWC); + return t.wrap_outputs(outputs, FT::NHWC); } - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); } -FT get_inputs_format(Span& inputs) { +inline FT get_inputs_format(Span& inputs, const FormatTransformation& t) { FT format(FT::DEFAULT); for (auto& inp : inputs) { - auto& inp_format = inp.cast().format(); + auto& inp_format = inp.cast(t.value_type()).format(); if (inp_format != FT::DEFAULT) { mgb_assert(format == FT::DEFAULT || inp_format == format); format = inp_format.type(); @@ -270,11 +287,12 @@ FT get_inputs_format(Span& inputs) { return format; } -std::vector concat_rule( - const Concat& op, Span& inputs, const bool& auto_convert) { - FT format = get_inputs_format(inputs); +ValueRefList concat_rule( + const Concat& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { + FT format = get_inputs_format(inputs, t); if (!(format == FT::NHWC && auto_convert)) { - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); } // TODO: handle 5D NHWC Tensor from group conv auto axis = op.axis; @@ -283,25 +301,26 @@ std::vector concat_rule( } else if (axis == 1) { axis = 3; } - return wrap_outputs( + return t.wrap_outputs( imperative::apply( *Concat::make(axis, op.comp_node, op.scope()), - unwrap_inputs(inputs)), + t.unwrap_inputs(inputs)), format); } -std::vector elemwise_rule( - const Elemwise& op, Span& inputs, const bool& auto_convert) { - FT format = get_inputs_format(inputs); - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); +ValueRefList elemwise_rule( + const Elemwise& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { + FT format = get_inputs_format(inputs, t); + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); } -std::vector identity_rule_helper( - const OpDef& op, const Span& inputs) { +ValueRefList identity_rule_helper( + const OpDef& op, const Span& inputs, const FormatTransformation& t) { // mgb_assert(inputs.size() == 1); - auto& src = inputs[0].cast(); - return wrap_outputs( - imperative::apply(op, unwrap_inputs(inputs)), src.format().type()); + auto& src = inputs[0].cast(t.value_type()); + return t.wrap_outputs( + imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); } // clang-format off @@ -318,10 +337,11 @@ std::vector identity_rule_helper( cb(Identity) // clang-format on -#define CREATE_IDENTITY_OP_RULE(op) \ - std::vector op##_rule( \ - const op& _op, Span& inputs, const bool& auto_convert) { \ - return identity_rule_helper(_op, inputs); \ +#define CREATE_IDENTITY_OP_RULE(op) \ + ValueRefList op##_rule( \ + const op& _op, Span& inputs, const bool& auto_convert, \ + const FormatTransformation& t) { \ + return identity_rule_helper(_op, inputs, t); \ } FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) #undef CREATE_IDENTITY_OP_RULE @@ -344,22 +364,26 @@ struct FormatRuleRegistry { #undef REGISTER_IDENTITY_OP_RULE } // namespace -std::vector FormatTransformation::apply_transformation( +ValueRefList FormatTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* apply_op = op.as()) { // all inputs should be FormattedTensorValue auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); if (iter != format_rules.end()) { - return iter->second(apply_op->op(), inputs, m_auto_convert); + return iter->second(apply_op->op(), inputs, m_auto_convert, *this); } else { return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); } } else if (auto* create_tensor = op.as()) { auto format = create_tensor->format(); - return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)}; + return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; } else if (auto* get_attr = op.as()) { - auto* src = inputs.as_array<1>()[0].as(); - if (!m_auto_convert || !src || src->format() != FT::NHWC) { + auto&& input = inputs.item(); + if (!input.is(m_value_type)) { + return imperative::apply(op, input); + } + auto& src = input.cast(m_value_type); + if (!(m_auto_convert && src.format() == FT::NHWC)) { return imperative::apply(op, unwrap_inputs(inputs)); } switch (get_attr->attr()) { @@ -369,16 +393,16 @@ std::vector FormatTransformation::apply_transformation( return {ShapeValue::make(shape)}; } case GetAttr::Value: { - auto nchw_src = unwrap_input(src->to(FT::NCHW, "")); - return imperative::apply(op, std::vector{nchw_src}); + auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); + return imperative::apply(op, SmallVector{nchw_src}); } default: return imperative::apply(op, unwrap_inputs(inputs)); } } else if (op.is()) { - bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + bool is_formatted_tensor = inputs.item().is(m_value_type); if (is_formatted_tensor) { - return {FormatValue::make(inputs[0].cast().format())}; + return {FormatValue::make(inputs[0].cast(m_value_type).format())}; } else { mgb_log_warn( "Not FormattedTensorValue input for GetFormat op: %s", @@ -386,9 +410,9 @@ std::vector FormatTransformation::apply_transformation( return {FormatValue::make(FT::DEFAULT)}; } } else if (op.is()) { - bool is_formatted_tensor = inputs.as_array<1>()[0].is(); + bool is_formatted_tensor = inputs.item().is(m_value_type); if (is_formatted_tensor) { - auto& format = inputs[0].cast().format(); + auto&& format = inputs[0].cast(m_value_type).format(); return wrap_outputs( imperative::apply(op, unwrap_inputs(inputs)), format.type()); } else { diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h index 81fd83bfaee5c735f7ebdae988908a53326cd81b..0d1b5c5939495a035b710a35bf8270c188df9d8f 100644 --- a/imperative/src/include/megbrain/imperative/transformations/format.h +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -7,7 +7,7 @@ namespace mgb::imperative { -class FormattedTensorValue final : public ValueImpl { +class FormattedTensorValue final : public ObjectValue { private: ValueRef m_value; Format m_format; @@ -26,10 +26,6 @@ public: const Format& format() const { return m_format; } - TypedValueRef as(const Format::Type& target) const; - TypedValueRef to( - const Format::Type& target, const std::string& scope = "") const; - void clear() override { m_value = {}; m_format = {}; @@ -40,23 +36,18 @@ public: void on_unwatch() override { m_value.unwatch(); } }; -/** - * \brief simulates scalar because megbrain graph system don't support scalar - * - * Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'. - * This transformation simulates scalars with a flag. If a value is ScalarValue, it is - * scalar, vice versa. So there is not scalar down this layer. - */ class FormatTransformation final : public Transformation { private: - bool m_auto_convert = false; + // enable auto_convert by default to be easier to use. + bool m_auto_convert = true; + ObjectType m_value_type{"FormattedTensorValue"}; public: - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is()); + mgb_assert(!value.is(m_value_type)); return value; } @@ -65,6 +56,22 @@ public: } void set_auto_convert(bool enabled) { m_auto_convert = enabled; } bool get_auto_convert() const { return m_auto_convert; } + + const Type& value_type() const { return m_value_type; } + + inline ValueRef unwrap_input(const ValueRef& input) const; + inline ValueRefList unwrap_inputs(const Span& inputs) const; + inline ValueRef wrap_output( + const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const; + inline ValueRefList wrap_outputs( + const ValueRefList& outputs, + Format::Type type = Format::Type::DEFAULT) const; + + TypedValueRef as( + const FormattedTensorValue&, const Format::Type& target) const; + TypedValueRef to( + const FormattedTensorValue&, const Format::Type& target, + const std::string& scope = "") const; }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index ecf63b48d654b3a11a77672ffc2c4b4e5b4f6223..9543aa49e3d5b0a1453f4e647cc103102b40ba8d 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -67,6 +67,7 @@ template class Type : public IType { protected: Type(std::string name) : IType(std::move(name)) {} + Type(IType&& type) : IType(std::move(type)) {} // TODO: each type owns an allocator public: @@ -104,6 +105,7 @@ template class ObjectType : public Type { public: ObjectType(std::string name) : Type(name) {} + ObjectType(IType&& type) : Type(std::move(type)) {} }; /**