diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py
index a41427eaa0c35acb8fd39cd8752eedeb5e94b2dd..1c7a8c0eaf565bcf35f4a4e9e9ddd383647e18c9 100644
--- a/imperative/python/megengine/jit/xla_backend.py
+++ b/imperative/python/megengine/jit/xla_backend.py
@@ -91,13 +91,13 @@ class xla_trace(trace):
         set_use_xla_backend(self.orig_use_xla)
 
     def convert_params_to_xla(self):
-        from ..device import coalesce_free_memory
         from ..utils.module_utils import get_expand_structure
         from ..tensor import Tensor
 
         backend = self.xla_exec.backend
         devices = backend.local_devices()
-        _, device_id, _ = CompNode(get_default_device()).physical_locator
+        default_cn = CompNode(get_default_device())
+        _, device_id, _ = default_cn.physical_locator
         device_index = (
             0 if len(devices) == 0 else [d.id for d in devices].index(device_id)
         )
@@ -114,7 +114,7 @@ class xla_trace(trace):
             if np_array.shape == ():
                 np_array = np_array[np.newaxis]
             xla_array = backend.buffer_from_pyval(np_array, device)
-            tensor._reset(Tensor(xla_array))
+            tensor._reset(Tensor(xla_array, device=default_cn))
 
         for attr, _ in self.attr_to_key.items():
             param = get_expand_structure(attr[0], attr[1])
@@ -232,7 +232,7 @@ class xla_trace(trace):
                 return_vals.append(outputs[self.outkey2idx[i]])
         keeped_features = []
         for i in self.keeped_activation:
-            keeped_features.append(outputs[self.outkey2idx[i]])
+            keeped_features.append(tensor(outputs[self.outkey2idx[i]], device=cn))
         out_tensors = []
         for array in return_vals:
             if array is not None:
diff --git a/imperative/python/megengine/xla/rules/nn.py b/imperative/python/megengine/xla/rules/nn.py
index 69063bbb364d065368e09c27840ec9a48b564f61..909041e6444f6ff153882e55e2955d68d45cf3ad 100644
--- a/imperative/python/megengine/xla/rules/nn.py
+++ b/imperative/python/megengine/xla/rules/nn.py
@@ -49,15 +49,16 @@ def convolution_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
     if opr.sparse == mops.BatchConvBias.Sparse.DENSE:
         feature_group_count, batch_group_count = 1, 1
     else:
-        assert ic == oc, "dwconv only support ic == oc"
         assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
-        feature_group_count, batch_group_count = ic, 1
+        feature_group_count, batch_group_count = weight.shape[0], 1
 
         if opr.format == mops.AdaptivePooling.Format.NCHW:
-            assert (
-                weight.shape[1] == 1 and weight.shape[2] == 1
-            ), f"weight shape error: {weight.shape}"
-            xla_weight_shape = [weight.shape[i] for i in [0, 2, 3, 4]]
+            xla_weight_shape = xla_weight_shape = [
+                weight.shape[0] * weight.shape[1],
+                weight.shape[2],
+                weight.shape[3],
+                weight.shape[4],
+            ]
         weight = reshape(weight, xla_weight_shape)
 
     feature_group_count = ir_utils.i64_attr(feature_group_count)
@@ -159,14 +160,16 @@ def _conv_general_vjp_rhs_padding(
     return list(zip(pads_lo, pads_hi))
 
 
-@register_lower_rule("ConvolutionBackwardDataV2")
+@register_lower_rule("ConvolutionBackwardDataV2", mops.ConvolutionBackwardData)
 def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
-    assert len(args) == 3 and len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3
     assert (
         ctx.param["dilate_h"] == 1 and ctx.param["dilate_w"] == 1
     ), "dilate_conv is not support now"
 
-    weight, dout, inp = args[0], args[1], args[2]
+    if len(args) == 3:
+        weight, dout, inp = args[0], args[1], args[2]
+    else:
+        weight, dout, inp = args[0], args[1], None
     if ctx.param["format"] == mops.AdaptivePooling.Format.NCHW:
         dnums = ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))
         inp_spec, weight_spec, out_spec = dnums
@@ -177,8 +180,8 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
         ph, pw = ctx.param["pad_h"], ctx.param["pad_w"]
         padding = ((ph, ph), (pw, pw))
         weight_shape = weight.shape
-        inp_shape = inp.shape
-        ic = inp.shape[1]  # NCHW
+        inp_shape = inp.shape if inp else ctx.vars_out[0].shape
+        ic = inp_shape[1]  # NCHW
         oc = weight.shape[0]  # OIHW or O11HW for dwconv
         t_weight_spec = (weight_spec[1], weight_spec[0]) + weight_spec[2:]
         dnums = hlo.ConvDimensionNumbers.get(
@@ -196,11 +199,23 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
         if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE:
             feature_group_count, batch_group_count = 1, 1
         else:
-            assert ic == oc, "only support dpwise conv currently"
-            assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
-            feature_group_count, batch_group_count = ic, 1
-            weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]]
+            weight_shape = weight.shape
+            assert len(weight_shape) == 5, "mge dpconv weight dim is 5"
+            feature_group_count, batch_group_count = weight.shape[0], 1
+            weight_shape = [
+                weight.shape[1],
+                weight.shape[0] * weight.shape[2],
+                weight.shape[3],
+                weight.shape[4],
+            ]
+            weight = weight.transpose((1, 0, 2, 3, 4))
             weight = weight.reshape(weight_shape)
+            weight_shape = [
+                weight_shape[1],
+                weight_shape[0],
+                weight_shape[2],
+                weight_shape[3],
+            ]
 
         padding = _conv_general_vjp_lhs_padding(
             np.take(inp_shape, inp_hw),
@@ -262,11 +277,15 @@ def conv_backward_filter_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
         if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE:
             feature_group_count, batch_group_count = 1, 1
         else:
-            assert ic == oc, "only support dpwise conv currently"
-            assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
-            feature_group_count, batch_group_count = ic, 1
-            weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]]
-
+            weight_shape = weight.shape
+            assert len(weight_shape) == 5, "mge dpconv weight dim is 5"
+            feature_group_count, batch_group_count = weight.shape[0], 1
+            weight_shape = [
+                weight_shape[2],
+                weight_shape[0] * weight_shape[1],
+                weight_shape[3],
+                weight_shape[4],
+            ]
         if batch_group_count > 1:
             feature_group_count = batch_group_count
             batch_group_count = 1
diff --git a/imperative/python/megengine/xla/rules/reduction.py b/imperative/python/megengine/xla/rules/reduction.py
index 5246a605d0f33a93c0ad4200604da122f99d6d56..05f4022a4310b859a917b6a6f8b5cb3fe7828adc 100644
--- a/imperative/python/megengine/xla/rules/reduction.py
+++ b/imperative/python/megengine/xla/rules/reduction.py
@@ -138,7 +138,9 @@ def reduce_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
     else:
         assert len(args) == 2
         src_shape = args[0].shape
-        tgt_shape = list(ctx.module_context.get_value(ctx.vars_in[1]))
+        if src_shape == ctx.vars_out[0].shape:
+            return args[0]
+        tgt_shape = list(ctx.vars_out[0].shape)
         tgt_shape = [1,] * (len(src_shape) - len(tgt_shape)) + tgt_shape
         src_idx, tgt_idx, axes = 0, 0, []
         while src_idx < len(src_shape) and tgt_idx < len(tgt_shape):
diff --git a/imperative/python/test/unit/xla/functional/test_xla_nn.py b/imperative/python/test/unit/xla/functional/test_xla_nn.py
index b12edb5e10d429d6fa13b641165a50656527756a..0188d82cb5df2584cf6f104689ed5307e2a6da6f 100644
--- a/imperative/python/test/unit/xla/functional/test_xla_nn.py
+++ b/imperative/python/test/unit/xla/functional/test_xla_nn.py
@@ -93,6 +93,74 @@ def test_conv2d():
         padding=(2, 1),
         groups=16,
     )
+    tester(
+        (4, 16, 24, 24),
+        (4, 4, 4, 1, 1),
+        (1, 16, 1, 1),
+        stride=(2, 3),
+        padding=(2, 1),
+        groups=4,
+    )
+
+
+@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
+@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
+@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
+def test_conv_transpose2d():
+    np.random.seed(123)
+    mge.random.seed(123)
+
+    def tester(x_shape, w_shape, b_shape, stride, padding, groups, dtype=None):
+        dtype = dtype or np.float32
+        x = tensor(0.1 * np.random.rand(*x_shape), dtype=dtype)
+        w = tensor(0.1 * np.random.rand(*w_shape), dtype=dtype)
+        b = tensor(0.1 * np.random.rand(*b_shape), dtype=dtype) if b_shape else None
+        y = F.conv_transpose2d(x, w, b, stride=stride, padding=padding, groups=groups)
+        dy = tensor(0.1 * np.random.rand(*y.shape), dtype=dtype)
+
+        gm = GradManager()
+
+        if b is not None:
+
+            @jit.xla_trace(without_host=True)
+            def func(x, w, b, dy):
+                gm.attach([x, w, b])
+                with gm:
+                    y = F.conv_transpose2d(
+                        x, w, b, stride=stride, padding=padding, groups=groups
+                    )
+                    gm.backward(y, dy)
+                return [y, x.grad, w.grad, b.grad]
+
+            mge_rsts = func(x, w, b, dy)
+            xla_rsts = func(x, w, b, dy)
+        else:
+
+            @jit.xla_trace(without_host=True)
+            def func(x, w, dy):
+                gm.attach([x, w])
+                with gm:
+                    y = F.conv2d(x, w, stride=stride, padding=padding, groups=groups)
+                    gm.backward(y, dy)
+                return [y, x.grad, w.grad]
+
+            mge_rsts = func(x, w, dy)
+            xla_rsts = func(x, w, dy)
+
+        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
+            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-4)
+
+    tester(
+        (4, 16, 24, 24), (16, 32, 3, 3), (1, 32, 1, 1), stride=1, padding=1, groups=1
+    )
+    tester(
+        (4, 16, 24, 24),
+        (16, 32, 3, 3),
+        (1, 32, 1, 1),
+        stride=(2, 3),
+        padding=(2, 1),
+        groups=1,
+    )
 
 
 @pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")