提交 c1c1d6d1 编写于 作者: M Megvii Engine Team

feat(imperative): region restrictd conv support bias in python

GitOrigin-RevId: 9a2c1ee27a0ca576f98c072d2854ebe59a2ff5ce
上级 f287501e
......@@ -1980,6 +1980,7 @@ def region_restricted_conv(
weight: Tensor,
rin: Tensor,
rout: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
......@@ -1994,6 +1995,9 @@ def region_restricted_conv(
Args:
inp: feature map of the convolution operation.
weight: convolution kernel.
rin: input mask
rout: output mask
bias: bias added to the result of convolution (if given).
stride: stride of the 2D region restricted convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
......@@ -2027,6 +2031,8 @@ def region_restricted_conv(
sparse=sparse_type,
)
(output,) = apply(op, inp, weight, rin, rout)
if bias is not None:
output += bias
return output
......
......@@ -1040,6 +1040,7 @@ class RegionRestrictedConv(_ConvNd):
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``. Default: 1
bias: whether to add a bias onto the result of convolution. Default: True
conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
......@@ -1071,6 +1072,7 @@ class RegionRestrictedConv(_ConvNd):
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
groups: int,
bias: bool = True,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
......@@ -1095,7 +1097,7 @@ class RegionRestrictedConv(_ConvNd):
0,
dilation,
groups,
False,
bias,
**kwargs,
)
......@@ -1133,7 +1135,7 @@ class RegionRestrictedConv(_ConvNd):
(self.padding[1], self.padding[1]),
)
def calc_conv(self, inp, weight, rin, rout):
def calc_conv(self, inp, weight, rin, rout, bias):
assert self.padding_mode in [
"zeros",
"reflect",
......@@ -1144,6 +1146,7 @@ class RegionRestrictedConv(_ConvNd):
weight,
rin,
rout,
bias,
self.stride,
self.padding,
self.dilation,
......@@ -1153,4 +1156,4 @@ class RegionRestrictedConv(_ConvNd):
)
def forward(self, inp, rin, rout):
return self.calc_conv(inp, self.weight, rin, rout)
return self.calc_conv(inp, self.weight, rin, rout, self.bias)
......@@ -930,7 +930,8 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_naive(bias):
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager
......@@ -943,15 +944,22 @@ def test_region_restricted_conv_forward_backward_naive():
cpu_src = tensor(src_1, device=handle)
cpu_filter = tensor(filter_1, device=handle)
gm = GradManager().attach([cpu_src, cpu_filter])
cpu_bias = (
tensor(np.ones((1, 2, 1, 1), dtype=np.float32), device=handle) if bias else None
)
with gm:
cpu_out = F.region_restricted_conv(
cpu_src,
cpu_filter,
tensor(rin_1, device=handle),
tensor(rout_1, device=handle),
bias=cpu_bias,
groups=2,
)
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle))
if cpu_bias is not None:
cpu_out = cpu_out - cpu_bias
np.testing.assert_allclose(cpu_out, np.array([14, 126]).reshape(1, 2, 1, 1))
np.testing.assert_allclose(
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2)
)
......@@ -963,7 +971,8 @@ def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
)
def test_region_restricted_conv_forward_backward_cuda():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_cuda(bias):
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager
......@@ -998,18 +1007,23 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter])
with gm:
expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
)
gm.backward(
expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
)
return src, filter
return src, filter, expected_out
expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()
src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
......@@ -1018,18 +1032,25 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)
gm = GradManager().attach([src, filter])
with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle))
np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad)
np.testing.assert_allclose(gpu_out, expected_out)
@pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
)
def test_region_restricted_conv_forward_backward_uint8():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_uint8(bias):
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager
......@@ -1063,18 +1084,23 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter])
with gm:
expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
)
gm.backward(
expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
)
return src, filter
return src, filter, expected_out
expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()
# forward and dgrad/wgrad
src = tensor(
......@@ -1084,23 +1110,22 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)
gm = GradManager().attach([src, filter])
with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward(
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle)
)
# assert uint8 gpu result close to cpu result
np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad)
def test_region_restricted_conv():
test_region_restricted_conv_forward_backward_naive()
if is_cuda_available():
test_region_restricted_conv_forward_backward_cuda()
test_region_restricted_conv_forward_backward_uint8()
np.testing.assert_allclose(gpu_out, expected_out)
def test_conv2d_autocast():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册