提交 8a222d2c 编写于 作者: M Megvii Engine Team

fix(mge): replace functional paramter

GitOrigin-RevId: af3c68c588ff319651a6545bfc4ee1e428ffc1e6
上级 27aa648b
......@@ -12,6 +12,7 @@ from typing import Union
import numpy as np
from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from ..ops import builtin
......@@ -87,6 +88,7 @@ def _matmul(inp1, inp2):
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
)
......
......@@ -11,6 +11,7 @@ import math
from functools import lru_cache
from typing import Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._trace_option import use_symbolic_shape
......@@ -1077,6 +1078,7 @@ def matmul(
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
if dim1 == 1 and dim2 == 1: # dispatch to Dot
return dot(inp1, inp2)
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
......
......@@ -10,6 +10,7 @@
from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
......@@ -115,6 +116,7 @@ def linear(
weight: weight with shape `(out_features, in_features)`.
bias: bias with shape `(out_features,)`. Default: None
"""
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None:
if amp._enabled:
......@@ -185,6 +187,8 @@ def conv1d(
pad_h = padding
dilate_h = dilation
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution(
stride_h=stride_h,
......@@ -197,6 +201,7 @@ def conv1d(
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
format=conv_format,
)
(output,) = apply(op, inp, weight)
if bias is not None:
......@@ -261,6 +266,8 @@ def conv2d(
dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Convolution(
stride_h=stride_h,
stride_w=stride_w,
......@@ -272,6 +279,7 @@ def conv2d(
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
format=conv_format,
)
(output,) = apply(op, inp, weight)
if bias is not None:
......@@ -403,6 +411,7 @@ def conv_transpose2d(
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
op = builtin.ConvolutionBackwardData(
stride_h=stride_h,
......@@ -474,6 +483,7 @@ def deformable_conv2d(
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.DeformableConv(
stride_h=stride_h,
......@@ -614,6 +624,7 @@ def max_pool2d(
window_h, window_w = _pair_nonzero(kernel_size)
stride_h, stride_w = _pair_nonzero(stride)
padding_h, padding_w = _pair(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling(
window_h=window_h,
......@@ -623,6 +634,7 @@ def max_pool2d(
pad_h=padding_h,
pad_w=padding_w,
mode="max",
format=conv_format,
)
(output,) = apply(op, inp)
return output
......@@ -656,6 +668,7 @@ def avg_pool2d(
window_h, window_w = _pair_nonzero(kernel_size)
stride_h, stride_w = _pair_nonzero(stride)
padding_h, padding_w = _pair(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling(
window_h=window_h,
......@@ -665,6 +678,7 @@ def avg_pool2d(
pad_h=padding_h,
pad_w=padding_w,
mode=mode,
format=conv_format,
)
(output,) = apply(op, inp)
return output
......@@ -686,8 +700,9 @@ def adaptive_max_pool2d(
"""
if isinstance(oshp, int):
oshp = (oshp, oshp)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.AdaptivePooling(mode="max", format="NCHW",)
op = builtin.AdaptivePooling(mode="max", format=conv_format,)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
......
......@@ -8,6 +8,7 @@
# pylint: disable=too-many-lines
from typing import Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..tensor import Tensor
......@@ -55,6 +56,8 @@ def conv_bias_activation(
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.ConvBias(
stride_h=sh,
stride_w=sw,
......@@ -63,7 +66,7 @@ def conv_bias_activation(
dilate_h=dh,
dilate_w=dw,
dtype=dtype,
format="NCHW",
format=conv_format,
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
......@@ -114,6 +117,8 @@ def batch_conv_bias_activation(
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.BatchConvBias(
stride_h=sh,
stride_w=sw,
......@@ -122,7 +127,7 @@ def batch_conv_bias_activation(
dilate_h=dh,
dilate_w=dw,
dtype=dtype,
format="NCHW",
format=conv_format,
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
......@@ -164,6 +169,7 @@ def conv_transpose2d(
pad_h, pad_w = _pair(padding)
stride_h, stride_w = _pair_nonzero(stride)
dilate_h, dilate_w = _pair_nonzero(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
# should be replaced by Op with bias such as ConvolutionBackwardDataBias
op = builtin.ConvolutionBackwardData(
......
......@@ -10,6 +10,7 @@ from typing import Iterable, Optional, Tuple, Union
import numpy as np
from ..core import _config
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
......@@ -143,9 +144,11 @@ def correlation(
pad_size: int (non-negative), optional, default=0) – pad for Correlation
is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
"""
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
assert conv_format == "NCHW", "Currently correlation only support NCHW mode"
op = builtin.Correlation(
format="NCHW",
format=conv_format,
kernel_size=kernel_size,
max_displacement=max_displacement,
stride1=stride1,
......@@ -215,10 +218,12 @@ def roi_align(
sample_points = (sample_points, sample_points)
sample_height, sample_width = sample_points
offset = 0.5 if aligned else 0.0
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
assert conv_format == "NCHW", "Currently roi_align only support NCHW mode"
op = builtin.ROIAlign(
mode=mode,
format="NCHW",
format=conv_format,
spatial_scale=spatial_scale,
offset=offset,
pooled_height=pooled_height,
......@@ -343,9 +348,10 @@ def remap(
[[[[1. 4.]
[4. 4.]]]]
"""
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Remap(
imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar
imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar
)
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
(result,) = apply(op, inp, map_xy)
......@@ -384,10 +390,12 @@ def warp_affine(
however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported.
"""
conv_format = _config._get_actual_op_param(format, _config.__conv_format)
op = builtin.WarpAffine(
border_mode=border_mode,
border_val=border_val,
format=format,
format=conv_format,
imode=interp_mode,
)
out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
......@@ -466,8 +474,9 @@ def warp_perspective(
mat = mat.astype("float32")
if inp.dtype == np.float16:
inp = inp.astype("float32")
conv_format = _config._get_actual_op_param(format, _config.__conv_format)
op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val
)
out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
if mat_idx is not None:
......@@ -602,7 +611,9 @@ def interpolate(
}
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.Resize(imode=mode_map[mode], format="NCHW")
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
assert conv_format == "NCHW", "Currently resize only support NCHW mode"
op = builtin.Resize(imode=mode_map[mode], format=conv_format)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册