提交 3f3a256e 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix conv* dtype promotion

GitOrigin-RevId: 3f03790cfc2ecf2f2c05e1ea5a68be0bc0e84bb2
上级 536506c3
......@@ -9,7 +9,7 @@
# pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const
......@@ -157,6 +157,12 @@ def conv1d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
......@@ -234,6 +240,12 @@ def conv2d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -297,6 +309,12 @@ def conv3d(
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D(
pad_d=pad[D],
......@@ -364,6 +382,12 @@ def conv_transpose2d(
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
......@@ -482,6 +506,12 @@ def local_conv2d(
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
op = builtin.GroupLocal(
stride_h=stride_h,
stride_w=stride_w,
......@@ -527,6 +557,12 @@ def conv_transpose3d(
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
pad_h=pad[H],
......
......@@ -939,7 +939,7 @@ class ConvTranspose3d(_ConvNd):
ichl = self.in_channels
ochl = self.out_channels
kt, kh, kw = self.kernel_size
return (ochl, ichl, kt, kh, kw)
return (ichl, ochl, kt, kh, kw)
def _infer_bias_shape(self):
# Assume format is NCTHW
......
......@@ -9,11 +9,41 @@
import itertools
import numpy as np
import pytest
import megengine.module as M
from megengine import Parameter, tensor
from megengine.functional.debug_param import (
get_execution_strategy,
set_execution_strategy,
)
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d
@pytest.fixture
def reproducible():
old = get_execution_strategy()
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
yield
set_execution_strategy(old)
# NOTE: test in module for convenience. should really test in functional
@pytest.mark.parametrize(
"name",
["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"],
)
def test_conv_dtype_promotion(name, reproducible):
N, Ci, Co, K = 2, 16, 32, 3
S = (7,) * int(name[-2])
if "Local" in name:
m = getattr(M, name)(Ci, Co, *S, K)
else:
m = getattr(M, name)(Ci, Co, K)
x = tensor(np.random.random(size=(N, Ci) + S).astype("float16"))
np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy())
def test_conv_transpose2d():
SH, SW = 3, 1
PH, PW = 2, 0
......@@ -163,6 +193,7 @@ def test_conv_transpose3d():
)
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW]
assert conv_transpose3d.weight.numpy().shape == weight.shape
conv_transpose3d.weight = Parameter(weight)
out_meg = conv_transpose3d.forward(tensor(inp))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册