提交 08f7a957 编写于 作者: M Megvii Engine Team

fix(imperative): fix the parameter acquisition problem of generalnorm

GitOrigin-RevId: 9d0e689c24307c00cd2060ae420180b775be2c5b
上级 9e6544bf
...@@ -24,11 +24,7 @@ void GeneralNormBase::deduce_layout_fwd( ...@@ -24,11 +24,7 @@ void GeneralNormBase::deduce_layout_fwd(
unnormalized_shape[idx] = data.shape[i]; unnormalized_shape[idx] = data.shape[i];
TensorLayout unnormalized_layout = TensorLayout unnormalized_layout =
TensorLayout(unnormalized_shape, dtype::Float32()); TensorLayout(unnormalized_shape, dtype::Float32());
if (idx == 0) { unnormalized_layout.ndim = idx;
unnormalized_layout.ndim = 1;
unnormalized_layout.shape[0] = 1;
} else
unnormalized_layout.ndim = idx;
unnormalized_layout.init_contiguous_stride(); unnormalized_layout.init_contiguous_stride();
dst = data; dst = data;
......
...@@ -71,6 +71,7 @@ __all__ = [ ...@@ -71,6 +71,7 @@ __all__ = [
"dropout", "dropout",
"embedding", "embedding",
"gelu", "gelu",
"general_norm",
"group_norm", "group_norm",
"hsigmoid", "hsigmoid",
"instance_norm", "instance_norm",
...@@ -1173,25 +1174,24 @@ def general_norm( ...@@ -1173,25 +1174,24 @@ def general_norm(
See :math:`\beta` in :class:`~.GeneralNorm`. See :math:`\beta` in :class:`~.GeneralNorm`.
eps: a value added to the denominator for numerical stability. Default: 1e-5 eps: a value added to the denominator for numerical stability. Default: 1e-5
""" """
if isinstance(normalized_axis, int): if not isinstance(normalized_axis, Sequence):
normalized_axis = [ normalized_axis = [normalized_axis]
normalized_axis, assert isinstance(normalized_axis, (list, tuple))
]
elif isinstance(normalized_axis, tuple):
normalized_axis = list(normalized_axis)
else:
assert isinstance(normalized_axis, list), "not support normalized_axis type"
assert len(normalized_axis) > 0, "normalization axis not specified" assert len(normalized_axis) > 0, "normalization axis not specified"
normalized_axis.sort()
if len(normalized_axis) != 1: normalized_axis = [num + inp.ndim if num < 0 else num for num in normalized_axis]
assert normalized_axis[0] >= 0 assert normalized_axis == sorted(
elif normalized_axis[0] == -1: normalized_axis
normalized_axis[0] = inp.shape[inp.ndim - 1] ), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format(
sorted(normalized_axis), normalized_axis
)
assert (
normalized_axis[-1] < inp.ndim,
), "the maximum axis in normalized_axis is greater than inp_shape.ndim"
assert len(set(normalized_axis)) == len( assert len(set(normalized_axis)) == len(
normalized_axis normalized_axis
), "there are duplicate axis in list normalized_axis" ), "there are duplicate axis in normalized_axis"
_reshape = [] _reshape = []
_rereshape = [] _rereshape = []
...@@ -1200,7 +1200,7 @@ def general_norm( ...@@ -1200,7 +1200,7 @@ def general_norm(
) != (len(normalized_axis) - 1) ) != (len(normalized_axis) - 1)
if _need_reshape: if _need_reshape:
get_logger().warning( get_logger().warning(
"normalized_axis is discontinuous, and performance may be poor." "normalized_axis is discontinuous, and performance may be poor"
) )
unnormalized_axis = list(set(range(inp.ndim)) - set(normalized_axis)) unnormalized_axis = list(set(range(inp.ndim)) - set(normalized_axis))
unnormalized_axis.sort() unnormalized_axis.sort()
......
import typing as T
import numpy as np import numpy as np
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Parameter from megengine import Parameter
from ..logger import get_logger
from .init import ones_, zeros_ from .init import ones_, zeros_
from .module import Module from .module import Module
...@@ -209,7 +212,8 @@ class GeneralNorm(Module): ...@@ -209,7 +212,8 @@ class GeneralNorm(Module):
The standard-deviation is calculated via the biased estimator. The standard-deviation is calculated via the biased estimator.
Args: Args:
normalized_axis(int, list or tuple): the axis of input needs to be normalized. Default: -1 normalized_shape(int, list or tuple): the shape of input needs to be normalized, normalized_shape must be specified when affine is true. When affine=true, we will directly use this shape to initialize weight/bias. Please ensure that the order is correct. Default: None
normalized_axis(int, list or tuple): the axis of input needs to be normalized, one-to-one correspondence between normalized_axis and normalized_shape. Default: -1
eps: a value added to the denominator for numerical stability. Default: 1e-5 eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: this module has learnable affine parameters (weight, bias) when affine is set to be True. affine: this module has learnable affine parameters (weight, bias) when affine is set to be True.
...@@ -220,45 +224,85 @@ class GeneralNorm(Module): ...@@ -220,45 +224,85 @@ class GeneralNorm(Module):
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4)) >>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.GeneralNorm((0, 2)) >>> m = M.GeneralNorm((2, 4), (0, 2))
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
>>> m = M.GeneralNorm((3, 4), (1, -1)) # Please be careful.
>>> out = m(inp) >>> out = m(inp)
>>> out.numpy().shape >>> out.numpy().shape
(2, 3, 4, 4) (2, 3, 4, 4)
>>> m = M.GeneralNorm((2, 4, 3), (0, 2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, 2)).
>>> m = M.GeneralNorm((2, 4, 3), (0, -2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, -2)).
>>> m = M.GeneralNorm((3, 4), (3, -1)) # Incorrect initialization, because axis=-1 and axis=3 are the same axis, namely axis=3.
""" """
def __init__(self, normalized_axis=-1, eps=1e-05, affine=True, **kwargs): def __init__(
self, normalized_shape=None, normalized_axis=0, eps=1e-05, affine=True, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
if isinstance(normalized_axis, int):
normalized_axis = (normalized_axis,)
if isinstance(normalized_axis, list):
normalized_axis.sort()
if isinstance(normalized_axis, tuple):
normalized_axis = sorted(normalized_axis)
self.normalized_axis = tuple(normalized_axis)
self.eps = eps self.eps = eps
self.affine = affine self.affine = affine
self.weight = None
self.bias = None if self.affine:
assert (
normalized_shape is not None
), "normalized_shape must be specified when affine is true"
assert (
normalized_axis is not None
), "normalized_axis must be specified when affine is true"
if not isinstance(normalized_axis, T.Sequence):
normalized_axis = [normalized_axis]
if not isinstance(normalized_shape, T.Sequence):
normalized_axis = [normalized_shape]
assert isinstance(normalized_axis, (list, tuple))
assert isinstance(normalized_shape, (list, tuple))
assert len(normalized_axis) == len(
normalized_shape
), "The size of normalized_axis and normalized_shape are different"
assert len(set(normalized_axis)) == len(
normalized_axis
), "there are duplicate axis in list normalized_axis"
self.weight = Parameter(np.ones(normalized_shape, dtype="float32"))
self.bias = Parameter(np.zeros(normalized_shape, dtype="float32"))
else:
self.weight = None
self.bias = None
self.normalized_shape = normalized_shape
self.normalized_axis = normalized_axis
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
if self.affine and self.weight and self.bias: if self.affine:
ones_(self.weight) ones_(self.weight)
zeros_(self.bias) zeros_(self.bias)
def forward(self, x): def forward(self, x):
if self.affine and not self.weight and not self.bias: self.normalized_axis = [
shape = [] num + x.ndim if num < 0 else num for num in self.normalized_axis
if len(self.normalized_axis) == 1 and self.normalized_axis[0] == -1: ]
shape = x.shape[x.ndim - 1] assert self.normalized_axis == sorted(
else: self.normalized_axis
for axis in self.normalized_axis: ), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format(
shape.append(x.shape[axis]) sorted(self.normalized_axis), self.normalized_axis
)
self.weight = Parameter(np.ones(shape, dtype="float32")) inp_shape = x.shape
self.bias = Parameter(np.zeros(shape, dtype="float32")) for i in range(len(self.normalized_axis)):
assert (
inp_shape[self.normalized_axis[i]] == self.normalized_shape[i]
), "inp.shape={}, normalized_axis={}, normalized_shape={}, inp.shape[normalized_axis[{}]]({}) != normalized_shape[{}]({})".format(
x.shape,
self.normalized_axis,
self.normalized_shape,
i,
inp_shape[self.normalized_axis[i]],
i,
self.normalized_shape[i],
)
x = F.nn.general_norm( x = F.nn.general_norm(
x, self.normalized_axis, self.affine, self.weight, self.bias, self.eps x, self.normalized_axis, self.affine, self.weight, self.bias, self.eps
...@@ -266,5 +310,5 @@ class GeneralNorm(Module): ...@@ -266,5 +310,5 @@ class GeneralNorm(Module):
return x return x
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
s = "normalized_axis={normalized_axis}, eps={eps}, affine={affine}" s = "normalized_shape={normalized_shape}, normalized_axis={normalized_axis}, eps={eps}, affine={affine}"
return s.format(**self.__dict__) return s.format(**self.__dict__)
...@@ -77,6 +77,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -77,6 +77,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
auto mean = Tensor::make(mean_layout, cn); auto mean = Tensor::make(mean_layout, cn);
auto rstd = Tensor::make(rstd_layout, cn); auto rstd = Tensor::make(rstd_layout, cn);
if (inputs[0]->layout().is_empty()) {
return {out, mean, rstd};
}
if (p.affine) { if (p.affine) {
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd);
} else { } else {
......
...@@ -22,6 +22,9 @@ GeneralNormForward::GeneralNormForward( ...@@ -22,6 +22,9 @@ GeneralNormForward::GeneralNormForward(
output(0)->dtype(data->dtype()); output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32()); output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32()); output(2)->dtype(dtype::Float32());
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }
GeneralNormForward::GeneralNormForward( GeneralNormForward::GeneralNormForward(
...@@ -33,6 +36,9 @@ GeneralNormForward::GeneralNormForward( ...@@ -33,6 +36,9 @@ GeneralNormForward::GeneralNormForward(
output(0)->dtype(data->dtype()); output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32()); output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32()); output(2)->dtype(dtype::Float32());
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }
SymbolVarArray GeneralNormForward::make( SymbolVarArray GeneralNormForward::make(
...@@ -90,6 +96,12 @@ size_t GeneralNormForward::get_workspace_size_bytes( ...@@ -90,6 +96,12 @@ size_t GeneralNormForward::get_workspace_size_bytes(
} }
void GeneralNormForward::scn_do_execute() { void GeneralNormForward::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(
output(0)->dev_tensor().empty() && output(1)->dev_tensor().empty() &&
output(2)->dev_tensor().empty());
return;
}
if (param().affine) { if (param().affine) {
megdnn_opr()->exec( megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
...@@ -105,6 +117,16 @@ void GeneralNormForward::scn_do_execute() { ...@@ -105,6 +117,16 @@ void GeneralNormForward::scn_do_execute() {
} }
} }
GeneralNormForward::NodeProp* GeneralNormForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
if (input().size() == 3) {
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(2), NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GeneralNormForward) { MGB_IMPL_OPR_GRAD(GeneralNormForward) {
auto p = opr.param(); auto p = opr.param();
......
...@@ -30,6 +30,7 @@ private: ...@@ -30,6 +30,7 @@ private:
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override; const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override; void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
}; };
using GeneralNorm = GeneralNormForward; using GeneralNorm = GeneralNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册