提交 d57a0712 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add fallback for op not supported for nhwc tensor

GitOrigin-RevId: 8411ce7bdc39c262e2067179429304961bd52a68
上级 38a9aa9f
......@@ -224,7 +224,6 @@ def test_interpolate(mode, is_symbolic):
assert rst.format == x.format
return rst.numpy()
# NHWC interpolate only suppoted channel is 1 or 3
data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
_compare_nchw_nhwc(data, func, is_symbolic)
......@@ -331,6 +330,35 @@ def test_pooling2d(pooling, is_symbolic):
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.skip("not implemented")
def test_where():
def func(x):
mask = tensor(
np.array([True, False, False, True] * 6, dtype=np.bool).reshape(
(1, 2, 3, 4)
)
)
y = tensor(
np.array([1, np.inf, np.nan, 4] * 6, dtype=np.float32).reshape((1, 2, 3, 4))
)
out = F.where(mask, x, y)
assert out.format == "default"
return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def test_unsupported_op():
def func(x):
rst = F.nn.pad(x, pad_width=((1, 1),), mode="constant")
assert rst.format == "default"
return rst.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func)
def _compare_backward(inps, model, is_symbolic=None):
def func(*inps):
return model(*inps)
......
......@@ -18,20 +18,20 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const FormattedTensorValue& tensor, const FT& target,
const std::string& scope) const {
std::vector<int32_t> pattern;
if (tensor.format() == FT::NHWC && target == FT::NCHW) {
Format format = tensor.format();
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) {
// FIXME(czh): temporary fast path for group conv 5D weight.
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
pattern = {0, 1, 4, 2, 3};
} else {
pattern = {0, 3, 1, 2};
}
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) {
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
tensor.format().to_string().c_str(),
Format(target).to_string().c_str());
format.to_string().c_str(), Format(target).to_string().c_str());
}
auto output =
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0];
......@@ -84,7 +84,7 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
return out;
} else {
mgb_throw(
MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).",
MegBrainError, "Unsupported shape ndim %lu in GetAttr(Shape).",
shape.ndim);
}
}
......@@ -189,7 +189,7 @@ ValueRefList reshape_rule(
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
......@@ -204,7 +204,7 @@ ValueRefList reshape_rule(
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
......@@ -229,7 +229,7 @@ ValueRefList broadcast_rule(
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
......@@ -244,7 +244,7 @@ ValueRefList broadcast_rule(
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
......@@ -323,7 +323,7 @@ ValueRefList setsubtensor_rule(
auto nhwc_inputs = ValueRefList(inputs.size());
if (format == FT::DEFAULT || format == FT::NCHW) {
// value for setsubtensor should transpose to match shape.
auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC);
auto nhwc_value = t.to(value, FT::NHWC);
// make new inputs for setsubtensor
nhwc_inputs[0] = src.value();
nhwc_inputs[1] = nhwc_value->value();
......@@ -355,14 +355,15 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation&
return format;
}
inline ValueRefList unify_nhwc_inputs(
Span<ValueRef>& inputs, std::string scope, const FormatTransformation& t) {
inline ValueRefList unify_inputs_format(
const Span<ValueRef>& inputs, const FT& dst_fmt, const std::string& scope,
const FormatTransformation& t) {
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != FT::NHWC &&
if (inp.format() != dst_fmt &&
inp.value().shape().cast<ShapeValue>().ndim == 4) {
unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope);
unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else {
unified_inputs[i] = inputs[i];
}
......@@ -375,7 +376,7 @@ ValueRefList elemwise_rule(
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
if (format == FT::NHWC && auto_convert) {
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format);
}
......@@ -389,7 +390,7 @@ ValueRefList concat_rule(
if (!(format == FT::NHWC && auto_convert)) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t);
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
// TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis;
if (axis == 2 || axis == 3) {
......@@ -460,6 +461,12 @@ ValueRefList adaptive_pooling_rule(
#define FOREACH_FORMAT_POLICY_OP(cb) \
cb(Pooling) \
cb(Convolution)
#define FOREACH_BYPASS_OP(cb) \
cb(ParamPackSplit) \
cb(ParamPackConcat) \
cb(CollectiveComm) \
cb(CheckNonFinite)
// clang-format on
// multi inputs op without params
......@@ -517,6 +524,15 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
}
FOREACH_FORMAT_POLICY_OP(CREATE_FORMAT_POLICY_OP_RULE)
#define CREATE_BYPASS_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
return t.wrap_outputs(imperative::apply(_op, t.unwrap_inputs(inputs))); \
}
FOREACH_BYPASS_OP(CREATE_BYPASS_OP_RULE)
#undef CREATE_BYPASS_OP_RULE
#undef CREATE_FORMAT_OP_RULE
#define REGISTER_OP_RULE(op) register_format_rule(op##_rule);
struct FormatRuleRegistry {
......@@ -536,6 +552,7 @@ struct FormatRuleRegistry {
FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE)
FOREACH_BYPASS_OP(REGISTER_OP_RULE)
}
} _;
#undef REGISTER_OP_RULE
......@@ -549,10 +566,13 @@ ValueRefList FormatTransformation::apply_transformation(
if (iter != format_rules.end()) {
return iter->second(apply_op->op(), inputs, m_auto_convert, *this);
} else {
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)));
auto unified_inputs = unify_inputs_format(
inputs, FT::DEFAULT, apply_op->op().scope(), *this);
return wrap_outputs(imperative::apply(op, unwrap_inputs(unified_inputs)));
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format();
// TODO: add dimshuffle for nhwc format
return {wrap_output(imperative::apply(op, inputs)[0], format)};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto&& input = inputs.item();
......@@ -570,7 +590,7 @@ ValueRefList FormatTransformation::apply_transformation(
return {ShapeValue::make(shape)};
}
case GetAttr::Value: {
auto nchw_src = unwrap_input(to(src, FT::NCHW, ""));
auto nchw_src = unwrap_input(to(src, FT::DEFAULT, ""));
return imperative::apply(op, {nchw_src});
}
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册