提交 c9e56f49 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add dimshuffle before creating nhwc tensor

GitOrigin-RevId: 4461f9a0d3d6428282fcd858c677263d9a08fa9f
上级 d57a0712
......@@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
if x.format != "nhwc":
if inplace:
# hostvalue should still be valid, so no d2h cost.
data = x.numpy().transpose(*pattern)
data = x.numpy()
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc")
else:
......
......@@ -38,7 +38,7 @@ def test_basic():
def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
x2 = tensor(data, format="nhwc")
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
out1 = func(x1)
......@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic):
if x.format == "nhwc":
x = F.conv2d(
x,
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"),
weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"),
)
assert x.format == "nhwc"
return x.numpy()
......@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic):
if x.format == "nhwc":
x = F.conv2d(
x,
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"),
weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"),
groups=2,
)
assert x.format == "nhwc"
......@@ -286,10 +286,10 @@ def test_bn(is_symbolic):
if x.format == "nhwc":
oups = F.batch_norm(
x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
training=True,
inplace=False,
)
......
......@@ -27,7 +27,11 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
pattern = {0, 3, 1, 2};
}
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
pattern = {0, 1, 3, 4, 2};
} else {
pattern = {0, 2, 3, 1};
}
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
......@@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation(
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format();
// TODO: add dimshuffle for nhwc format
return {wrap_output(imperative::apply(op, inputs)[0], format)};
if (format == FT::NHWC) {
auto output = wrap_output(imperative::apply(op, inputs)[0]);
output = to(output.cast(m_value_type), FT::NHWC, "");
return {output};
} else {
return {wrap_output(imperative::apply(op, inputs)[0], format)};
}
} else if (auto* get_attr = op.as<GetAttr>()) {
auto&& input = inputs.item();
if (!input.is(m_value_type)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册