提交 c9e56f49 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add dimshuffle before creating nhwc tensor

GitOrigin-RevId: 4461f9a0d3d6428282fcd858c677263d9a08fa9f
上级 d57a0712
...@@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): ...@@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
if x.format != "nhwc": if x.format != "nhwc":
if inplace: if inplace:
# hostvalue should still be valid, so no d2h cost. # hostvalue should still be valid, so no d2h cost.
data = x.numpy().transpose(*pattern) data = x.numpy()
# reset will destroy existed backward grad # reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc") x[...] = Tensor(data, format="nhwc")
else: else:
......
...@@ -38,7 +38,7 @@ def test_basic(): ...@@ -38,7 +38,7 @@ def test_basic():
def _compare_nchw_nhwc(data, func, is_symbolic=None): def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data) x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") x2 = tensor(data, format="nhwc")
if is_symbolic is not None: if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic) func = trace(func, symbolic=is_symbolic)
out1 = func(x1) out1 = func(x1)
...@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic): ...@@ -247,8 +247,8 @@ def test_conv2d(is_symbolic):
if x.format == "nhwc": if x.format == "nhwc":
x = F.conv2d( x = F.conv2d(
x, x,
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"),
) )
assert x.format == "nhwc" assert x.format == "nhwc"
return x.numpy() return x.numpy()
...@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic): ...@@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic):
if x.format == "nhwc": if x.format == "nhwc":
x = F.conv2d( x = F.conv2d(
x, x,
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"),
groups=2, groups=2,
) )
assert x.format == "nhwc" assert x.format == "nhwc"
...@@ -286,10 +286,10 @@ def test_bn(is_symbolic): ...@@ -286,10 +286,10 @@ def test_bn(is_symbolic):
if x.format == "nhwc": if x.format == "nhwc":
oups = F.batch_norm( oups = F.batch_norm(
x.astype("float32"), x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
training=True, training=True,
inplace=False, inplace=False,
) )
......
...@@ -27,7 +27,11 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( ...@@ -27,7 +27,11 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
pattern = {0, 3, 1, 2}; pattern = {0, 3, 1, 2};
} }
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
pattern = {0, 2, 3, 1}; if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
pattern = {0, 1, 3, 4, 2};
} else {
pattern = {0, 2, 3, 1};
}
} else { } else {
mgb_throw( mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s", MegBrainError, "Unsupport format conversion from %s to %s",
...@@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation(
} }
} else if (auto* create_tensor = op.as<CreateTensor>()) { } else if (auto* create_tensor = op.as<CreateTensor>()) {
auto format = create_tensor->format(); auto format = create_tensor->format();
// TODO: add dimshuffle for nhwc format if (format == FT::NHWC) {
return {wrap_output(imperative::apply(op, inputs)[0], format)}; auto output = wrap_output(imperative::apply(op, inputs)[0]);
output = to(output.cast(m_value_type), FT::NHWC, "");
return {output};
} else {
return {wrap_output(imperative::apply(op, inputs)[0], format)};
}
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto&& input = inputs.item(); auto&& input = inputs.item();
if (!input.is(m_value_type)) { if (!input.is(m_value_type)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册