提交 8dbc602e 编写于 作者: M Megvii Engine Team

fix(imperative): fix ConvBwdData layout issue

GitOrigin-RevId: e3657f8829adf2288e9822d7dfd7e281961bb081
上级 cafc3be9
...@@ -7,12 +7,14 @@ import numpy as np ...@@ -7,12 +7,14 @@ import numpy as np
import pytest import pytest
from utils import opr_test from utils import opr_test
import megengine as mge
import megengine.amp as amp import megengine.amp as amp
import megengine.config as config import megengine.config as config
import megengine.core.ops.builtin as builtin import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
import megengine.jit as jit import megengine.jit as jit
import megengine.module as M
from megengine import Parameter, Tensor, is_cuda_available, tensor from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
...@@ -1637,6 +1639,16 @@ def test_conv_transpose2d(): ...@@ -1637,6 +1639,16 @@ def test_conv_transpose2d():
output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32) output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32)
) )
@mge.jit.trace()
def func():
deconv = M.ConvTranspose2d(16, 33, (3, 5), (2, 3), (3, 4))
x = Tensor(np.random.rand(20, 16, 50, 100))
for i in range(20):
y = deconv(x._broadcast(F.concat([x.shape, x.shape])[:4]))
mge._sync()
func()
def test_conv_transpose3d(): def test_conv_transpose3d():
m = ConvTranspose3d( m = ConvTranspose3d(
......
...@@ -174,13 +174,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -174,13 +174,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if (filter.ndim && diff.ndim) { if (filter.ndim && diff.ndim) {
// deduce_layout won't override existing dtype // deduce_layout won't override existing dtype
dnn_opr.opr().deduce_layout(filter, diff, output_layout); dnn_opr.opr().deduce_layout(filter, diff, output_layout);
if (inputs.size() == 3) { } else {
if (!inputs[2].value.empty()) { dnn_opr.opr().deduce_dtype(filter.dtype, diff.dtype, output_layout.dtype);
cg::copy_tensor_value_to_shape(output_layout, inputs[2].value); }
output_layout.init_contiguous_stride(); if (inputs.size() == 3) {
} else { if (!inputs[2].value.empty()) {
output_layout.ndim = 0; cg::copy_tensor_value_to_shape(output_layout, inputs[2].value);
} output_layout.init_contiguous_stride();
} else {
output_layout.ndim = 0;
} }
} }
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
...@@ -202,8 +204,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -202,8 +204,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
if (inputs.size() == 3) { if (inputs.size() == 3) {
cg::copy_tensor_value_to_shape( cg::copy_tensor_value_to_shape(
out_layout, inputs[2]->get_value().proxy_to_default_cpu()); out_layout, inputs[2]->get_value().proxy_to_default_cpu());
out_layout.init_contiguous_stride(); } else {
dnn_opr.op()->deduce_layout(
inputs[0]->layout(), inputs[1]->layout(), out_layout);
} }
out_layout.init_contiguous_stride();
return out_layout; return out_layout;
} }
}(); }();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册