From fd41302cc19b6ff71157276b2f7d03cf3c7ea09c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Mar 2022 11:01:21 +0800 Subject: [PATCH] feat(imperative/amp): add set_format GitOrigin-RevId: 91de6f49de7dc334cbe17f2a29e11a8f40ee79d6 --- .../python/megengine/amp/convert_format.py | 5 +++- imperative/python/megengine/functional/nn.py | 2 ++ .../python/megengine/module/normalization.py | 11 ++++++-- imperative/python/megengine/tensor.py | 6 +++- imperative/python/src/tensor.cpp | 9 +++++- imperative/python/src/tensor.h | 6 ++++ .../test/unit/core/test_formatted_tensor.py | 3 ++ imperative/src/impl/basic_operators.cpp | 9 +++++- .../src/impl/transformations/format.cpp | 28 +++++++++++-------- .../megbrain/imperative/basic_operators.h | 14 +++++++++- .../imperative/transformations/format.h | 6 ++-- 11 files changed, 78 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py index 3eca860af..2c202c46b 100644 --- a/imperative/python/megengine/amp/convert_format.py +++ b/imperative/python/megengine/amp/convert_format.py @@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): # TODO: use initialization from tensor after fixing format setting if x.format != "nhwc": if inplace: + # reset will destroy backward grad data = x.numpy().transpose(*pattern) x[...] = Tensor(data, format="nhwc") else: - x = Tensor(x.numpy().transpose(*pattern), format="nhwc") + # use mge interface to maintain grad + x = F.transpose(x, pattern) + x.format="nhwc" return x diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 34def5e2d..5bb7e3120 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -245,6 +245,8 @@ def conv2d( sparse_type = "dense" if groups == 1 else "group" compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + with _config._override(auto_format_convert=False): + print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) op = builtin.Convolution( stride_h=stride_h, stride_w=stride_w, diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 2f6652aa1..ae6aa72be 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -1,5 +1,6 @@ import numpy as np +import megengine as mge import megengine.functional as F from megengine import Parameter @@ -34,6 +35,7 @@ class GroupNorm(Module): def forward(self, x): N, C, H, W = x.shape + format = x.format assert C == self.num_channels x = x.reshape(N, self.num_groups, -1) @@ -44,7 +46,9 @@ class GroupNorm(Module): x = x.reshape(N, C, H, W) if self.affine: x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) - + # FIXME(czh): remove this after making it a builtin op. + if format == "nhwc": + x = mge.amp.convert_tensor_format(x, inplace=False) return x def _module_info_string(self) -> str: @@ -81,6 +85,7 @@ class InstanceNorm(Module): def forward(self, x): N, C, H, W = x.shape + format = x.format assert C == self.num_channels x = x.reshape(N, C, -1) mean = x.mean(axis=2, keepdims=True) @@ -90,7 +95,9 @@ class InstanceNorm(Module): x = x.reshape(N, C, H, W) if self.affine: x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) - + # FIXME(czh): remove this after making it a builtin op. + if format == "nhwc": + x = mge.amp.convert_tensor_format(x, inplace=False) return x def _module_info_string(self) -> str: diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 0aec35a51..781258670 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin): @property def format(self) -> str: - return super().format + return super().format() + + @format.setter + def format(self, format): + super()._set_format(format) @property def qparams(self): diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 83b6b40ca..9a3c7ab3c 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) { module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow(obj); } +void TensorWrapper::_set_format(PyObject* dest) { + auto py_dest = py::reinterpret_borrow(dest); + auto format = py_dest.cast(); + m_tensor->set_format(format); +} + void TensorWrapper::_set_name(PyObject* dest) { auto py_dest = py::reinterpret_borrow(dest); auto name = py_dest.cast(); @@ -812,7 +818,7 @@ void init_tensor(py::module m) { .def_getset<&TensorWrapper::shape>("shape") .def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::device>("device") - .def_getset<&TensorWrapper::format>("format") + .def<&TensorWrapper::format>("format") .def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::isscalar>("_isscalar") .def<&TensorWrapper::detach>("detach") @@ -820,6 +826,7 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_detail>("_detail") + .def<&TensorWrapper::_set_format>("_set_format") .def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_watch>("_watch") .def<&TensorWrapper::_var>("var") diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 7e243631e..1f849f2b4 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -59,6 +59,11 @@ public: return *shape; } inline Format format() { return *data().format(); } + inline void set_format(std::string format) { + if (!format.empty()) { + m_data = imperative::apply(SetFormat(format), m_data)[0]; + } + } inline HostValue::ref_t numpy() { return data().numpy(); } inline void reset(ValueRef value) { m_data = value; @@ -130,6 +135,7 @@ public: PyObject* copied(); PyObject* module_trace_info(); void set_module_trace_info(PyObject*); + void _set_format(PyObject*); void _set_name(PyObject*); PyObject* _detail(); PyObject* _var(); diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index d1f3d4cfb..4732a4f0e 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -31,6 +31,9 @@ def test_basic(): b[...] = tensor(data, format="nchw") assert b.format == "nchw" + # set tensor's format + b.format = "nhwc" + assert b.format == "nhwc" def _compare_nchw_nhwc(data, func, is_symbolic=None): x1 = tensor(data) diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index a91c228f2..52f48b64c 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -105,9 +105,16 @@ std::string IsScalar::to_string() const { return "IsScalar"; } +std::string GetFormat::to_string() const { + return "GetFormat{}"; +} + +std::string SetFormat::to_string() const { + return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str()); +} + std::string GetVarVal::to_string() const { return "GetVarVal"; } - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 54de6cb53..72ce22415 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs( } inline ValueRef FormatTransformation::wrap_output( - const ValueRef& output, FT type) const { - return m_value_type.make(output, type); + const ValueRef& output, Format format) const { + return m_value_type.make(output, format); } inline ValueRefList FormatTransformation::wrap_outputs( - const ValueRefList& outputs, FT type) const { + const ValueRefList& outputs, Format format) const { ValueRefList wrapped_outputs(outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { - wrapped_outputs[i] = wrap_output(outputs[i], type); + wrapped_outputs[i] = wrap_output(outputs[i], format); } return wrapped_outputs; } @@ -241,7 +241,7 @@ ValueRefList subtensor_rule( if (!(auto_convert && src.format() == FT::NHWC)) { return {t.wrap_output( imperative::apply(op, t.unwrap_inputs(inputs))[0], - src.format().type())}; + src.format())}; } auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); auto outputs = imperative::apply( @@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule( if (!(auto_convert && src.format() == FT::NHWC)) { return {t.wrap_output( imperative::apply(op, t.unwrap_inputs(inputs))[0], - src.format().type())}; + src.format())}; } // value has been broadcasted to src's fake NCHW shape. auto& value = inputs[1].cast(t.value_type()); @@ -330,7 +330,7 @@ ValueRefList identity_rule_helper( // mgb_assert(inputs.size() == 1); auto& src = inputs[0].cast(t.value_type()); return t.wrap_outputs( - imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); + imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); } ValueRefList batchnorm_rule( @@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation( } } else if (auto* create_tensor = op.as()) { auto format = create_tensor->format(); - return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; + return {wrap_output(imperative::apply(op, inputs)[0], format)}; } else if (auto* get_attr = op.as()) { auto&& input = inputs.item(); if (!input.is(m_value_type)) { @@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation( op.to_string().c_str(), inputs[0].to_string().c_str()); return {FormatValue::make(FT::DEFAULT)}; } + } else if (auto* _op = op.as()) { + auto&& inp_ref = inputs[0].as_ref(m_value_type); + mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); + return {m_value_type.make(inp_ref->value(), _op->format())}; } else if (op.is()) { auto&& inp_ref = inputs[0].as_ref(m_value_type); if (inp_ref) { auto&& format = inp_ref->format(); return wrap_outputs( - imperative::apply(op, unwrap_inputs(inputs)), format.type()); + imperative::apply(op, unwrap_inputs(inputs)), format); } else { mgb_log_warn( "Not FormattedTensorValue input for IdentityLike op: %s, %s", @@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation( GenericFunction new_callback = [this, callback, format](Span inputs_) -> ValueRefList { auto wrapped_inputs = SmallVector{ - this->value_type().make(inputs_.item(), format.type())}; + this->value_type().make(inputs_.item(), format)}; auto ret = callback(wrapped_inputs); return ret; }; auto&& outputs = imperative::apply( op, inp_ref->value(), FunctionValue::make(new_callback)); - return wrap_outputs(outputs, format.type()); + return wrap_outputs(outputs, format); } else { mgb_log_warn( "Not FormattedTensorValue input for AttachGrad op: %s, %s", @@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation( for (size_t i = 0; i < nr_outputs; ++i) { if (auto output_ref = outputs_[i].as_ref(m_value_type)) { wrapped_outputs[i] = - m_value_type.make(outputs[i], output_ref->format().type()); + m_value_type.make(outputs[i], output_ref->format()); } else { mgb_log_warn( "Not FormattedTensorValue outputs for SetGrad op: %s, %s", diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index 4d35c746f..5d7858579 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -164,7 +164,19 @@ public: class GetFormat final : public OperatorImpl { public: - std::string to_string() const override { return "GetFormat{}"; } + std::string to_string() const override; +}; + +class SetFormat final : public OperatorImpl { +private: + Format m_format; + +public: + SetFormat(std::string format) : m_format(format) {} + + Format format() const { return m_format; } + + std::string to_string() const override; }; class GetVarVal final : public OperatorImpl { diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h index 90427ee2b..95a18e1c8 100644 --- a/imperative/src/include/megbrain/imperative/transformations/format.h +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -26,6 +26,8 @@ public: const Format& format() const { return m_format; } + void set_format(Format format) { m_format = format; } + void clear() override { m_value = {}; m_format = {}; @@ -65,10 +67,10 @@ public: inline ValueRef unwrap_input(const ValueRef& input) const; inline ValueRefList unwrap_inputs(const Span& inputs) const; inline ValueRef wrap_output( - const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const; + const ValueRef& output, Format format = Format::Type::DEFAULT) const; inline ValueRefList wrap_outputs( const ValueRefList& outputs, - Format::Type type = Format::Type::DEFAULT) const; + Format format = Format::Type::DEFAULT) const; TypedValueRef as( const FormattedTensorValue&, const Format::Type& target) const; -- GitLab