提交 8dbc602e 编写于 作者: M Megvii Engine Team

fix(imperative): fix ConvBwdData layout issue

GitOrigin-RevId: e3657f8829adf2288e9822d7dfd7e281961bb081
上级 cafc3be9
......@@ -7,12 +7,14 @@ import numpy as np
import pytest
from utils import opr_test
import megengine as mge
import megengine.amp as amp
import megengine.config as config
import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
import megengine.jit as jit
import megengine.module as M
from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape
......@@ -1637,6 +1639,16 @@ def test_conv_transpose2d():
output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32)
)
@mge.jit.trace()
def func():
deconv = M.ConvTranspose2d(16, 33, (3, 5), (2, 3), (3, 4))
x = Tensor(np.random.rand(20, 16, 50, 100))
for i in range(20):
y = deconv(x._broadcast(F.concat([x.shape, x.shape])[:4]))
mge._sync()
func()
def test_conv_transpose3d():
m = ConvTranspose3d(
......
......@@ -174,13 +174,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if (filter.ndim && diff.ndim) {
// deduce_layout won't override existing dtype
dnn_opr.opr().deduce_layout(filter, diff, output_layout);
if (inputs.size() == 3) {
if (!inputs[2].value.empty()) {
cg::copy_tensor_value_to_shape(output_layout, inputs[2].value);
output_layout.init_contiguous_stride();
} else {
output_layout.ndim = 0;
}
} else {
dnn_opr.opr().deduce_dtype(filter.dtype, diff.dtype, output_layout.dtype);
}
if (inputs.size() == 3) {
if (!inputs[2].value.empty()) {
cg::copy_tensor_value_to_shape(output_layout, inputs[2].value);
output_layout.init_contiguous_stride();
} else {
output_layout.ndim = 0;
}
}
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
......@@ -202,8 +204,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
if (inputs.size() == 3) {
cg::copy_tensor_value_to_shape(
out_layout, inputs[2]->get_value().proxy_to_default_cpu());
out_layout.init_contiguous_stride();
} else {
dnn_opr.op()->deduce_layout(
inputs[0]->layout(), inputs[1]->layout(), out_layout);
}
out_layout.init_contiguous_stride();
return out_layout;
}
}();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册