提交 7e9aa742 编写于 作者: M Megvii Engine Team

feat(imperative/amp): enable auto_convert_format by default

GitOrigin-RevId: 71ae311fed63b33642574056041068dd904be723
上级 fc0f4546
......@@ -1017,10 +1017,14 @@ void init_tensor(py::module m) {
using namespace std::placeholders;
self.compiled = std::make_shared<CompiledTransformation>(
*self.trace_result, self.record_input_shapes);
self.compiled->set_value_comparator(
std::bind(&Trace::compare_value, this, _1, _2));
self.options_visitor(py::cast(&self.compiled->options()));
self.compiled->compile();
self.compiled->set_value_comparator(
std::bind(&Trace::compare_value, this, _1, _2));
self.options_visitor(py::cast(&self.compiled->options()));
try {
self.compiled->compile();
} catch (const std::exception& e) {
mgb_log_error(e.what());
}
}
// register transformations
if (self.compiled) {
......
......@@ -5,6 +5,7 @@ import megengine as mge
import megengine.functional as F
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
def test_basic():
......@@ -30,26 +31,29 @@ def test_basic():
assert b.format == "nchw"
def _compare_nchw_nhwc(data, func):
def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data, format="nchw")
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
out1 = func(x1)
with mge.config._override(auto_format_convert=True):
out2 = func(x2)
out2 = func(x2)
np.testing.assert_almost_equal(out1, out2, decimal=5)
def test_dimshuffle():
@pytest.mark.parametrize("is_symbolic", [None])
def test_dimshuffle(is_symbolic):
def func(x):
out = F.transpose(x, [2, 3, 0, 1])
assert out.format == "default"
return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_reshape():
@pytest.mark.parametrize("is_symbolic", [None])
def test_reshape(is_symbolic):
# maintain NHWC format
def func(x):
out = F.reshape(x, (1, 2, 6, 2))
......@@ -58,7 +62,7 @@ def test_reshape():
return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
# not maintain NHWC format
def func2(x):
......@@ -66,18 +70,20 @@ def test_reshape():
assert out.format == "default"
return out.numpy()
_compare_nchw_nhwc(data, func2)
_compare_nchw_nhwc(data, func2, is_symbolic)
def test_flatten():
@pytest.mark.parametrize("is_symbolic", [None])
def test_flatten(is_symbolic):
def func(x):
return F.flatten(x).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_broadcast():
@pytest.mark.parametrize("is_symbolic", [None])
def test_broadcast(is_symbolic):
# maintain NHWC format
def func(x):
out = F.broadcast_to(x, (4, 3, 2, 3))
......@@ -86,7 +92,7 @@ def test_broadcast():
return out.numpy()
data = np.arange(0, 24).reshape((4, 3, 2, 1))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
# not maintain NHWC format
def func2(x):
......@@ -94,30 +100,32 @@ def test_broadcast():
assert out.format == "default"
return out.numpy()
_compare_nchw_nhwc(data, func2)
_compare_nchw_nhwc(data, func2, is_symbolic)
@pytest.mark.skip("repeat cannot maintain format yet")
def test_repeat():
@pytest.mark.parametrize("is_symbolic", [None])
def test_repeat(is_symbolic):
def func(x):
rst = F.repeat(x, 3, axis=1)
assert rst.format == x.format
return rst.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_getshape():
@pytest.mark.parametrize("is_symbolic", [None])
def test_getshape(is_symbolic):
def func(x):
return x.shape
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.skip("symbolic shape is not supported yet")
def test_get_symbolic_shape():
def test_get_symbolic_shape(is_symbolic):
from megengine.core._trace_option import set_symbolic_shape
origin_opt = set_symbolic_shape(True)
......@@ -126,77 +134,84 @@ def test_get_symbolic_shape():
return x.shape.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
set_symbolic_shape(origin_opt)
def test_getvalue():
@pytest.mark.parametrize("is_symbolic", [None])
def test_getvalue(is_symbolic):
def func(x):
return x.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_get_set_subtensor():
@pytest.mark.parametrize("is_symbolic", [None])
def test_get_set_subtensor(is_symbolic):
def get_subtensor(x):
return x[:, :1, :2, :3].numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, get_subtensor)
_compare_nchw_nhwc(data, get_subtensor, is_symbolic)
def set_subtensor(x):
x[:, :1, :2, :3] = 0
return x.numpy()
_compare_nchw_nhwc(data, set_subtensor)
_compare_nchw_nhwc(data, set_subtensor, is_symbolic)
def test_get_set_advanced_indexing():
@pytest.mark.parametrize("is_symbolic", [None])
def test_get_set_advanced_indexing(is_symbolic):
def get_advanced_indexing(x):
x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy()
return x
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, get_advanced_indexing)
_compare_nchw_nhwc(data, get_advanced_indexing, is_symbolic)
def set_advanced_indexing(x):
x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0
return x.numpy()
_compare_nchw_nhwc(data, set_advanced_indexing)
_compare_nchw_nhwc(data, set_advanced_indexing, is_symbolic)
def test_typecvt():
@pytest.mark.parametrize("is_symbolic", [None])
def test_typecvt(is_symbolic):
def typecvt(x):
return x.astype("float16").numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, typecvt)
_compare_nchw_nhwc(data, typecvt, is_symbolic)
def test_elemwise():
@pytest.mark.parametrize("is_symbolic", [None])
def test_elemwise(is_symbolic):
def elemwise(x):
return (x * 2 + x / 2).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, elemwise)
_compare_nchw_nhwc(data, elemwise, is_symbolic)
def test_concat():
@pytest.mark.parametrize("is_symbolic", [None])
def test_concat(is_symbolic):
def func(x):
rst = F.concat([x / 2, x * 2], axis=1)
assert rst.format == x.format
return rst.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.parametrize(
"mode", ["bilinear", "nearest"],
)
def test_interpolate(mode):
@pytest.mark.parametrize("is_symbolic", [None])
def test_interpolate(mode, is_symbolic):
def func(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
......@@ -208,10 +223,11 @@ def test_interpolate(mode):
# NHWC interpolate only suppoted channel is 1 or 3
data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_conv2d():
@pytest.mark.parametrize("is_symbolic", [None])
def test_conv2d(is_symbolic):
def conv2d(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
......@@ -226,10 +242,11 @@ def test_conv2d():
return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, conv2d)
_compare_nchw_nhwc(data, conv2d, is_symbolic)
def test_group_conv2d():
@pytest.mark.parametrize("is_symbolic", [None])
def test_group_conv2d(is_symbolic):
def conv2d(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
......@@ -247,10 +264,11 @@ def test_group_conv2d():
).numpy()
data = np.arange(0, 48).reshape((1, 4, 3, 4))
_compare_nchw_nhwc(data, conv2d)
_compare_nchw_nhwc(data, conv2d, is_symbolic)
def test_bn():
@pytest.mark.parametrize("is_symbolic", [None])
def test_bn(is_symbolic):
def func(x):
if x.format == "nhwc":
with mge.config._override(bn_format="dim_111c"):
......@@ -279,14 +297,15 @@ def test_bn():
)[0].numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.parametrize(
"pooling",
[F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d],
)
def test_pooling2d(pooling):
@pytest.mark.parametrize("is_symbolic", [None])
def test_pooling2d(pooling, is_symbolic):
def func(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
......@@ -297,18 +316,25 @@ def test_pooling2d(pooling):
return pooling(x.astype("float32"), 2).numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
_compare_nchw_nhwc(data, func, is_symbolic)
def test_backward():
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward(is_symbolic):
data = np.arange(0, 24).reshape((1, 2, 3, 4))
x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
gm = GradManager().attach([w, b])
def func(x, w, b):
return F.conv2d(x, w, b)
with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
x = F.conv2d(x, w, b)
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
x = func(x, w, b)
# TODO: fix manually convert to NHWC, usually used in detection head
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x)
......
......@@ -7,53 +7,62 @@ namespace imperative {
using FT = Format::Type;
TypedValueRef<FormattedTensorValue> FormattedTensorValue::as(const FT& target) const {
return FormattedTensorValue::make(m_value, target);
TypedValueRef<FormattedTensorValue> FormatTransformation::as(
const FormattedTensorValue& tensor, const FT& target) const {
return m_value_type.make(tensor.value(), target);
}
TypedValueRef<FormattedTensorValue> FormattedTensorValue::to(
const FT& target, const std::string& scope) const {
TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const FormattedTensorValue& tensor, const FT& target,
const std::string& scope) const {
std::vector<int32_t> pattern;
if (m_format == FT::NHWC && target == FT::NCHW) {
if (tensor.format() == FT::NHWC && target == FT::NCHW) {
pattern = {0, 3, 1, 2};
} else if (m_format == FT::NCHW && target == FT::NHWC) {
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
m_format.to_string().c_str(), Format(target).to_string().c_str());
tensor.format().to_string().c_str(),
Format(target).to_string().c_str());
}
auto output = imperative::apply(
*Dimshuffle::make(pattern, scope), std::vector<ValueRef>{m_value})[0];
return FormattedTensorValue::make(output, target);
*Dimshuffle::make(pattern, scope),
SmallVector<ValueRef>{tensor.value()})[0];
return m_value_type.make(output, target);
}
namespace {
ValueRef unwrap_input(const ValueRef& input) {
if (auto format_input = input.as_ref<FormattedTensorValue>()) {
inline ValueRef FormatTransformation::unwrap_input(const ValueRef& input) const {
if (auto format_input = input.as_ref(m_value_type)) {
return format_input->value();
} else {
return input;
}
}
std::vector<ValueRef> unwrap_inputs(const Span<ValueRef>& inputs) {
std::vector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
unwrapped_inputs.push_back(unwrap_input(input));
inline ValueRefList FormatTransformation::unwrap_inputs(
const Span<ValueRef>& inputs) const {
ValueRefList unwrapped_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
unwrapped_inputs[i] = unwrap_input(inputs[i]);
}
return unwrapped_inputs;
}
std::vector<ValueRef> wrap_outputs(
const std::vector<ValueRef>& outputs, FT type = FT::DEFAULT) {
std::vector<ValueRef> wrapped_outputs;
for (auto&& output : outputs) {
wrapped_outputs.push_back(FormattedTensorValue::make(output, type));
inline ValueRef FormatTransformation::wrap_output(
const ValueRef& output, FT type) const {
return m_value_type.make(output, type);
}
inline ValueRefList FormatTransformation::wrap_outputs(
const ValueRefList& outputs, FT type) const {
ValueRefList wrapped_outputs(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
wrapped_outputs[i] = wrap_output(outputs[i], type);
}
return wrapped_outputs;
}
namespace {
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
mgb_assert(shape.ndim == 4);
......@@ -64,20 +73,21 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
return out;
}
using FormatRule = std::function<std::vector<ValueRef>(
const OpDef&, Span<ValueRef>&, const bool&)>;
using FormatRule = std::function<ValueRefList(
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>;
static std::unordered_map<Typeinfo*, FormatRule> format_rules;
template <typename T>
void register_format_rule(
std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>&, const bool&)) {
void register_format_rule(ValueRefList (*rule)(
const T&, Span<ValueRef>&, const bool&, const FormatTransformation&)) {
format_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef>& inputs,
const bool& auto_convert) {
return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert);
const bool& auto_convert,
const FormatTransformation& t) {
return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert, t);
};
}
auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) {
inline auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) {
mgb_assert(pattern.size() == 4);
auto nhwc_pattern = pattern;
for (size_t idx = 0; idx < 4; ++idx) {
......@@ -93,19 +103,20 @@ auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) {
return nhwc_pattern;
}
std::vector<ValueRef> dimshuffle_rule(
const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert) {
ValueRefList dimshuffle_rule(
const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
auto& src = inputs[0].cast(t.value_type());
// Only support converting pattern from NCHW to NHWC currently.
if (auto_convert && src.format() == FT::NHWC) {
auto pattern = convert_nchw2nhwc_pattern(op.pattern);
// dimshuffle will not maintain NHWC Format
return wrap_outputs(imperative::apply(
return t.wrap_outputs(imperative::apply(
*Dimshuffle::make(std::move(pattern), op.scope()),
unwrap_inputs(inputs)));
t.unwrap_inputs(inputs)));
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
......@@ -125,53 +136,55 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
return nhwc_shape_input;
}
std::vector<ValueRef> reshape_rule(
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert) {
ValueRefList reshape_rule(
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() == 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) {
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd();
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape});
return wrap_outputs(outputs, FT::NHWC);
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = src.to(FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])});
return wrap_outputs(outputs);
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
}
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
std::vector<ValueRef> broadcast_rule(
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert) {
ValueRefList broadcast_rule(
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() == 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) {
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd();
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape});
return wrap_outputs(outputs, FT::NHWC);
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = src.to(FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])});
return wrap_outputs(outputs);
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
}
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
bool is_reduce_ndim_idx_items(
inline bool is_reduce_ndim_idx_items(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items,
const Span<ValueRef>& inputs) {
for (auto i = 0; i < items.size(); ++i) {
......@@ -184,7 +197,7 @@ bool is_reduce_ndim_idx_items(
return false;
}
auto convert_nchw2nhwc_idx_items(
inline auto convert_nchw2nhwc_idx_items(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) {
auto nhwc_items = items;
for (auto i = 0; i < nhwc_items.size(); ++i) {
......@@ -199,51 +212,55 @@ auto convert_nchw2nhwc_idx_items(
}
template <typename T>
std::vector<ValueRef> subtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) {
ValueRefList subtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
auto& src = inputs[0].cast(t.value_type());
bool is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[1], &inputs[inputs.size() - 1]});
if (!is_reduce_ndim) {
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {FormattedTensorValue::make(
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())};
return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format().type())};
}
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
*T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs));
return wrap_outputs(outputs, FT::NHWC);
*T::make(std::move(nhwc_items), op.scope()), t.unwrap_inputs(inputs));
return t.wrap_outputs(outputs, FT::NHWC);
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
template <typename T>
std::vector<ValueRef> setsubtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) {
ValueRefList setsubtensor_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() >= 2);
auto& src = inputs[0].cast<FormattedTensorValue>();
auto& src = inputs[0].cast(t.value_type());
bool is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[2], &inputs[inputs.size() - 1]});
if (!is_reduce_ndim) {
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {FormattedTensorValue::make(
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())};
return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format().type())};
}
// value has been broadcasted to src's fake NCHW shape.
auto& value = inputs[1].cast<FormattedTensorValue>();
auto& value = inputs[1].cast(t.value_type());
auto& format = value.format();
auto nhwc_inputs = std::vector<ValueRef>(inputs.size());
auto nhwc_inputs = ValueRefList(inputs.size());
if (format == FT::DEFAULT || format == FT::NCHW) {
// value for setsubtensor should transpose to match shape.
auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC);
auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC);
// make new inputs for setsubtensor
nhwc_inputs[0] = src.value();
nhwc_inputs[1] = nhwc_value->value();
for (auto i = 2; i < inputs.size(); ++i) {
nhwc_inputs[i] = inputs[i].as_ref<FormattedTensorValue>()->value();
nhwc_inputs[i] = t.unwrap_input(inputs[i]);
}
} else if (format != FT::NHWC) {
mgb_throw(
......@@ -253,15 +270,15 @@ std::vector<ValueRef> setsubtensor_rule(
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
*T::make(std::move(nhwc_items), op.scope()), nhwc_inputs);
return wrap_outputs(outputs, FT::NHWC);
return t.wrap_outputs(outputs, FT::NHWC);
}
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
FT get_inputs_format(Span<ValueRef>& inputs) {
inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) {
FT format(FT::DEFAULT);
for (auto& inp : inputs) {
auto& inp_format = inp.cast<FormattedTensorValue>().format();
auto& inp_format = inp.cast(t.value_type()).format();
if (inp_format != FT::DEFAULT) {
mgb_assert(format == FT::DEFAULT || inp_format == format);
format = inp_format.type();
......@@ -270,11 +287,12 @@ FT get_inputs_format(Span<ValueRef>& inputs) {
return format;
}
std::vector<ValueRef> concat_rule(
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert) {
FT format = get_inputs_format(inputs);
ValueRefList concat_rule(
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
if (!(format == FT::NHWC && auto_convert)) {
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
// TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis;
......@@ -283,25 +301,26 @@ std::vector<ValueRef> concat_rule(
} else if (axis == 1) {
axis = 3;
}
return wrap_outputs(
return t.wrap_outputs(
imperative::apply(
*Concat::make(axis, op.comp_node, op.scope()),
unwrap_inputs(inputs)),
t.unwrap_inputs(inputs)),
format);
}
std::vector<ValueRef> elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert) {
FT format = get_inputs_format(inputs);
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
ValueRefList elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
std::vector<ValueRef> identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs) {
ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast<FormattedTensorValue>();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), src.format().type());
auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type());
}
// clang-format off
......@@ -318,10 +337,11 @@ std::vector<ValueRef> identity_rule_helper(
cb(Identity)
// clang-format on
#define CREATE_IDENTITY_OP_RULE(op) \
std::vector<ValueRef> op##_rule( \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert) { \
return identity_rule_helper(_op, inputs); \
#define CREATE_IDENTITY_OP_RULE(op) \
ValueRefList op##_rule( \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
#undef CREATE_IDENTITY_OP_RULE
......@@ -344,22 +364,26 @@ struct FormatRuleRegistry {
#undef REGISTER_IDENTITY_OP_RULE
} // namespace
std::vector<ValueRef> FormatTransformation::apply_transformation(
ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
if (iter != format_rules.end()) {
return iter->second(apply_op->op(), inputs, m_auto_convert);
return iter->second(apply_op->op(), inputs, m_auto_convert, *this);
} else {
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format();
return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)};
return {wrap_output(imperative::apply(op, inputs)[0], format.type())};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto* src = inputs.as_array<1>()[0].as<FormattedTensorValue>();
if (!m_auto_convert || !src || src->format() != FT::NHWC) {
auto&& input = inputs.item();
if (!input.is(m_value_type)) {
return imperative::apply(op, input);
}
auto& src = input.cast(m_value_type);
if (!(m_auto_convert && src.format() == FT::NHWC)) {
return imperative::apply(op, unwrap_inputs(inputs));
}
switch (get_attr->attr()) {
......@@ -369,16 +393,16 @@ std::vector<ValueRef> FormatTransformation::apply_transformation(
return {ShapeValue::make(shape)};
}
case GetAttr::Value: {
auto nchw_src = unwrap_input(src->to(FT::NCHW, ""));
return imperative::apply(op, std::vector<ValueRef>{nchw_src});
auto nchw_src = unwrap_input(to(src, FT::NCHW, ""));
return imperative::apply(op, SmallVector<ValueRef>{nchw_src});
}
default:
return imperative::apply(op, unwrap_inputs(inputs));
}
} else if (op.is<GetFormat>()) {
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>();
bool is_formatted_tensor = inputs.item().is(m_value_type);
if (is_formatted_tensor) {
return {FormatValue::make(inputs[0].cast<FormattedTensorValue>().format())};
return {FormatValue::make(inputs[0].cast(m_value_type).format())};
} else {
mgb_log_warn(
"Not FormattedTensorValue input for GetFormat op: %s",
......@@ -386,9 +410,9 @@ std::vector<ValueRef> FormatTransformation::apply_transformation(
return {FormatValue::make(FT::DEFAULT)};
}
} else if (op.is<Operator::IdentityLike>()) {
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>();
bool is_formatted_tensor = inputs.item().is(m_value_type);
if (is_formatted_tensor) {
auto& format = inputs[0].cast<FormattedTensorValue>().format();
auto&& format = inputs[0].cast(m_value_type).format();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), format.type());
} else {
......
......@@ -7,7 +7,7 @@
namespace mgb::imperative {
class FormattedTensorValue final : public ValueImpl<FormattedTensorValue> {
class FormattedTensorValue final : public ObjectValue<FormattedTensorValue> {
private:
ValueRef m_value;
Format m_format;
......@@ -26,10 +26,6 @@ public:
const Format& format() const { return m_format; }
TypedValueRef<FormattedTensorValue> as(const Format::Type& target) const;
TypedValueRef<FormattedTensorValue> to(
const Format::Type& target, const std::string& scope = "") const;
void clear() override {
m_value = {};
m_format = {};
......@@ -40,23 +36,18 @@ public:
void on_unwatch() override { m_value.unwatch(); }
};
/**
* \brief simulates scalar because megbrain graph system don't support scalar
*
* Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'.
* This transformation simulates scalars with a flag. If a value is ScalarValue, it is
* scalar, vice versa. So there is not scalar down this layer.
*/
class FormatTransformation final : public Transformation {
private:
bool m_auto_convert = false;
// enable auto_convert by default to be easier to use.
bool m_auto_convert = true;
ObjectType<FormattedTensorValue> m_value_type{"FormattedTensorValue"};
public:
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<FormattedTensorValue>());
mgb_assert(!value.is(m_value_type));
return value;
}
......@@ -65,6 +56,22 @@ public:
}
void set_auto_convert(bool enabled) { m_auto_convert = enabled; }
bool get_auto_convert() const { return m_auto_convert; }
const Type<FormattedTensorValue>& value_type() const { return m_value_type; }
inline ValueRef unwrap_input(const ValueRef& input) const;
inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const;
inline ValueRef wrap_output(
const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs(
const ValueRefList& outputs,
Format::Type type = Format::Type::DEFAULT) const;
TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const;
TypedValueRef<FormattedTensorValue> to(
const FormattedTensorValue&, const Format::Type& target,
const std::string& scope = "") const;
};
} // namespace mgb::imperative
......@@ -67,6 +67,7 @@ template <typename T>
class Type : public IType {
protected:
Type(std::string name) : IType(std::move(name)) {}
Type(IType&& type) : IType(std::move(type)) {}
// TODO: each type owns an allocator
public:
......@@ -104,6 +105,7 @@ template <typename T>
class ObjectType : public Type<T> {
public:
ObjectType(std::string name) : Type<T>(name) {}
ObjectType(IType&& type) : Type<T>(std::move(type)) {}
};
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册