提交 fd41302c 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add set_format

GitOrigin-RevId: 91de6f49de7dc334cbe17f2a29e11a8f40ee79d6
上级 fc633ce4
...@@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): ...@@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
# TODO: use initialization from tensor after fixing format setting # TODO: use initialization from tensor after fixing format setting
if x.format != "nhwc": if x.format != "nhwc":
if inplace: if inplace:
# reset will destroy backward grad
data = x.numpy().transpose(*pattern) data = x.numpy().transpose(*pattern)
x[...] = Tensor(data, format="nhwc") x[...] = Tensor(data, format="nhwc")
else: else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") # use mge interface to maintain grad
x = F.transpose(x, pattern)
x.format="nhwc"
return x return x
......
...@@ -245,6 +245,8 @@ def conv2d( ...@@ -245,6 +245,8 @@ def conv2d(
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
with _config._override(auto_format_convert=False):
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format)
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
......
import numpy as np import numpy as np
import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Parameter from megengine import Parameter
...@@ -34,6 +35,7 @@ class GroupNorm(Module): ...@@ -34,6 +35,7 @@ class GroupNorm(Module):
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
format = x.format
assert C == self.num_channels assert C == self.num_channels
x = x.reshape(N, self.num_groups, -1) x = x.reshape(N, self.num_groups, -1)
...@@ -44,7 +46,9 @@ class GroupNorm(Module): ...@@ -44,7 +46,9 @@ class GroupNorm(Module):
x = x.reshape(N, C, H, W) x = x.reshape(N, C, H, W)
if self.affine: if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
# FIXME(czh): remove this after making it a builtin op.
if format == "nhwc":
x = mge.amp.convert_tensor_format(x, inplace=False)
return x return x
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
...@@ -81,6 +85,7 @@ class InstanceNorm(Module): ...@@ -81,6 +85,7 @@ class InstanceNorm(Module):
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
format = x.format
assert C == self.num_channels assert C == self.num_channels
x = x.reshape(N, C, -1) x = x.reshape(N, C, -1)
mean = x.mean(axis=2, keepdims=True) mean = x.mean(axis=2, keepdims=True)
...@@ -90,7 +95,9 @@ class InstanceNorm(Module): ...@@ -90,7 +95,9 @@ class InstanceNorm(Module):
x = x.reshape(N, C, H, W) x = x.reshape(N, C, H, W)
if self.affine: if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
# FIXME(czh): remove this after making it a builtin op.
if format == "nhwc":
x = mge.amp.convert_tensor_format(x, inplace=False)
return x return x
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
......
...@@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin):
@property @property
def format(self) -> str: def format(self) -> str:
return super().format return super().format()
@format.setter
def format(self, format):
super()._set_format(format)
@property @property
def qparams(self): def qparams(self):
......
...@@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) { ...@@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) {
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
} }
void TensorWrapper::_set_format(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest);
auto format = py_dest.cast<std::string>();
m_tensor->set_format(format);
}
void TensorWrapper::_set_name(PyObject* dest) { void TensorWrapper::_set_name(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest); auto py_dest = py::reinterpret_borrow<py::object>(dest);
auto name = py_dest.cast<std::string>(); auto name = py_dest.cast<std::string>();
...@@ -812,7 +818,7 @@ void init_tensor(py::module m) { ...@@ -812,7 +818,7 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::shape>("shape") .def_getset<&TensorWrapper::shape>("shape")
.def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device") .def_getset<&TensorWrapper::device>("device")
.def_getset<&TensorWrapper::format>("format") .def<&TensorWrapper::format>("format")
.def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("_isscalar") .def<&TensorWrapper::isscalar>("_isscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
...@@ -820,6 +826,7 @@ void init_tensor(py::module m) { ...@@ -820,6 +826,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::_detail>("_detail") .def<&TensorWrapper::_detail>("_detail")
.def<&TensorWrapper::_set_format>("_set_format")
.def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_set_name>("_set_name")
.def<&TensorWrapper::_watch>("_watch") .def<&TensorWrapper::_watch>("_watch")
.def<&TensorWrapper::_var>("var") .def<&TensorWrapper::_var>("var")
......
...@@ -59,6 +59,11 @@ public: ...@@ -59,6 +59,11 @@ public:
return *shape; return *shape;
} }
inline Format format() { return *data().format(); } inline Format format() { return *data().format(); }
inline void set_format(std::string format) {
if (!format.empty()) {
m_data = imperative::apply(SetFormat(format), m_data)[0];
}
}
inline HostValue::ref_t numpy() { return data().numpy(); } inline HostValue::ref_t numpy() { return data().numpy(); }
inline void reset(ValueRef value) { inline void reset(ValueRef value) {
m_data = value; m_data = value;
...@@ -130,6 +135,7 @@ public: ...@@ -130,6 +135,7 @@ public:
PyObject* copied(); PyObject* copied();
PyObject* module_trace_info(); PyObject* module_trace_info();
void set_module_trace_info(PyObject*); void set_module_trace_info(PyObject*);
void _set_format(PyObject*);
void _set_name(PyObject*); void _set_name(PyObject*);
PyObject* _detail(); PyObject* _detail();
PyObject* _var(); PyObject* _var();
......
...@@ -31,6 +31,9 @@ def test_basic(): ...@@ -31,6 +31,9 @@ def test_basic():
b[...] = tensor(data, format="nchw") b[...] = tensor(data, format="nchw")
assert b.format == "nchw" assert b.format == "nchw"
# set tensor's format
b.format = "nhwc"
assert b.format == "nhwc"
def _compare_nchw_nhwc(data, func, is_symbolic=None): def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data) x1 = tensor(data)
......
...@@ -105,9 +105,16 @@ std::string IsScalar::to_string() const { ...@@ -105,9 +105,16 @@ std::string IsScalar::to_string() const {
return "IsScalar"; return "IsScalar";
} }
std::string GetFormat::to_string() const {
return "GetFormat{}";
}
std::string SetFormat::to_string() const {
return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str());
}
std::string GetVarVal::to_string() const { std::string GetVarVal::to_string() const {
return "GetVarVal"; return "GetVarVal";
} }
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs( ...@@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs(
} }
inline ValueRef FormatTransformation::wrap_output( inline ValueRef FormatTransformation::wrap_output(
const ValueRef& output, FT type) const { const ValueRef& output, Format format) const {
return m_value_type.make(output, type); return m_value_type.make(output, format);
} }
inline ValueRefList FormatTransformation::wrap_outputs( inline ValueRefList FormatTransformation::wrap_outputs(
const ValueRefList& outputs, FT type) const { const ValueRefList& outputs, Format format) const {
ValueRefList wrapped_outputs(outputs.size()); ValueRefList wrapped_outputs(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
wrapped_outputs[i] = wrap_output(outputs[i], type); wrapped_outputs[i] = wrap_output(outputs[i], format);
} }
return wrapped_outputs; return wrapped_outputs;
} }
...@@ -241,7 +241,7 @@ ValueRefList subtensor_rule( ...@@ -241,7 +241,7 @@ ValueRefList subtensor_rule(
if (!(auto_convert && src.format() == FT::NHWC)) { if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output( return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0], imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format().type())}; src.format())};
} }
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply( auto outputs = imperative::apply(
...@@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule( ...@@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule(
if (!(auto_convert && src.format() == FT::NHWC)) { if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output( return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0], imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format().type())}; src.format())};
} }
// value has been broadcasted to src's fake NCHW shape. // value has been broadcasted to src's fake NCHW shape.
auto& value = inputs[1].cast(t.value_type()); auto& value = inputs[1].cast(t.value_type());
...@@ -330,7 +330,7 @@ ValueRefList identity_rule_helper( ...@@ -330,7 +330,7 @@ ValueRefList identity_rule_helper(
// mgb_assert(inputs.size() == 1); // mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast(t.value_type()); auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs( return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
} }
ValueRefList batchnorm_rule( ValueRefList batchnorm_rule(
...@@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation(
} }
} else if (auto* create_tensor = op.as<CreateTensor>()) { } else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format(); auto format = create_tensor->format();
return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; return {wrap_output(imperative::apply(op, inputs)[0], format)};
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto&& input = inputs.item(); auto&& input = inputs.item();
if (!input.is(m_value_type)) { if (!input.is(m_value_type)) {
...@@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation(
op.to_string().c_str(), inputs[0].to_string().c_str()); op.to_string().c_str(), inputs[0].to_string().c_str());
return {FormatValue::make(FT::DEFAULT)}; return {FormatValue::make(FT::DEFAULT)};
} }
} else if (auto* _op = op.as<SetFormat>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type);
mgb_assert(inp_ref, "Cannot set format for non-format Tensor.");
return {m_value_type.make(inp_ref->value(), _op->format())};
} else if (op.is<Operator::IdentityLike>()) { } else if (op.is<Operator::IdentityLike>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type); auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) { if (inp_ref) {
auto&& format = inp_ref->format(); auto&& format = inp_ref->format();
return wrap_outputs( return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), format.type()); imperative::apply(op, unwrap_inputs(inputs)), format);
} else { } else {
mgb_log_warn( mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s", "Not FormattedTensorValue input for IdentityLike op: %s, %s",
...@@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation(
GenericFunction new_callback = GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{ auto wrapped_inputs = SmallVector<ValueRef>{
this->value_type().make(inputs_.item(), format.type())}; this->value_type().make(inputs_.item(), format)};
auto ret = callback(wrapped_inputs); auto ret = callback(wrapped_inputs);
return ret; return ret;
}; };
auto&& outputs = imperative::apply( auto&& outputs = imperative::apply(
op, inp_ref->value(), FunctionValue::make(new_callback)); op, inp_ref->value(), FunctionValue::make(new_callback));
return wrap_outputs(outputs, format.type()); return wrap_outputs(outputs, format);
} else { } else {
mgb_log_warn( mgb_log_warn(
"Not FormattedTensorValue input for AttachGrad op: %s, %s", "Not FormattedTensorValue input for AttachGrad op: %s, %s",
...@@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation(
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
if (auto output_ref = outputs_[i].as_ref(m_value_type)) { if (auto output_ref = outputs_[i].as_ref(m_value_type)) {
wrapped_outputs[i] = wrapped_outputs[i] =
m_value_type.make(outputs[i], output_ref->format().type()); m_value_type.make(outputs[i], output_ref->format());
} else { } else {
mgb_log_warn( mgb_log_warn(
"Not FormattedTensorValue outputs for SetGrad op: %s, %s", "Not FormattedTensorValue outputs for SetGrad op: %s, %s",
......
...@@ -164,7 +164,19 @@ public: ...@@ -164,7 +164,19 @@ public:
class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> {
public: public:
std::string to_string() const override { return "GetFormat{}"; } std::string to_string() const override;
};
class SetFormat final : public OperatorImpl<SetFormat, Operator::IdentityLike> {
private:
Format m_format;
public:
SetFormat(std::string format) : m_format(format) {}
Format format() const { return m_format; }
std::string to_string() const override;
}; };
class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> {
......
...@@ -26,6 +26,8 @@ public: ...@@ -26,6 +26,8 @@ public:
const Format& format() const { return m_format; } const Format& format() const { return m_format; }
void set_format(Format format) { m_format = format; }
void clear() override { void clear() override {
m_value = {}; m_value = {};
m_format = {}; m_format = {};
...@@ -65,10 +67,10 @@ public: ...@@ -65,10 +67,10 @@ public:
inline ValueRef unwrap_input(const ValueRef& input) const; inline ValueRef unwrap_input(const ValueRef& input) const;
inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const; inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const;
inline ValueRef wrap_output( inline ValueRef wrap_output(
const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const; const ValueRef& output, Format format = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs( inline ValueRefList wrap_outputs(
const ValueRefList& outputs, const ValueRefList& outputs,
Format::Type type = Format::Type::DEFAULT) const; Format format = Format::Type::DEFAULT) const;
TypedValueRef<FormattedTensorValue> as( TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const; const FormattedTensorValue&, const Format::Type& target) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册