提交 38a9aa9f 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add auto dimshuffle for elemwise and concat

GitOrigin-RevId: 6e3df4e064675fb1adeb85a72ea8e0b276a9660d
上级 cd263765
...@@ -193,7 +193,10 @@ def test_typecvt(is_symbolic): ...@@ -193,7 +193,10 @@ def test_typecvt(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_elemwise(is_symbolic): def test_elemwise(is_symbolic):
def elemwise(x): def elemwise(x):
return (x * 2 + x / 2).numpy() tmp = F.ones((1, 2, 3, 4))
oup = x * tmp + x / 2
assert oup.format == x.format
return oup.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4)) data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, elemwise, is_symbolic) _compare_nchw_nhwc(data, elemwise, is_symbolic)
...@@ -202,7 +205,8 @@ def test_elemwise(is_symbolic): ...@@ -202,7 +205,8 @@ def test_elemwise(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_concat(is_symbolic): def test_concat(is_symbolic):
def func(x): def func(x):
rst = F.concat([x / 2, x * 2], axis=1) tmp = F.ones((1, 2, 3, 4))
rst = F.concat([x / 2, tmp], axis=1)
assert rst.format == x.format assert rst.format == x.format
return rst.numpy() return rst.numpy()
......
...@@ -355,6 +355,33 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& ...@@ -355,6 +355,33 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation&
return format; return format;
} }
inline ValueRefList unify_nhwc_inputs(
Span<ValueRef>& inputs, std::string scope, const FormatTransformation& t) {
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != FT::NHWC &&
inp.value().shape().cast<ShapeValue>().ndim == 4) {
unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope);
} else {
unified_inputs[i] = inputs[i];
}
}
return unified_inputs;
}
ValueRefList elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
if (format == FT::NHWC && auto_convert) {
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format);
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
ValueRefList concat_rule( ValueRefList concat_rule(
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert, const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) { const FormatTransformation& t) {
...@@ -362,6 +389,7 @@ ValueRefList concat_rule( ...@@ -362,6 +389,7 @@ ValueRefList concat_rule(
if (!(format == FT::NHWC && auto_convert)) { if (!(format == FT::NHWC && auto_convert)) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
} }
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
// TODO: handle 5D NHWC Tensor from group conv // TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis; auto axis = op.axis;
if (axis == 2 || axis == 3) { if (axis == 2 || axis == 3) {
...@@ -372,7 +400,7 @@ ValueRefList concat_rule( ...@@ -372,7 +400,7 @@ ValueRefList concat_rule(
return t.wrap_outputs( return t.wrap_outputs(
imperative::apply( imperative::apply(
*Concat::make(axis, op.comp_node, op.scope()), *Concat::make(axis, op.comp_node, op.scope()),
t.unwrap_inputs(inputs)), t.unwrap_inputs(unified_inputs)),
format); format);
} }
...@@ -415,7 +443,6 @@ ValueRefList adaptive_pooling_rule( ...@@ -415,7 +443,6 @@ ValueRefList adaptive_pooling_rule(
// clang-format off // clang-format off
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ #define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \
cb(Elemwise) \
cb(CompiledOp) \ cb(CompiledOp) \
cb(SubgraphOp) cb(SubgraphOp)
...@@ -501,6 +528,7 @@ struct FormatRuleRegistry { ...@@ -501,6 +528,7 @@ struct FormatRuleRegistry {
register_format_rule(subtensor_rule<IndexingMultiAxisVec>); register_format_rule(subtensor_rule<IndexingMultiAxisVec>);
register_format_rule(setsubtensor_rule<SetSubtensor>); register_format_rule(setsubtensor_rule<SetSubtensor>);
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(elemwise_rule);
register_format_rule(concat_rule); register_format_rule(concat_rule);
register_format_rule(batchnorm_rule); register_format_rule(batchnorm_rule);
register_format_rule(adaptive_pooling_rule); register_format_rule(adaptive_pooling_rule);
...@@ -515,7 +543,6 @@ struct FormatRuleRegistry { ...@@ -515,7 +543,6 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
// mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue // all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册