From 9d535d7ac34cbb9feba71e418d7a84198e65306a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 5 Jul 2023 20:11:28 +0800 Subject: [PATCH] feat(xla): improve lower rule GitOrigin-RevId: 55d43fe0f3666ef233612505a925662161b50bec --- .../python/megengine/jit/xla_backend.py | 8 +-- imperative/python/megengine/xla/rules/nn.py | 59 ++++++++++------ .../python/megengine/xla/rules/reduction.py | 4 +- .../test/unit/xla/functional/test_xla_nn.py | 68 +++++++++++++++++++ 4 files changed, 114 insertions(+), 25 deletions(-) diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py index a41427eaa..1c7a8c0ea 100644 --- a/imperative/python/megengine/jit/xla_backend.py +++ b/imperative/python/megengine/jit/xla_backend.py @@ -91,13 +91,13 @@ class xla_trace(trace): set_use_xla_backend(self.orig_use_xla) def convert_params_to_xla(self): - from ..device import coalesce_free_memory from ..utils.module_utils import get_expand_structure from ..tensor import Tensor backend = self.xla_exec.backend devices = backend.local_devices() - _, device_id, _ = CompNode(get_default_device()).physical_locator + default_cn = CompNode(get_default_device()) + _, device_id, _ = default_cn.physical_locator device_index = ( 0 if len(devices) == 0 else [d.id for d in devices].index(device_id) ) @@ -114,7 +114,7 @@ class xla_trace(trace): if np_array.shape == (): np_array = np_array[np.newaxis] xla_array = backend.buffer_from_pyval(np_array, device) - tensor._reset(Tensor(xla_array)) + tensor._reset(Tensor(xla_array, device=default_cn)) for attr, _ in self.attr_to_key.items(): param = get_expand_structure(attr[0], attr[1]) @@ -232,7 +232,7 @@ class xla_trace(trace): return_vals.append(outputs[self.outkey2idx[i]]) keeped_features = [] for i in self.keeped_activation: - keeped_features.append(outputs[self.outkey2idx[i]]) + keeped_features.append(tensor(outputs[self.outkey2idx[i]], device=cn)) out_tensors = [] for array in return_vals: if array is not None: diff --git a/imperative/python/megengine/xla/rules/nn.py b/imperative/python/megengine/xla/rules/nn.py index 69063bbb3..909041e64 100644 --- a/imperative/python/megengine/xla/rules/nn.py +++ b/imperative/python/megengine/xla/rules/nn.py @@ -49,15 +49,16 @@ def convolution_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): if opr.sparse == mops.BatchConvBias.Sparse.DENSE: feature_group_count, batch_group_count = 1, 1 else: - assert ic == oc, "dwconv only support ic == oc" assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 + feature_group_count, batch_group_count = weight.shape[0], 1 if opr.format == mops.AdaptivePooling.Format.NCHW: - assert ( - weight.shape[1] == 1 and weight.shape[2] == 1 - ), f"weight shape error: {weight.shape}" - xla_weight_shape = [weight.shape[i] for i in [0, 2, 3, 4]] + xla_weight_shape = xla_weight_shape = [ + weight.shape[0] * weight.shape[1], + weight.shape[2], + weight.shape[3], + weight.shape[4], + ] weight = reshape(weight, xla_weight_shape) feature_group_count = ir_utils.i64_attr(feature_group_count) @@ -159,14 +160,16 @@ def _conv_general_vjp_rhs_padding( return list(zip(pads_lo, pads_hi)) -@register_lower_rule("ConvolutionBackwardDataV2") +@register_lower_rule("ConvolutionBackwardDataV2", mops.ConvolutionBackwardData) def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 3 and len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3 assert ( ctx.param["dilate_h"] == 1 and ctx.param["dilate_w"] == 1 ), "dilate_conv is not support now" - weight, dout, inp = args[0], args[1], args[2] + if len(args) == 3: + weight, dout, inp = args[0], args[1], args[2] + else: + weight, dout, inp = args[0], args[1], None if ctx.param["format"] == mops.AdaptivePooling.Format.NCHW: dnums = ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3)) inp_spec, weight_spec, out_spec = dnums @@ -177,8 +180,8 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): ph, pw = ctx.param["pad_h"], ctx.param["pad_w"] padding = ((ph, ph), (pw, pw)) weight_shape = weight.shape - inp_shape = inp.shape - ic = inp.shape[1] # NCHW + inp_shape = inp.shape if inp else ctx.vars_out[0].shape + ic = inp_shape[1] # NCHW oc = weight.shape[0] # OIHW or O11HW for dwconv t_weight_spec = (weight_spec[1], weight_spec[0]) + weight_spec[2:] dnums = hlo.ConvDimensionNumbers.get( @@ -196,11 +199,23 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE: feature_group_count, batch_group_count = 1, 1 else: - assert ic == oc, "only support dpwise conv currently" - assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 - weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]] + weight_shape = weight.shape + assert len(weight_shape) == 5, "mge dpconv weight dim is 5" + feature_group_count, batch_group_count = weight.shape[0], 1 + weight_shape = [ + weight.shape[1], + weight.shape[0] * weight.shape[2], + weight.shape[3], + weight.shape[4], + ] + weight = weight.transpose((1, 0, 2, 3, 4)) weight = weight.reshape(weight_shape) + weight_shape = [ + weight_shape[1], + weight_shape[0], + weight_shape[2], + weight_shape[3], + ] padding = _conv_general_vjp_lhs_padding( np.take(inp_shape, inp_hw), @@ -262,11 +277,15 @@ def conv_backward_filter_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]] if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE: feature_group_count, batch_group_count = 1, 1 else: - assert ic == oc, "only support dpwise conv currently" - assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 - weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]] - + weight_shape = weight.shape + assert len(weight_shape) == 5, "mge dpconv weight dim is 5" + feature_group_count, batch_group_count = weight.shape[0], 1 + weight_shape = [ + weight_shape[2], + weight_shape[0] * weight_shape[1], + weight_shape[3], + weight_shape[4], + ] if batch_group_count > 1: feature_group_count = batch_group_count batch_group_count = 1 diff --git a/imperative/python/megengine/xla/rules/reduction.py b/imperative/python/megengine/xla/rules/reduction.py index 5246a605d..05f4022a4 100644 --- a/imperative/python/megengine/xla/rules/reduction.py +++ b/imperative/python/megengine/xla/rules/reduction.py @@ -138,7 +138,9 @@ def reduce_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): else: assert len(args) == 2 src_shape = args[0].shape - tgt_shape = list(ctx.module_context.get_value(ctx.vars_in[1])) + if src_shape == ctx.vars_out[0].shape: + return args[0] + tgt_shape = list(ctx.vars_out[0].shape) tgt_shape = [1,] * (len(src_shape) - len(tgt_shape)) + tgt_shape src_idx, tgt_idx, axes = 0, 0, [] while src_idx < len(src_shape) and tgt_idx < len(tgt_shape): diff --git a/imperative/python/test/unit/xla/functional/test_xla_nn.py b/imperative/python/test/unit/xla/functional/test_xla_nn.py index b12edb5e1..0188d82cb 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_nn.py +++ b/imperative/python/test/unit/xla/functional/test_xla_nn.py @@ -93,6 +93,74 @@ def test_conv2d(): padding=(2, 1), groups=16, ) + tester( + (4, 16, 24, 24), + (4, 4, 4, 1, 1), + (1, 16, 1, 1), + stride=(2, 3), + padding=(2, 1), + groups=4, + ) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_conv_transpose2d(): + np.random.seed(123) + mge.random.seed(123) + + def tester(x_shape, w_shape, b_shape, stride, padding, groups, dtype=None): + dtype = dtype or np.float32 + x = tensor(0.1 * np.random.rand(*x_shape), dtype=dtype) + w = tensor(0.1 * np.random.rand(*w_shape), dtype=dtype) + b = tensor(0.1 * np.random.rand(*b_shape), dtype=dtype) if b_shape else None + y = F.conv_transpose2d(x, w, b, stride=stride, padding=padding, groups=groups) + dy = tensor(0.1 * np.random.rand(*y.shape), dtype=dtype) + + gm = GradManager() + + if b is not None: + + @jit.xla_trace(without_host=True) + def func(x, w, b, dy): + gm.attach([x, w, b]) + with gm: + y = F.conv_transpose2d( + x, w, b, stride=stride, padding=padding, groups=groups + ) + gm.backward(y, dy) + return [y, x.grad, w.grad, b.grad] + + mge_rsts = func(x, w, b, dy) + xla_rsts = func(x, w, b, dy) + else: + + @jit.xla_trace(without_host=True) + def func(x, w, dy): + gm.attach([x, w]) + with gm: + y = F.conv2d(x, w, stride=stride, padding=padding, groups=groups) + gm.backward(y, dy) + return [y, x.grad, w.grad] + + mge_rsts = func(x, w, dy) + xla_rsts = func(x, w, dy) + + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-4) + + tester( + (4, 16, 24, 24), (16, 32, 3, 3), (1, 32, 1, 1), stride=1, padding=1, groups=1 + ) + tester( + (4, 16, 24, 24), + (16, 32, 3, 3), + (1, 32, 1, 1), + stride=(2, 3), + padding=(2, 1), + groups=1, + ) @pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") -- GitLab