提交 9d535d7a 编写于 作者: M Megvii Engine Team

feat(xla): improve lower rule

GitOrigin-RevId: 55d43fe0f3666ef233612505a925662161b50bec
上级 d8917c22
......@@ -91,13 +91,13 @@ class xla_trace(trace):
set_use_xla_backend(self.orig_use_xla)
def convert_params_to_xla(self):
from ..device import coalesce_free_memory
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
backend = self.xla_exec.backend
devices = backend.local_devices()
_, device_id, _ = CompNode(get_default_device()).physical_locator
default_cn = CompNode(get_default_device())
_, device_id, _ = default_cn.physical_locator
device_index = (
0 if len(devices) == 0 else [d.id for d in devices].index(device_id)
)
......@@ -114,7 +114,7 @@ class xla_trace(trace):
if np_array.shape == ():
np_array = np_array[np.newaxis]
xla_array = backend.buffer_from_pyval(np_array, device)
tensor._reset(Tensor(xla_array))
tensor._reset(Tensor(xla_array, device=default_cn))
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
......@@ -232,7 +232,7 @@ class xla_trace(trace):
return_vals.append(outputs[self.outkey2idx[i]])
keeped_features = []
for i in self.keeped_activation:
keeped_features.append(outputs[self.outkey2idx[i]])
keeped_features.append(tensor(outputs[self.outkey2idx[i]], device=cn))
out_tensors = []
for array in return_vals:
if array is not None:
......
......@@ -49,15 +49,16 @@ def convolution_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if opr.sparse == mops.BatchConvBias.Sparse.DENSE:
feature_group_count, batch_group_count = 1, 1
else:
assert ic == oc, "dwconv only support ic == oc"
assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
feature_group_count, batch_group_count = ic, 1
feature_group_count, batch_group_count = weight.shape[0], 1
if opr.format == mops.AdaptivePooling.Format.NCHW:
assert (
weight.shape[1] == 1 and weight.shape[2] == 1
), f"weight shape error: {weight.shape}"
xla_weight_shape = [weight.shape[i] for i in [0, 2, 3, 4]]
xla_weight_shape = xla_weight_shape = [
weight.shape[0] * weight.shape[1],
weight.shape[2],
weight.shape[3],
weight.shape[4],
]
weight = reshape(weight, xla_weight_shape)
feature_group_count = ir_utils.i64_attr(feature_group_count)
......@@ -159,14 +160,16 @@ def _conv_general_vjp_rhs_padding(
return list(zip(pads_lo, pads_hi))
@register_lower_rule("ConvolutionBackwardDataV2")
@register_lower_rule("ConvolutionBackwardDataV2", mops.ConvolutionBackwardData)
def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 3 and len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3
assert (
ctx.param["dilate_h"] == 1 and ctx.param["dilate_w"] == 1
), "dilate_conv is not support now"
weight, dout, inp = args[0], args[1], args[2]
if len(args) == 3:
weight, dout, inp = args[0], args[1], args[2]
else:
weight, dout, inp = args[0], args[1], None
if ctx.param["format"] == mops.AdaptivePooling.Format.NCHW:
dnums = ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))
inp_spec, weight_spec, out_spec = dnums
......@@ -177,8 +180,8 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
ph, pw = ctx.param["pad_h"], ctx.param["pad_w"]
padding = ((ph, ph), (pw, pw))
weight_shape = weight.shape
inp_shape = inp.shape
ic = inp.shape[1] # NCHW
inp_shape = inp.shape if inp else ctx.vars_out[0].shape
ic = inp_shape[1] # NCHW
oc = weight.shape[0] # OIHW or O11HW for dwconv
t_weight_spec = (weight_spec[1], weight_spec[0]) + weight_spec[2:]
dnums = hlo.ConvDimensionNumbers.get(
......@@ -196,11 +199,23 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE:
feature_group_count, batch_group_count = 1, 1
else:
assert ic == oc, "only support dpwise conv currently"
assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
feature_group_count, batch_group_count = ic, 1
weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]]
weight_shape = weight.shape
assert len(weight_shape) == 5, "mge dpconv weight dim is 5"
feature_group_count, batch_group_count = weight.shape[0], 1
weight_shape = [
weight.shape[1],
weight.shape[0] * weight.shape[2],
weight.shape[3],
weight.shape[4],
]
weight = weight.transpose((1, 0, 2, 3, 4))
weight = weight.reshape(weight_shape)
weight_shape = [
weight_shape[1],
weight_shape[0],
weight_shape[2],
weight_shape[3],
]
padding = _conv_general_vjp_lhs_padding(
np.take(inp_shape, inp_hw),
......@@ -262,11 +277,15 @@ def conv_backward_filter_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE:
feature_group_count, batch_group_count = 1, 1
else:
assert ic == oc, "only support dpwise conv currently"
assert len(weight.shape) == 5, "mge dpconv weight dim is 5"
feature_group_count, batch_group_count = ic, 1
weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]]
weight_shape = weight.shape
assert len(weight_shape) == 5, "mge dpconv weight dim is 5"
feature_group_count, batch_group_count = weight.shape[0], 1
weight_shape = [
weight_shape[2],
weight_shape[0] * weight_shape[1],
weight_shape[3],
weight_shape[4],
]
if batch_group_count > 1:
feature_group_count = batch_group_count
batch_group_count = 1
......
......@@ -138,7 +138,9 @@ def reduce_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
else:
assert len(args) == 2
src_shape = args[0].shape
tgt_shape = list(ctx.module_context.get_value(ctx.vars_in[1]))
if src_shape == ctx.vars_out[0].shape:
return args[0]
tgt_shape = list(ctx.vars_out[0].shape)
tgt_shape = [1,] * (len(src_shape) - len(tgt_shape)) + tgt_shape
src_idx, tgt_idx, axes = 0, 0, []
while src_idx < len(src_shape) and tgt_idx < len(tgt_shape):
......
......@@ -93,6 +93,74 @@ def test_conv2d():
padding=(2, 1),
groups=16,
)
tester(
(4, 16, 24, 24),
(4, 4, 4, 1, 1),
(1, 16, 1, 1),
stride=(2, 3),
padding=(2, 1),
groups=4,
)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_conv_transpose2d():
np.random.seed(123)
mge.random.seed(123)
def tester(x_shape, w_shape, b_shape, stride, padding, groups, dtype=None):
dtype = dtype or np.float32
x = tensor(0.1 * np.random.rand(*x_shape), dtype=dtype)
w = tensor(0.1 * np.random.rand(*w_shape), dtype=dtype)
b = tensor(0.1 * np.random.rand(*b_shape), dtype=dtype) if b_shape else None
y = F.conv_transpose2d(x, w, b, stride=stride, padding=padding, groups=groups)
dy = tensor(0.1 * np.random.rand(*y.shape), dtype=dtype)
gm = GradManager()
if b is not None:
@jit.xla_trace(without_host=True)
def func(x, w, b, dy):
gm.attach([x, w, b])
with gm:
y = F.conv_transpose2d(
x, w, b, stride=stride, padding=padding, groups=groups
)
gm.backward(y, dy)
return [y, x.grad, w.grad, b.grad]
mge_rsts = func(x, w, b, dy)
xla_rsts = func(x, w, b, dy)
else:
@jit.xla_trace(without_host=True)
def func(x, w, dy):
gm.attach([x, w])
with gm:
y = F.conv2d(x, w, stride=stride, padding=padding, groups=groups)
gm.backward(y, dy)
return [y, x.grad, w.grad]
mge_rsts = func(x, w, dy)
xla_rsts = func(x, w, dy)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-4)
tester(
(4, 16, 24, 24), (16, 32, 3, 3), (1, 32, 1, 1), stride=1, padding=1, groups=1
)
tester(
(4, 16, 24, 24),
(16, 32, 3, 3),
(1, 32, 1, 1),
stride=(2, 3),
padding=(2, 1),
groups=1,
)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册