From 08f7a9576da8765b318b56d39d3bdfe69e955441 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Mar 2023 10:52:13 +0800 Subject: [PATCH] fix(imperative): fix the parameter acquisition problem of generalnorm GitOrigin-RevId: 9d0e689c24307c00cd2060ae420180b775be2c5b --- dnn/src/common/general_norm.cpp | 6 +- imperative/python/megengine/functional/nn.py | 30 +++--- .../python/megengine/module/normalization.py | 94 ++++++++++++++----- imperative/src/impl/ops/general_norm.cpp | 4 + src/opr/impl/dnn/general_norm.cpp | 22 +++++ .../include/megbrain/opr/dnn/general_norm.h | 1 + 6 files changed, 112 insertions(+), 45 deletions(-) diff --git a/dnn/src/common/general_norm.cpp b/dnn/src/common/general_norm.cpp index f031f2296..59f2a51b7 100644 --- a/dnn/src/common/general_norm.cpp +++ b/dnn/src/common/general_norm.cpp @@ -24,11 +24,7 @@ void GeneralNormBase::deduce_layout_fwd( unnormalized_shape[idx] = data.shape[i]; TensorLayout unnormalized_layout = TensorLayout(unnormalized_shape, dtype::Float32()); - if (idx == 0) { - unnormalized_layout.ndim = 1; - unnormalized_layout.shape[0] = 1; - } else - unnormalized_layout.ndim = idx; + unnormalized_layout.ndim = idx; unnormalized_layout.init_contiguous_stride(); dst = data; diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 40689fb29..22a780f8e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -71,6 +71,7 @@ __all__ = [ "dropout", "embedding", "gelu", + "general_norm", "group_norm", "hsigmoid", "instance_norm", @@ -1173,25 +1174,24 @@ def general_norm( See :math:`\beta` in :class:`~.GeneralNorm`. eps: a value added to the denominator for numerical stability. Default: 1e-5 """ - if isinstance(normalized_axis, int): - normalized_axis = [ - normalized_axis, - ] - elif isinstance(normalized_axis, tuple): - normalized_axis = list(normalized_axis) - else: - assert isinstance(normalized_axis, list), "not support normalized_axis type" + if not isinstance(normalized_axis, Sequence): + normalized_axis = [normalized_axis] + assert isinstance(normalized_axis, (list, tuple)) assert len(normalized_axis) > 0, "normalization axis not specified" - normalized_axis.sort() - if len(normalized_axis) != 1: - assert normalized_axis[0] >= 0 - elif normalized_axis[0] == -1: - normalized_axis[0] = inp.shape[inp.ndim - 1] + normalized_axis = [num + inp.ndim if num < 0 else num for num in normalized_axis] + assert normalized_axis == sorted( + normalized_axis + ), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format( + sorted(normalized_axis), normalized_axis + ) + assert ( + normalized_axis[-1] < inp.ndim, + ), "the maximum axis in normalized_axis is greater than inp_shape.ndim" assert len(set(normalized_axis)) == len( normalized_axis - ), "there are duplicate axis in list normalized_axis" + ), "there are duplicate axis in normalized_axis" _reshape = [] _rereshape = [] @@ -1200,7 +1200,7 @@ def general_norm( ) != (len(normalized_axis) - 1) if _need_reshape: get_logger().warning( - "normalized_axis is discontinuous, and performance may be poor." + "normalized_axis is discontinuous, and performance may be poor" ) unnormalized_axis = list(set(range(inp.ndim)) - set(normalized_axis)) unnormalized_axis.sort() diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 6ef203c63..0b262d40c 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -1,9 +1,12 @@ +import typing as T + import numpy as np import megengine as mge import megengine.functional as F from megengine import Parameter +from ..logger import get_logger from .init import ones_, zeros_ from .module import Module @@ -209,7 +212,8 @@ class GeneralNorm(Module): The standard-deviation is calculated via the biased estimator. Args: - normalized_axis(int, list or tuple): the axis of input needs to be normalized. Default: -1 + normalized_shape(int, list or tuple): the shape of input needs to be normalized, normalized_shape must be specified when affine is true. When affine=true, we will directly use this shape to initialize weight/bias. Please ensure that the order is correct. Default: None + normalized_axis(int, list or tuple): the axis of input needs to be normalized, one-to-one correspondence between normalized_axis and normalized_shape. Default: -1 eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: this module has learnable affine parameters (weight, bias) when affine is set to be True. @@ -220,45 +224,85 @@ class GeneralNorm(Module): Examples: >>> import numpy as np >>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4)) - >>> m = M.GeneralNorm((0, 2)) + >>> m = M.GeneralNorm((2, 4), (0, 2)) + >>> out = m(inp) + >>> out.numpy().shape + (2, 3, 4, 4) + >>> m = M.GeneralNorm((3, 4), (1, -1)) # Please be careful. >>> out = m(inp) >>> out.numpy().shape (2, 3, 4, 4) + >>> m = M.GeneralNorm((2, 4, 3), (0, 2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, 2)). + >>> m = M.GeneralNorm((2, 4, 3), (0, -2, 1)) # Incorrect initialization, the order of normalized_axis is incorrect, should be adjusted to m = M.GeneralNorm((2, 3, 4), (0, 1, -2)). + >>> m = M.GeneralNorm((3, 4), (3, -1)) # Incorrect initialization, because axis=-1 and axis=3 are the same axis, namely axis=3. """ - def __init__(self, normalized_axis=-1, eps=1e-05, affine=True, **kwargs): + def __init__( + self, normalized_shape=None, normalized_axis=0, eps=1e-05, affine=True, **kwargs + ): super().__init__(**kwargs) - if isinstance(normalized_axis, int): - normalized_axis = (normalized_axis,) - if isinstance(normalized_axis, list): - normalized_axis.sort() - if isinstance(normalized_axis, tuple): - normalized_axis = sorted(normalized_axis) - self.normalized_axis = tuple(normalized_axis) - self.eps = eps self.affine = affine - self.weight = None - self.bias = None + + if self.affine: + assert ( + normalized_shape is not None + ), "normalized_shape must be specified when affine is true" + assert ( + normalized_axis is not None + ), "normalized_axis must be specified when affine is true" + if not isinstance(normalized_axis, T.Sequence): + normalized_axis = [normalized_axis] + if not isinstance(normalized_shape, T.Sequence): + normalized_axis = [normalized_shape] + assert isinstance(normalized_axis, (list, tuple)) + assert isinstance(normalized_shape, (list, tuple)) + + assert len(normalized_axis) == len( + normalized_shape + ), "The size of normalized_axis and normalized_shape are different" + assert len(set(normalized_axis)) == len( + normalized_axis + ), "there are duplicate axis in list normalized_axis" + + self.weight = Parameter(np.ones(normalized_shape, dtype="float32")) + self.bias = Parameter(np.zeros(normalized_shape, dtype="float32")) + else: + self.weight = None + self.bias = None + + self.normalized_shape = normalized_shape + self.normalized_axis = normalized_axis self.reset_parameters() def reset_parameters(self): - if self.affine and self.weight and self.bias: + if self.affine: ones_(self.weight) zeros_(self.bias) def forward(self, x): - if self.affine and not self.weight and not self.bias: - shape = [] - if len(self.normalized_axis) == 1 and self.normalized_axis[0] == -1: - shape = x.shape[x.ndim - 1] - else: - for axis in self.normalized_axis: - shape.append(x.shape[axis]) - - self.weight = Parameter(np.ones(shape, dtype="float32")) - self.bias = Parameter(np.zeros(shape, dtype="float32")) + self.normalized_axis = [ + num + x.ndim if num < 0 else num for num in self.normalized_axis + ] + assert self.normalized_axis == sorted( + self.normalized_axis + ), "The order of normalized_axis is incorrect, should be {}, but got {}. Please specify the values of axis in the correct order in normalized_axis".format( + sorted(self.normalized_axis), self.normalized_axis + ) + inp_shape = x.shape + for i in range(len(self.normalized_axis)): + assert ( + inp_shape[self.normalized_axis[i]] == self.normalized_shape[i] + ), "inp.shape={}, normalized_axis={}, normalized_shape={}, inp.shape[normalized_axis[{}]]({}) != normalized_shape[{}]({})".format( + x.shape, + self.normalized_axis, + self.normalized_shape, + i, + inp_shape[self.normalized_axis[i]], + i, + self.normalized_shape[i], + ) x = F.nn.general_norm( x, self.normalized_axis, self.affine, self.weight, self.bias, self.eps @@ -266,5 +310,5 @@ class GeneralNorm(Module): return x def _module_info_string(self) -> str: - s = "normalized_axis={normalized_axis}, eps={eps}, affine={affine}" + s = "normalized_shape={normalized_shape}, normalized_axis={normalized_axis}, eps={eps}, affine={affine}" return s.format(**self.__dict__) diff --git a/imperative/src/impl/ops/general_norm.cpp b/imperative/src/impl/ops/general_norm.cpp index 7d30b6e39..3a1ee51f6 100644 --- a/imperative/src/impl/ops/general_norm.cpp +++ b/imperative/src/impl/ops/general_norm.cpp @@ -77,6 +77,10 @@ SmallVector apply_on_physical_tensor( auto mean = Tensor::make(mean_layout, cn); auto rstd = Tensor::make(rstd_layout, cn); + if (inputs[0]->layout().is_empty()) { + return {out, mean, rstd}; + } + if (p.affine) { caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); } else { diff --git a/src/opr/impl/dnn/general_norm.cpp b/src/opr/impl/dnn/general_norm.cpp index 0d1b7fe00..49d1e57c1 100644 --- a/src/opr/impl/dnn/general_norm.cpp +++ b/src/opr/impl/dnn/general_norm.cpp @@ -22,6 +22,9 @@ GeneralNormForward::GeneralNormForward( output(0)->dtype(data->dtype()); output(1)->dtype(dtype::Float32()); output(2)->dtype(dtype::Float32()); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } GeneralNormForward::GeneralNormForward( @@ -33,6 +36,9 @@ GeneralNormForward::GeneralNormForward( output(0)->dtype(data->dtype()); output(1)->dtype(dtype::Float32()); output(2)->dtype(dtype::Float32()); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVarArray GeneralNormForward::make( @@ -90,6 +96,12 @@ size_t GeneralNormForward::get_workspace_size_bytes( } void GeneralNormForward::scn_do_execute() { + if (input(0)->dev_tensor().empty()) { + mgb_assert( + output(0)->dev_tensor().empty() && output(1)->dev_tensor().empty() && + output(2)->dev_tensor().empty()); + return; + } if (param().affine) { megdnn_opr()->exec( input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), @@ -105,6 +117,16 @@ void GeneralNormForward::scn_do_execute() { } } +GeneralNormForward::NodeProp* GeneralNormForward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); + if (input().size() == 3) { + ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(2), NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return ret; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(GeneralNormForward) { auto p = opr.param(); diff --git a/src/opr/include/megbrain/opr/dnn/general_norm.h b/src/opr/include/megbrain/opr/dnn/general_norm.h index bddcfffed..276e5630a 100644 --- a/src/opr/include/megbrain/opr/dnn/general_norm.h +++ b/src/opr/include/megbrain/opr/dnn/general_norm.h @@ -30,6 +30,7 @@ private: const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const override; void scn_do_execute() override; + NodeProp* do_make_node_prop() const override; }; using GeneralNorm = GeneralNormForward; -- GitLab