提交 08f7a957 编写于 作者: M Megvii Engine Team

fix(imperative): fix the parameter acquisition problem of generalnorm

GitOrigin-RevId: 9d0e689c24307c00cd2060ae420180b775be2c5b
上级 9e6544bf
......@@ -24,11 +24,7 @@ void GeneralNormBase::deduce_layout_fwd(
unnormalized_shape[idx] = data.shape[i];
TensorLayout unnormalized_layout =
TensorLayout(unnormalized_shape, dtype::Float32());
if (idx == 0) {
unnormalized_layout.ndim = 1;
unnormalized_layout.shape[0] = 1;
} else
unnormalized_layout.ndim = idx;
unnormalized_layout.ndim = idx;
unnormalized_layout.init_contiguous_stride();
dst = data;
......
......@@ -71,6 +71,7 @@ __all__ = [
"dropout",
"embedding",
"gelu",
"general_norm",
"group_norm",
"hsigmoid",
"instance_norm",
......@@ -1173,25 +1174,24 @@ def general_norm(
See :math:`\beta` in :class:`~.GeneralNorm`.
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""
if isinstance(normalized_axis, int):
normalized_axis = [
normalized_axis,
]
elif isinstance(normalized_axis, tuple):
normalized_axis = list(normalized_axis)
else:
assert isinstance(normalized_axis, list), "not support normalized_axis type"
if not isinstance(normalized_axis, Sequence):
normalized_axis = [normalized_axis]
assert isinstance(normalized_axis, (list, tuple))
assert len(normalized_axis) > 0, "normalization axis not specified"
normalized_axis.sort()
if len(normalized_axis) != 1:
assert normalized_axis[0] >= 0
elif normalized_axis[0] == -1:
normalized_axis[0] = inp.shape[inp.ndim - 1]
normalized_axis = [num + inp.ndim if num < 0 else num for num in normalized_axis]
assert normalized_axis == sorted(
normalized_axis
), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format(
sorted(normalized_axis), normalized_axis
)
assert (
normalized_axis[-1] < inp.ndim,
), "the maximum axis in normalized_axis is greater than inp_shape.ndim"
assert len(set(normalized_axis)) == len(
normalized_axis
), "there are duplicate axis in list normalized_axis"
), "there are duplicate axis in normalized_axis"
_reshape = []
_rereshape = []
......@@ -1200,7 +1200,7 @@ def general_norm(
) != (len(normalized_axis) - 1)
if _need_reshape:
get_logger().warning(
"normalized_axis is discontinuous, and performance may be poor."
"normalized_axis is discontinuous, and performance may be poor"
)
unnormalized_axis = list(set(range(inp.ndim)) - set(normalized_axis))
unnormalized_axis.sort()
......
import typing as T
import numpy as np
import megengine as mge
import megengine.functional as F
from megengine import Parameter
from ..logger import get_logger
from .init import ones_, zeros_
from .module import Module
......@@ -209,7 +212,8 @@ class GeneralNorm(Module):
The standard-deviation is calculated via the biased estimator.
Args:
normalized_axis(int, list or tuple): the axis of input needs to be normalized. Default: -1
normalized_shape(int, list or tuple): the shape of input needs to be normalized, normalized_shape must be specified when affine is true. When affine=true, we will directly use this shape to initialize weight/bias. Please ensure that the order is correct. Default: None
normalized_axis(int, list or tuple): the axis of input needs to be normalized, one-to-one correspondence between normalized_axis and normalized_shape. Default: -1
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: this module has learnable affine parameters (weight, bias) when affine is set to be True.
......@@ -220,45 +224,85 @@ class GeneralNorm(Module):
Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.GeneralNorm((0, 2))
>>> m = M.GeneralNorm((2, 4), (0, 2))
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
>>> m = M.GeneralNorm((3, 4), (1, -1)) # Please be careful.
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
>>> m = M.GeneralNorm((2, 4, 3), (0, 2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, 2)).
>>> m = M.GeneralNorm((2, 4, 3), (0, -2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, -2)).
>>> m = M.GeneralNorm((3, 4), (3, -1)) # Incorrect initialization, because axis=-1 and axis=3 are the same axis, namely axis=3.
"""
def __init__(self, normalized_axis=-1, eps=1e-05, affine=True, **kwargs):
def __init__(
self, normalized_shape=None, normalized_axis=0, eps=1e-05, affine=True, **kwargs
):
super().__init__(**kwargs)
if isinstance(normalized_axis, int):
normalized_axis = (normalized_axis,)
if isinstance(normalized_axis, list):
normalized_axis.sort()
if isinstance(normalized_axis, tuple):
normalized_axis = sorted(normalized_axis)
self.normalized_axis = tuple(normalized_axis)
self.eps = eps
self.affine = affine
self.weight = None
self.bias = None
if self.affine:
assert (
normalized_shape is not None
), "normalized_shape must be specified when affine is true"
assert (
normalized_axis is not None
), "normalized_axis must be specified when affine is true"
if not isinstance(normalized_axis, T.Sequence):
normalized_axis = [normalized_axis]
if not isinstance(normalized_shape, T.Sequence):
normalized_axis = [normalized_shape]
assert isinstance(normalized_axis, (list, tuple))
assert isinstance(normalized_shape, (list, tuple))
assert len(normalized_axis) == len(
normalized_shape
), "The size of normalized_axis and normalized_shape are different"
assert len(set(normalized_axis)) == len(
normalized_axis
), "there are duplicate axis in list normalized_axis"
self.weight = Parameter(np.ones(normalized_shape, dtype="float32"))
self.bias = Parameter(np.zeros(normalized_shape, dtype="float32"))
else:
self.weight = None
self.bias = None
self.normalized_shape = normalized_shape
self.normalized_axis = normalized_axis
self.reset_parameters()
def reset_parameters(self):
if self.affine and self.weight and self.bias:
if self.affine:
ones_(self.weight)
zeros_(self.bias)
def forward(self, x):
if self.affine and not self.weight and not self.bias:
shape = []
if len(self.normalized_axis) == 1 and self.normalized_axis[0] == -1:
shape = x.shape[x.ndim - 1]
else:
for axis in self.normalized_axis:
shape.append(x.shape[axis])
self.weight = Parameter(np.ones(shape, dtype="float32"))
self.bias = Parameter(np.zeros(shape, dtype="float32"))
self.normalized_axis = [
num + x.ndim if num < 0 else num for num in self.normalized_axis
]
assert self.normalized_axis == sorted(
self.normalized_axis
), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format(
sorted(self.normalized_axis), self.normalized_axis
)
inp_shape = x.shape
for i in range(len(self.normalized_axis)):
assert (
inp_shape[self.normalized_axis[i]] == self.normalized_shape[i]
), "inp.shape={}, normalized_axis={}, normalized_shape={}, inp.shape[normalized_axis[{}]]({}) != normalized_shape[{}]({})".format(
x.shape,
self.normalized_axis,
self.normalized_shape,
i,
inp_shape[self.normalized_axis[i]],
i,
self.normalized_shape[i],
)
x = F.nn.general_norm(
x, self.normalized_axis, self.affine, self.weight, self.bias, self.eps
......@@ -266,5 +310,5 @@ class GeneralNorm(Module):
return x
def _module_info_string(self) -> str:
s = "normalized_axis={normalized_axis}, eps={eps}, affine={affine}"
s = "normalized_shape={normalized_shape}, normalized_axis={normalized_axis}, eps={eps}, affine={affine}"
return s.format(**self.__dict__)
......@@ -77,6 +77,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
auto mean = Tensor::make(mean_layout, cn);
auto rstd = Tensor::make(rstd_layout, cn);
if (inputs[0]->layout().is_empty()) {
return {out, mean, rstd};
}
if (p.affine) {
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd);
} else {
......
......@@ -22,6 +22,9 @@ GeneralNormForward::GeneralNormForward(
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
GeneralNormForward::GeneralNormForward(
......@@ -33,6 +36,9 @@ GeneralNormForward::GeneralNormForward(
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVarArray GeneralNormForward::make(
......@@ -90,6 +96,12 @@ size_t GeneralNormForward::get_workspace_size_bytes(
}
void GeneralNormForward::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(
output(0)->dev_tensor().empty() && output(1)->dev_tensor().empty() &&
output(2)->dev_tensor().empty());
return;
}
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
......@@ -105,6 +117,16 @@ void GeneralNormForward::scn_do_execute() {
}
}
GeneralNormForward::NodeProp* GeneralNormForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
if (input().size() == 3) {
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(2), NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GeneralNormForward) {
auto p = opr.param();
......
......@@ -30,6 +30,7 @@ private:
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
};
using GeneralNorm = GeneralNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册