From 40c8134ad0f3d43e36bf1cfce203107a7a2da521 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 21 Nov 2022 18:11:51 +0800 Subject: [PATCH] feat(dnn,src,imperative): add instancenorm GitOrigin-RevId: f71ae8ce3b9ea184fa5b7c871cd54657c4293785 --- dnn/scripts/opr_param_defs.py | 10 +- dnn/src/common/group_norm.cpp | 11 + dnn/src/common/handle_impl.h | 10 +- dnn/src/cuda/group_norm/group_norm_cuda.cu | 2 +- dnn/src/cuda/group_norm/opr_impl.cpp | 32 +- dnn/src/naive/group_norm/opr_impl.cpp | 2 +- imperative/python/megengine/functional/nn.py | 65 +++- .../python/megengine/module/normalization.py | 108 +++++-- imperative/python/megengine/tools/mge | 1 - .../megengine/traced_module/traced_module.py | 4 +- .../python/test/unit/module/test_module.py | 120 ++++++-- imperative/src/impl/ops/instance_norm.cpp | 106 +++++++ .../impl/transformations/dtype_promote.cpp | 1 + imperative/tablegen/generated/hash.txt | 12 +- imperative/tablegen/generated/opdef.cpp.inl | 104 +++++++ imperative/tablegen/generated/opdef.cpy.inl | 173 +++++++++++ imperative/tablegen/generated/opdef.h.inl | 17 ++ imperative/tablegen/generated/opdef.py.inl | 11 + src/core/include/megbrain/ir/ops.td | 2 + src/opr/impl/dnn/dnn.sereg.h | 221 ++++++++++++++ src/opr/impl/dnn/group_norm.cpp | 44 ++- src/opr/impl/dnn/instance_norm.cpp | 282 ++++++++++++++++++ .../include/megbrain/opr/dnn/instance_norm.h | 67 +++++ src/opr/test/dnn/instance_norm.cpp | 92 ++++++ 24 files changed, 1394 insertions(+), 103 deletions(-) create mode 100644 imperative/src/impl/ops/instance_norm.cpp create mode 100644 src/opr/impl/dnn/instance_norm.cpp create mode 100644 src/opr/include/megbrain/opr/dnn/instance_norm.h create mode 100644 src/opr/test/dnn/instance_norm.cpp diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index d09dacde5..ec8e7118d 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1269,6 +1269,11 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_fields('uint64', 'normalized_size', '1') ) +(pdef('Dropout') + .add_fields('float32', 'drop_prob', '0') + .add_fields('uint64', 'seed', '0') +) + (pdef('GroupNorm') .add_fields('bool', 'affine', 'true') .add_fields('float32', 'eps', '1e-5f') @@ -1276,11 +1281,6 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_enum_alias('Format', 'Convolution') ) -(pdef('Dropout') - .add_fields('float32', 'drop_prob', '0') - .add_fields('uint64', 'seed', '0') -) - (pdef('RNNCell'). add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2') ) diff --git a/dnn/src/common/group_norm.cpp b/dnn/src/common/group_norm.cpp index b4d391da8..35eaf2f20 100644 --- a/dnn/src/common/group_norm.cpp +++ b/dnn/src/common/group_norm.cpp @@ -11,6 +11,9 @@ void GroupNormBase::deduce_layout_fwd( TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { MEGDNN_MARK_USED_VAR(weight); MEGDNN_MARK_USED_VAR(bias); + megdnn_assert( + param().format == param::GroupNorm::Format::NCHW, + "The input of GroupNorm should be in NCHW format."); size_t N = data.shape[0]; size_t group = param().group; TensorLayout unnormalized_layout({N, group}, dtype::Float32()); @@ -39,6 +42,10 @@ void GroupNormBase::check_layout_fwd( megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str()); megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); + megdnn_assert(data.ndim == 4, "Only supports input of dim 4"); + megdnn_assert( + param().format == param::GroupNorm::Format::NCHW, + "The input of GroupNorm should be in NCHW format."); auto p = param(); size_t C = data.shape[1]; size_t group = p.group; @@ -110,6 +117,10 @@ void GroupNormBackward::check_exec( megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str()); megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); + megdnn_assert(data.ndim == 4, "Only supports input of dim 4"); + megdnn_assert( + param().format == param::GroupNorm::Format::NCHW, + "The input of GroupNorm should be in NCHW format."); if (p.affine) { megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str()); megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str()); diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 14396c36e..d53c73079 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -68,11 +68,11 @@ private: }; } // namespace megdnn -/*! - * \brief iterate though each operator class name; useful for explicit - * instantialization of create_operator<> templates - */ -// clang-format off + /*! + * \brief iterate though each operator class name; useful for explicit + * instantialization of create_operator<> templates + */ + // clang-format off #define MEGDNN_FOREACH_OPR_CLASS(cb) \ cb(ConvolutionForward) \ cb(ConvolutionBackwardData) \ diff --git a/dnn/src/cuda/group_norm/group_norm_cuda.cu b/dnn/src/cuda/group_norm/group_norm_cuda.cu index 8240218d6..c88baab57 100644 --- a/dnn/src/cuda/group_norm/group_norm_cuda.cu +++ b/dnn/src/cuda/group_norm/group_norm_cuda.cu @@ -433,7 +433,7 @@ __global__ void GetBackwardParamsCUDAKernel( const T scale_v = scale == nullptr ? T(1) : static_cast(scale[c]); sum1 += ds[index] * scale_v; sum2 += db[index] * scale_v; - const T scale_c = scale == nullptr ? T(0) : static_cast(scale[c]); + const T scale_c = scale == nullptr ? T(1) : static_cast(scale[c]); p1[index] = scale_c * var_inv; } diff --git a/dnn/src/cuda/group_norm/opr_impl.cpp b/dnn/src/cuda/group_norm/opr_impl.cpp index c10a0e9bb..e2a168f37 100644 --- a/dnn/src/cuda/group_norm/opr_impl.cpp +++ b/dnn/src/cuda/group_norm/opr_impl.cpp @@ -27,22 +27,16 @@ void GroupNormForwardImpl::exec( rstd.layout, workspace.size); auto p = param(); - using Format = param::GroupNorm::Format; float eps = p.eps; int group = p.group; bool affine = p.affine; auto layout = data.layout; - auto format = p.format; size_t N, C, H, W, imsize; - if (data.layout.ndim == 4 && format == Format::NCHW) { - N = layout.shape[0]; - C = layout.shape[1]; - H = layout.shape[2]; - W = layout.shape[3]; - imsize = H * W; - } else { - megdnn_throw(ssprintf("Unspport groupnorm input")); - } + N = layout.shape[0]; + C = layout.shape[1]; + H = layout.shape[2]; + W = layout.shape[3]; + imsize = H * W; auto stream = cuda_stream(handle()); using namespace ::megdnn::cuda::group_norm; @@ -94,22 +88,16 @@ void GroupNormBackwardImpl::exec( diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, ddata.layout, dweight.layout, dbias.layout, workspace.size); auto p = param(); - using Format = param::GroupNorm::Format; bool affine = p.affine; float eps = p.eps; int group = p.group; auto layout = data.layout; - auto format = p.format; size_t N, C, H, W, imsize; - if (layout.ndim == 4 && format == Format::NCHW) { - N = layout.shape[0]; - C = layout.shape[1]; - H = layout.shape[2]; - W = layout.shape[3]; - imsize = H * W; - } else { - megdnn_throw(ssprintf("Unspport groupnorm input")); - } + N = layout.shape[0]; + C = layout.shape[1]; + H = layout.shape[2]; + W = layout.shape[3]; + imsize = H * W; auto stream = cuda_stream(handle()); using namespace ::megdnn::cuda::group_norm; diff --git a/dnn/src/naive/group_norm/opr_impl.cpp b/dnn/src/naive/group_norm/opr_impl.cpp index a0b1ba02b..c09ee3adb 100644 --- a/dnn/src/naive/group_norm/opr_impl.cpp +++ b/dnn/src/naive/group_norm/opr_impl.cpp @@ -54,7 +54,7 @@ void forward( } else { for (size_t j = 0; j < inner_size; j++) { dst.ptr()[i * inner_size + j] = - (data.ptr()[i * inner_size + j] - slice_mean) / slice_std; + (data.ptr()[i * inner_size + j] - slice_mean) * slice_std; } } mean.ptr()[i] = static_cast(slice_mean); diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0c3fbad09..2ce096a52 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -62,6 +62,7 @@ __all__ = [ "gelu", "group_norm", "hsigmoid", + "instance_norm", "hswish", "indexing_one_hot", "layer_norm", @@ -1025,6 +1026,35 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return output +def instance_norm( + inp: Tensor, + affine: bool, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +): + r"""Applies instance normalization to the input. + + Refer to :class:`~.InstanceNorm` for more information. + + Args: + inp: input tensor. + affine: whether to use learnable affine parameters (weight, bias) + weight: scaling tensor in the learnable affine parameters. + See :math:`\gamma` in :class:`~.InstanceNorm`. + bias: bias tensor in the learnable affine parameters. + See :math:`\beta` in :class:`~.InstanceNorm`. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + """ + op = builtin.InstanceNorm(affine=affine, eps=eps) + if affine: + assert weight is not None, "weight must be provided if affine is True" + assert bias is not None, "bias must be provided if affine is True" + return apply(op, inp, weight, bias)[0] + else: + return apply(op, inp)[0] + + def group_norm( inp: Tensor, num_groups: int, @@ -1033,20 +1063,25 @@ def group_norm( bias: Optional[Tensor] = None, eps: float = 1e-5, ): - r"""Applies Group Normalization over a mini-batch of inputs as described in - the paper `Group Normalization `__ + r"""Applies group normalization to the input. + + Refer to :class:`~.GroupNorm` for more information. Args: inp: input tensor. num_groups: number of groups to separate the channels into - affine: whether to use weight and bias - weight: must not be None when the affine is true - bias: must not be None when the affine is true + See :attr:`num_groups` in :class:`~.GroupNorm`. + affine: whether to use learnable affine parameters (weight, bias) + weight: scaling tensor in the learnable affine parameters. + See :math:`\gamma` in :class:`~.GroupNorm`. + bias: bias tensor in the learnable affine parameters. + See :math:`\beta` in :class:`~.GroupNorm`. eps: a value added to the denominator for numerical stability. Default: 1e-5 """ op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,) if affine: - assert weight is not None and bias is not None + assert weight is not None, "weight must be provided if affine is True" + assert bias is not None, "bias must be provided if affine is True" return apply(op, inp, weight, bias)[0] else: return apply(op, inp)[0] @@ -1060,15 +1095,19 @@ def layer_norm( bias: Optional[Tensor] = None, eps: float = 1e-5, ): - r"""Applies layer normalization to the input. Support tensor of any shape as input. - Reference: https://arxiv.org/pdf/1803.08494.pdf. + r"""Applies layer normalization to the input. + + Refer to :class:`~.LayerNorm` for more information. Args: inp: input tensor. normalized_shape: the shape that you want to be normalizated - affine: whether to use weight and bias - weight: must not be None when the affine is true - bias: must not be None when the affine is true + See :attr:`normalized_shape` in :class:`~.LayerNorm`. + affine: whether to use learnable affine parameters (weight, bias) + weight: scaling tensor in the learnable affine parameters. + See :math:`\gamma` in :class:`~.LayerNorm`. + bias: bias tensor in the learnable affine parameters. + See :math:`\beta` in :class:`~.LayerNorm`. eps: a value added to the denominator for numerical stability. Default: 1e-5 """ if isinstance(normalized_shape, int): @@ -1088,10 +1127,10 @@ def layer_norm( normalized_size=normalized_size, ) if affine: - assert weight is not None and bias is not None + assert weight is not None, "weight must be provided if affine is True" + assert bias is not None, "bias must be provided if affine is True" return apply(op, inp, weight, bias)[0] else: - # assert weight is None and bias is None return apply(op, inp)[0] diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index e9a9bdb9a..e0790cb63 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -9,8 +9,33 @@ from .module import Module class GroupNorm(Module): - """Simple implementation of GroupNorm. Only support 4d tensor now. - Reference: https://arxiv.org/pdf/1803.08494.pdf. + r"""Applies Group Normalization over a mini-batch of inputs + Refer to `Group Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated separately over the each group. + :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of + attr:`num_channels` if :attr:`affine` is ``True``. + + Args: + num_groups (int): number of groups that divided from channels. + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: this module has learnable affine parameters (weight, bias) when affine is set to be True. + + Shape: + - Input: :math:`(N, C, H, W)` (now only support NCHW format tensor) + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> import numpy as np + >>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4)) + >>> m = M.GroupNorm(3, 3) + >>> out = m(inp) + >>> out.numpy().shape + (2, 3, 4, 4) """ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): @@ -48,9 +73,33 @@ class GroupNorm(Module): class InstanceNorm(Module): - """Simple implementation of InstanceNorm. Only support 4d tensor now. - Reference: https://arxiv.org/abs/1607.08022. - Note that InstanceNorm equals using GroupNome with num_groups=num_channels. + r"""Applies Instance Normalization over a mini-batch of inputs + Refer to `Instance Normalization https://arxiv.org/abs/1607.08022`__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. + :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of + attr:`num_channels` if :attr:`affine` is ``True``. + Note that InstanceNorm equals using GroupNorm with num_groups = num_channels. + + Args: + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: this module has learnable affine parameters (weight, bias) when affine is set to be True. + + Shape: + - Input: :math:`(N, C, H, W)` (now only support NCHW format tensor) + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> import numpy as np + >>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4)) + >>> m = M.InstanceNorm(3) + >>> out = m(inp) + >>> out.numpy().shape + (2, 3, 4, 4) """ def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): @@ -72,20 +121,7 @@ class InstanceNorm(Module): zeros_(self.bias) def forward(self, x): - N, C, H, W = x.shape - format = x.format - assert C == self.num_channels - x = x.reshape(N, C, -1) - mean = x.mean(axis=2, keepdims=True) - var = (x ** 2).mean(axis=2, keepdims=True) - mean * mean - - x = (x - mean) / F.sqrt(var + self.eps) - x = x.reshape(N, C, H, W) - if self.affine: - x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) - # FIXME(czh): remove this after making it a builtin op. - if format == "nhwc": - x = mge.amp.convert_tensor_format(x, inplace=False) + x = F.nn.instance_norm(x, self.affine, self.weight, self.bias, self.eps) return x def _module_info_string(self) -> str: @@ -94,8 +130,38 @@ class InstanceNorm(Module): class LayerNorm(Module): - """Simple implementation of LayerNorm. Support tensor of any shape as input. - Reference: https://arxiv.org/pdf/1803.08494.pdf. + r"""Applies Layer Normalization over a mini-batch of inputs + Refer to `Layer Normalization `_ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator. + + Args: + normalized_shape(int or tuple): input shape from an expected input of size + size :math:`[*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1]]`. + If it is a single integer, this module will normalize over the last dimension + which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: this module has learnable affine parameters (weight, bias) when affine is set to be True. + + Shape: + - Input: :math:`(N, *)` (2-D, 3-D, 4-D or 5-D tensor) + - Output: :math:`(N, *)` (same shape as input) + + Examples: + >>> import numpy as np + >>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4)) + >>> m = M.LayerNorm((4, 4)) + >>> out = m(inp) + >>> out.numpy().shape + (2, 3, 4, 4) """ def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs): diff --git a/imperative/python/megengine/tools/mge b/imperative/python/megengine/tools/mge index 3990811e2..2ac246810 100755 --- a/imperative/python/megengine/tools/mge +++ b/imperative/python/megengine/tools/mge @@ -1,6 +1,5 @@ #! /usr/bin/env python3 import argparse -import ntpath import os import pathlib import platform diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 0c33fe39a..dbce38f8f 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -511,8 +511,7 @@ class InternalGraph: inp = F.zeros(shape = (3, 4)) traced_module = tm.trace_module(net, inp) - Will produce the following ``InternalGraph``:: - + Will produce the following ``InternalGraph``: print(traced_module.graph) .. code-block:: text @@ -2463,6 +2462,7 @@ def trace_module( with active_module_tracer().patcher: global_scope = InternalGraph(name="top", qualname=net_name) active_module_tracer().push_scope(global_scope) + builder = TracedModuleBuilder(mod, True) NodeMixin.wrap_safe( diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 157a132c8..49a3fab42 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -16,6 +16,7 @@ from megengine.module import ( Conv2d, Dropout, GroupNorm, + InstanceNorm, Linear, MaxPool2d, Module, @@ -703,9 +704,15 @@ def test_module_compatible(): @pytest.mark.skip(reason="pytest aborted") -def test_grou_norm(): +@pytest.mark.parametrize("affine", [True, False]) +def test_grou_norm(affine): + num_groups = 256 + num_channels = 256 + weight_np = np.random.uniform(-0.5, 0.5, (num_channels)) + bias_np = np.random.uniform(-0.5, 0.5, (num_channels)) + class OriginGroupNormFunc(Module): - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): + def __init__(self, eps=1e-5, affine=True, **kwargs): super().__init__(**kwargs) assert num_channels % num_groups == 0 self.num_groups = num_groups @@ -713,8 +720,8 @@ def test_grou_norm(): self.eps = eps self.affine = affine if self.affine: - self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) - self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) + self.weight = Parameter(weight_np) + self.bias = Parameter(bias_np) else: self.weight = None self.bias = None @@ -732,36 +739,105 @@ def test_grou_norm(): ) return x - inp = np.random.randn(2, 256, 10, 16).astype("float32") + inp = np.random.uniform(-0.5, 0.5, (2, num_channels, 10, 16)).astype("float32") mge_inp = Tensor(inp) - mge_m = GroupNorm(32, 256) - + mge_m = GroupNorm(num_groups, num_channels, affine=affine) + mge_m.weight = Parameter(weight_np) + mge_m.bias = Parameter(bias_np) ori_inp = Tensor(inp) - ori_m = OriginGroupNormFunc(32, 256) - - targets = np.array(2) - mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters()) - ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters()) + ori_m = OriginGroupNormFunc(affine=affine) + mge_gm = mge.autodiff.GradManager().attach((*mge_m.parameters(), mge_inp)) + ori_gm = mge.autodiff.GradManager().attach((*ori_m.parameters(), ori_inp)) + dy = Tensor(np.random.uniform(-0.5, 0.5, inp.shape)) for i in range(2): with mge_gm: mge_output = mge_m(mge_inp) - loss = F.loss.square_loss( - mge_output.sum(), mge.tensor(targets, dtype=np.float32) - ) - mge_gm.backward(loss) + + mge_gm.backward(mge_output, dy) with ori_gm: ori_output = ori_m(ori_inp) - loss = F.loss.square_loss( - ori_output.sum(), mge.tensor(targets, dtype=np.float32) - ) - ori_gm.backward(loss) + + ori_gm.backward(ori_output, dy) np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) np.testing.assert_allclose( - mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03 + ori_inp.grad.numpy(), mge_inp.grad.numpy(), atol=1e-05 ) + if affine == True: + np.testing.assert_allclose( + mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), atol=1e-05 + ) + np.testing.assert_allclose( + mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), atol=1e-05 + ) + + +@pytest.mark.parametrize("affine", [True, False]) +def test_instance_norm(affine): + num_channels = 4 + weight_np = np.random.uniform(-0.5, 0.5, (num_channels)) + bias_np = np.random.uniform(-0.5, 0.5, (num_channels)) + + class OriginInstanceNormFunc(Module): + def __init__(self, eps=1e-5, affine=True, **kwargs): + super().__init__(**kwargs) + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(weight_np) + self.bias = Parameter(bias_np) + else: + self.weight = None + self.bias = None + + def forward(self, x): + N, C, H, W = x.shape + x = x.reshape(N, self.num_channels, -1) + mean = x.mean(axis=2, keepdims=True) + var = (x * x).mean(axis=2, keepdims=True) - mean * mean + x = (x - mean) / F.sqrt(var + self.eps) + x = x.reshape(N, C, H, W) + if self.affine: + x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape( + 1, -1, 1, 1 + ) + return x + + inp = np.random.uniform(-0.5, 0.5, (2, num_channels, 10, 16)).astype("float32") + mge_inp = Tensor(inp) + mge_m = InstanceNorm(num_channels, affine=affine) + mge_m.weight = Parameter(weight_np) + mge_m.bias = Parameter(bias_np) + + ori_inp = Tensor(inp) + ori_m = OriginInstanceNormFunc(affine=affine) + + mge_im = mge.autodiff.GradManager().attach((*mge_m.parameters(), mge_inp)) + ori_im = mge.autodiff.GradManager().attach((*ori_m.parameters(), ori_inp)) + dy = Tensor(np.random.uniform(-0.5, 0.5, inp.shape)) + + for i in range(2): + with mge_im: + mge_output = mge_m(mge_inp) + + mge_im.backward(mge_output, dy) + + with ori_im: + ori_output = ori_m(ori_inp) + + ori_im.backward(ori_output, dy) + + np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) np.testing.assert_allclose( - mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03 + ori_inp.grad.numpy(), mge_inp.grad.numpy(), atol=1e-04 ) + if affine == True: + np.testing.assert_allclose( + mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), atol=1e-04 + ) + np.testing.assert_allclose( + mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), atol=1e-04 + ) diff --git a/imperative/src/impl/ops/instance_norm.cpp b/imperative/src/impl/ops/instance_norm.cpp new file mode 100644 index 000000000..aa99efcaa --- /dev/null +++ b/imperative/src/impl/ops/instance_norm.cpp @@ -0,0 +1,106 @@ +#include "megbrain/opr/dnn/instance_norm.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb::imperative { + +namespace instance_norm { + +cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + size_t nr_inp = inputs.size(); + auto p = op.param(); + mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); + OperatorNodeConfig config{op.make_name()}; + if (nr_inp == 3) { + return opr::InstanceNorm::make( + inputs[0], inputs[1], inputs[2], op.param(), config)[0] + .node() + ->owner_opr(); + } else { + return opr::InstanceNorm::make(inputs[0], op.param(), config)[0] + .node() + ->owner_opr(); + } +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& instance_norm = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto affine = instance_norm.affine; + mgb_assert( + (nr_inp == 3 && affine) || (nr_inp == 1 && !affine), + "num of inputs of pooling should be 1 or 3 but you give %zu", + inputs.size()); + + auto&& inp = inputs[0]; + auto& inp_cn = inp.comp_node; + + if (inp.layout.ndim == 0) { + return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}}, + false}; + } + size_t C = inputs[0].layout.shape[1]; + auto p = instance_norm.param(); + p.group = C; + + DnnOprHelper dnn_opr(p); + auto&& [oup_layout, mean_layout, rstd_layout] = + dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{}); + return {{{oup_layout, inp_cn, {}}, + {mean_layout, inp_cn, {}}, + {rstd_layout, inp_cn, {}}}, + true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto p = op_def.param(); + + mgb_assert( + (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), + "num of inputs of instancenorm should be 1 or 3 but you give %zu", + inputs.size()); + + auto cn = inputs[0]->comp_node(); + using Format = megdnn::param::GroupNorm::Format; + mgb_assert(p.format == Format::NCHW, "only support inputs in shape NCHW."); + size_t C = inputs[0]->shape()[1]; + p.group = C; + + DnnOprCaller caller(cn, p); + + auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>( + inputs[0]->layout(), TensorLayout{}, TensorLayout{}); + + auto out = Tensor::make(oup_layout, cn); + auto mean = Tensor::make(mean_layout, cn); + auto rstd = Tensor::make(rstd_layout, cn); + + if (p.affine) { + caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); + } else { + megdnn::TensorND empty_dnn; + caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd); + } + return {out, mean, rstd}; +} + +OP_TRAIT_REG(InstanceNorm, InstanceNorm) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace instance_norm +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 921c39528..6b0aa44d8 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -439,6 +439,7 @@ ValueRefList DTypePromoteTransformation::apply_transformation( return imperative::apply(op, inputs); } } + return imperative::apply(op, inputs); } diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index ce0ee0f37..7d4919be7 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py -4bd0317fd84b5065c8d88a7ca6241908 ../../src/core/include/megbrain/ir/ops.td -cb32cb1ef6b2ef4a7defaeb02ecd36e3 generated/opdef.h.inl -1c0230f60ddf3459de2aa4e16c1e2957 generated/opdef.cpp.inl -f6cbfd25f0d61e7b94c687733f5ae9b9 generated/opdef.py.inl -3a023199c39ea5611975b902a882bbba generated/opdef.cpy.inl +4489d2cae7002dbfef1359c4d4c8141a ../../dnn/scripts/opr_param_defs.py +465dad57c288e2e2d5fb356f0baef7d7 ../../src/core/include/megbrain/ir/ops.td +5d850de38a2583233da0fac2f5be1f1d generated/opdef.h.inl +dfb41c4fba4727b9474074c38ca169db generated/opdef.cpp.inl +e71787dd73df41d5f967af66a0a5e71e generated/opdef.py.inl +b67068fb053c20a065255c971e3d2082 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index ab292471d..365071498 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -4419,6 +4419,110 @@ OP_TRAIT_REG(InplaceAdd, InplaceAdd) .props(InplaceAdd_props_impl) .make_name(InplaceAdd_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNorm); + +namespace { +size_t InstanceNorm_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.affine)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.eps)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.group)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); + return val; +} +bool InstanceNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.affine != b_.affine) return false; + if (a_.eps != b_.eps) return false; + if (a_.group != b_.group) return false; + if (a_.format != b_.format) return false; + return true; +} +std::vector> InstanceNorm_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("affine", std::to_string(op_.affine)); + props_.emplace_back("eps", std::to_string(op_.eps)); + props_.emplace_back("group", std::to_string(op_.group)); + switch (op_.format){ + case InstanceNorm::Format::NCHW: + props_.emplace_back("format", "NCHW"); + break; + case InstanceNorm::Format::NHWC: + props_.emplace_back("format", "NHWC"); + break; + case InstanceNorm::Format::NHWCD4: + props_.emplace_back("format", "NHWCD4"); + break; + case InstanceNorm::Format::NCHW4: + props_.emplace_back("format", "NCHW4"); + break; + case InstanceNorm::Format::NCHW8: + props_.emplace_back("format", "NCHW8"); + break; + case InstanceNorm::Format::NCHW32: + props_.emplace_back("format", "NCHW32"); + break; + case InstanceNorm::Format::NCHW88: + props_.emplace_back("format", "NCHW88"); + break; + case InstanceNorm::Format::NCHW44: + props_.emplace_back("format", "NCHW44"); + break; + case InstanceNorm::Format::NCHW44_DOT: + props_.emplace_back("format", "NCHW44_DOT"); + break; + case InstanceNorm::Format::NCHW4_NCHW32: + props_.emplace_back("format", "NCHW4_NCHW32"); + break; + case InstanceNorm::Format::NCHW32_NCHW4: + props_.emplace_back("format", "NCHW32_NCHW4"); + break; + case InstanceNorm::Format::NCHW4_NCHW: + props_.emplace_back("format", "NCHW4_NCHW"); + break; + case InstanceNorm::Format::NHWC_NCHW: + props_.emplace_back("format", "NHWC_NCHW"); + break; + case InstanceNorm::Format::NHWC_NCHW4_IC_SMALL: + props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); + break; + case InstanceNorm::Format::NCHW_NCHW4_IC_SMALL: + props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); + break; + case InstanceNorm::Format::CHWN4: + props_.emplace_back("format", "CHWN4"); + break; + case InstanceNorm::Format::NCHW64: + props_.emplace_back("format", "NCHW64"); + break; + case InstanceNorm::Format::NCHW4_NHWC: + props_.emplace_back("format", "NCHW4_NHWC"); + break; + default: + props_.emplace_back("format", "INVALID"); + break; + } + return props_; +} +std::string InstanceNorm_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "InstanceNorm"; +} +} // anonymous namespace +OP_TRAIT_REG(InstanceNorm, InstanceNorm) + .hash(InstanceNorm_hash_impl) + .is_same_st(InstanceNorm_is_same_st_impl) + .props(InstanceNorm_props_impl) + .make_name(InstanceNorm_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(LAMBUpdate); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 363a78959..9582792fe 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -12787,6 +12787,178 @@ void _init_py_InplaceAdd(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(InplaceAdd::typeinfo(), &py_type).second); } +void _init_py_InstanceNorm_Format(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "Format", reinterpret_cast(e_type)) >= 0); +} + +PyOpDefBegin(InstanceNorm) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"affine", serialization::dump(opdef.affine)}, + {"eps", serialization::dump(opdef.eps)}, + {"group", serialization::dump(opdef.group)}, + {"format", serialization::dump(opdef.format)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("affine"); + if (iter != state.end()) { + opdef.affine = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("eps"); + if (iter != state.end()) { + opdef.eps = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("group"); + if (iter != state.end()) { + opdef.group = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("format"); + if (iter != state.end()) { + opdef.format = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(InstanceNorm) + +int PyOp(InstanceNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL}; + PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast(kwlist), &affine, &eps, &group, &format, &scope)) + return -1; + + if (affine) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().affine = + py::cast(py::handle(affine)); + } CATCH_ALL(-1) + } + + if (eps) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().eps = + py::cast(py::handle(eps)); + } CATCH_ALL(-1) + } + + if (group) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().group = + py::cast(py::handle(group)); + } CATCH_ALL(-1) + } + + if (format) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().format = + py::cast(py::handle(format)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(InstanceNorm)::py_getsetters[] = { + {const_cast("affine"), py_get_generic(InstanceNorm, affine), py_set_generic(InstanceNorm, affine), const_cast("affine"), NULL}, + {const_cast("eps"), py_get_generic(InstanceNorm, eps), py_set_generic(InstanceNorm, eps), const_cast("eps"), NULL}, + {const_cast("group"), py_get_generic(InstanceNorm, group), py_set_generic(InstanceNorm, group), const_cast("group"), NULL}, + {const_cast("format"), py_get_generic(InstanceNorm, format), py_set_generic(InstanceNorm, format), const_cast("format"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(InstanceNorm)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(InstanceNorm)::getstate, METH_NOARGS, "InstanceNorm getstate"}, + {const_cast("__setstate__"), PyOp(InstanceNorm)::setstate, METH_VARARGS, "InstanceNorm setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(InstanceNorm)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(InstanceNorm)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(InstanceNorm)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(InstanceNorm)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self, affine: bool = ..., eps: float = ..., group: int = ..., format: Union[str, Format] = ...) -> None\n" +}; + +void _init_py_InstanceNorm(py::module m) { + using py_op = PyOp(InstanceNorm); + auto& py_type = PyOpType(InstanceNorm); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.InstanceNorm"; + py_type.tp_basicsize = sizeof(PyOp(InstanceNorm)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "InstanceNorm"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(InstanceNorm), &PyOp(InstanceNorm)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + _init_py_InstanceNorm_Format(py_type); + + PyType_Modified(&py_type); + m.add_object("InstanceNorm", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(InstanceNorm::typeinfo(), &py_type).second); +} + PyOpDefBegin(LAMBUpdate) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -22258,6 +22430,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_IndexingSetMultiAxisVec(m); \ _init_py_IndexingSetOneHot(m); \ _init_py_InplaceAdd(m); \ + _init_py_InstanceNorm(m); \ _init_py_LAMBUpdate(m); \ _init_py_LRN(m); \ _init_py_LSQ(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index a130b9e27..4b876352b 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1167,6 +1167,23 @@ public: } }; +class InstanceNorm : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + using Format = ::megdnn::param::GroupNorm::Format; + bool affine = true; + float eps = 1e-5f; + uint32_t group = 1; + Format format = ::megdnn::param::GroupNorm::Format::NCHW; + InstanceNorm() = default; + InstanceNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); } + InstanceNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {} + ::megdnn::param::GroupNorm param() const { + return {affine, eps, group, format}; + } +}; + class LAMBUpdate : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index b3af8baf4..343fd120a 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1339,6 +1339,17 @@ py::class_, OpDef> InplaceAddInst(m, "In InplaceAddInst .def(py::init<>()); +py::class_, OpDef> InstanceNormInst(m, "InstanceNorm"); + +InstanceNormInst.attr("Format") = AdaptivePoolingInst.attr("Format"); + +InstanceNormInst + .def(py::init(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {}) + .def_readwrite("affine", &InstanceNorm::affine) + .def_readwrite("eps", &InstanceNorm::eps) + .def_readwrite("group", &InstanceNorm::group) + .def_readwrite("format", &InstanceNorm::format); + py::class_, OpDef> LAMBUpdateInst(m, "LAMBUpdate"); LAMBUpdateInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index eb56061a4..f2be5b4dc 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -525,6 +525,8 @@ def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; +def InstanceNorm: MgbHashableOp<"InstanceNorm",[GroupNormParam]>; + def RNN: MgbHashableOp<"RNN", [RNNParam]>; def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>; diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 9fddafa3b..7f674362e 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -6,6 +6,7 @@ #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/group_norm.h" #include "megbrain/opr/dnn/images2neibs.h" +#include "megbrain/opr/dnn/instance_norm.h" #include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" @@ -736,6 +737,215 @@ struct OprLoadDumpImplV2 { } }; +template <> +struct OprMaker { + using Param = opr::GroupNorm::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 3) { + return opr::InstanceNorm::make(i[0], i[1], i[2], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 1); + return opr::InstanceNorm::make(i[0], param, config)[0].node()->owner_opr(); + } + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::InstanceNorm; + using Param = opr::GroupNorm::Param; + using ElemwiseParam = opr::Elemwise::Param; + using ReduceParam = opr::Reduce::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto graph = inputs[0]->owner_graph(); + auto comp_node = inputs[0]->comp_node(); + // std::unique_ptr m_static_infer_manager; + auto opr_param = opr->cast_final_safe().param(); + float eps = opr_param.eps; + auto half = DTypeScalar(static_cast(0.5)); + auto param_eps = DTypeScalar(static_cast(eps)); + auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); + auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); + + auto origin_shape = opr::GetVarShape::make(inputs[0]).node(); + + TensorShape input_shape = + inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); + size_t N = input_shape[0]; + size_t C = input_shape[1]; + size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; + int size = inner_size / C; + HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = N; + ptr[1] = C; + ptr[2] = size; + auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); + auto inp = opr::Reshape::make(inputs[0], target_shape); + + auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2}); + auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL}); + auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2}); + auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL}); + auto var = + opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB}); + auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD}); + auto sqrt = + opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW}); + auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB}); + auto temp_inp = + opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); + auto res = opr::Reshape::make(temp_inp, origin_shape); + + if (inputs.size() == 3) { + auto mul_temp = + opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL}); + auto res = opr::Elemwise::make( + {mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD}); + return res.node()->owner_opr(); + } else { + return res.node()->owner_opr(); + } + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + // auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + + // return OprMaker::make(ctx.read_param(), + // inputs, ctx.graph(), config); + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + +// OprMaker in MGB_SEREG_OPR only support unique output opr +template <> +struct OprMaker { + using Param = opr::GroupNormBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 5) { + return opr::InstanceNormBackward::make( + i[0], i[1], i[2], i[3], i[4], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 4); + return opr::InstanceNormBackward::make( + i[0], i[1], i[2], i[3], param, config)[0] + .node() + ->owner_opr(); + } + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::InstanceNormBackward; + using Param = opr::GroupNormBackward::Param; + using ElemwiseParam = opr::Elemwise::Param; + using ReduceParam = opr::Reduce::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto rstd = inputs[4]; + auto graph = inputs[1]->owner_graph(); + auto comp_node = inputs[1]->comp_node(); + auto opr_param = opr->cast_final_safe().param(); + float eps = opr_param.eps; + auto half = DTypeScalar(static_cast(0.5)); + auto param_eps = DTypeScalar(static_cast(eps)); + auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); + auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); + auto const_node = + opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node}); + + TensorShape input_shape = + inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); + auto origin_shape = opr::GetVarShape::make(inputs[1]).node(); + size_t N = input_shape[0]; + size_t C = input_shape[1]; + size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; + int size = inner_size / C; + HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = N; + ptr[1] = C; + ptr[2] = size; + auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); + auto inp = opr::Reshape::make(inputs[1], target_shape); + + auto temp_rstd = + opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD}); + auto sqrt = + opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW}); + auto slice_std = opr::Elemwise::make( + {const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); + auto sub_mean = + opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB}); + auto x_hat = + opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL}); + x_hat = opr::Reshape::make(x_hat, origin_shape); + auto size_node = + opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node}); + auto temp1 = opr::Elemwise::make( + {slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV}); + + auto dx_hat = + opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL}); + HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32()); + auto* ptr2 = tshape.ptr(); + ptr2[0] = N; + ptr2[1] = C; + ptr2[2] = 1; + ptr2[3] = input_shape[2]; + ptr2[4] = input_shape[3]; + target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); + x_hat = opr::Reshape::make(x_hat, target_shape); + dx_hat = opr::Reshape::make(dx_hat, target_shape); + auto temp2 = + opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL}); + ptr2[2] = 1; + ptr2[3] = 1; + ptr2[4] = 1; + target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); + auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape); + auto sum_dx_hat = + opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape); + auto temp4 = + opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL}); + auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB}); + auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB}); + auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL}); + auto dx = opr::Reshape::make(dx_temp, origin_shape); + return dx.node()->owner_opr(); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + template struct MakeLocalShareCaller2 { template @@ -961,6 +1171,8 @@ MGB_SEREG_OPR(LayerNorm, 0); MGB_SEREG_OPR(LayerNormBackward, 0); MGB_SEREG_OPR(GroupNorm, 0); MGB_SEREG_OPR(GroupNormBackward, 0); +MGB_SEREG_OPR(InstanceNorm, 0); +MGB_SEREG_OPR(InstanceNormBackward, 0); MGB_SEREG_OPR(RNNCellForward, 6); MGB_SEREG_OPR(LSTMCellForward, 7); MGB_SEREG_OPR(RNNForward, 3); @@ -977,6 +1189,15 @@ MGB_SEREG_OPR_V2( GroupNormBackward, 0, (mgb::serialization::OprLoadDumpImplV2::replace_opr), VERSION_2, CURRENT_VERSION); +MGB_SEREG_OPR_V2( + InstanceNorm, 0, + (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_2, CURRENT_VERSION); +MGB_SEREG_OPR_V2( + InstanceNormBackward, 0, + (mgb::serialization::OprLoadDumpImplV2< + opr::InstanceNormBackward, 0>::replace_opr), + VERSION_2, CURRENT_VERSION); } // namespace opr } // namespace mgb diff --git a/src/opr/impl/dnn/group_norm.cpp b/src/opr/impl/dnn/group_norm.cpp index b70544136..48b5fc0b9 100644 --- a/src/opr/impl/dnn/group_norm.cpp +++ b/src/opr/impl/dnn/group_norm.cpp @@ -77,11 +77,28 @@ void GroupNormForward::get_output_var_shape( size_t GroupNormForward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { - return intl::MegDNNOprMethInvoker::get_workspace_in_bytes( - megdnn_opr(), this, input_shapes, output_shapes); +#define in(x) \ + { input_shapes[x], input(x)->dtype() } +#define out(x) \ + { output_shapes[x], output(x)->dtype() } + auto&& param = megdnn_opr()->param(); + if (param.affine) { + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), in(2), out(0), out(1), out(2)); + } else { + TensorLayout temp_weight = TensorLayout(dtype::Float32()); + TensorLayout temp_bias = TensorLayout(dtype::Float32()); + return megdnn_opr()->get_workspace_in_bytes( + in(0), temp_weight, temp_bias, out(0), out(1), out(2)); + } +#undef in +#undef out } void GroupNormForward::scn_do_execute() { + mgb_assert( + param().format == Param::Format::NCHW, + "only support inputs in shape NCHW."); if (param().affine) { megdnn_opr()->exec( input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), @@ -214,11 +231,30 @@ void GroupNormBackward::init_output_dtype() { size_t GroupNormBackward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { - return intl::MegDNNOprMethInvoker:: - get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes); +#define in(x) \ + { input_shapes[x], input(x)->dtype() } +#define out(x) \ + { output_shapes[x], output(x)->dtype() } + auto&& param = megdnn_opr()->param(); + if (param.affine) { + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), in(2), in(3), in(4), out(0), out(1), out(2)); + } else { + TensorLayout temp_weight = TensorLayout(dtype::Float32()); + TensorLayout temp_dweight = TensorLayout(dtype::Float32()); + TensorLayout temp_dbias = TensorLayout(dtype::Float32()); + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), temp_weight, in(2), in(3), out(0), temp_dweight, + temp_dbias); + } +#undef in +#undef out } void GroupNormBackward::scn_do_execute() { + mgb_assert( + param().format == Param::Format::NCHW, + "only support inputs in shape NCHW."); if (param().affine) { megdnn_opr()->exec( input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), diff --git a/src/opr/impl/dnn/instance_norm.cpp b/src/opr/impl/dnn/instance_norm.cpp new file mode 100644 index 000000000..babfdc57a --- /dev/null +++ b/src/opr/impl/dnn/instance_norm.cpp @@ -0,0 +1,282 @@ +#include "megbrain/opr/dnn/instance_norm.h" + +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; + +/* ==================== InstanceNormForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNormForward); + +InstanceNormForward::InstanceNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "instance_norm", {data, weight, bias}} { + init_megdnn_opr(*this, param); + + add_input({data, weight, bias}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +InstanceNormForward::InstanceNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "instance_norm", {data}} { + init_megdnn_opr(*this, param); + + add_input({data}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +SymbolVarArray InstanceNormForward::make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, + const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), weight.node(), bias.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray InstanceNormForward::make( + SymbolVar data, const Param& param, const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void InstanceNormForward::get_output_var_shape( + const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { + out_shape[0] = inp_shape[0]; + size_t N = inp_shape[0].shape[0]; + size_t C = inp_shape[0].shape[1]; + TensorShape unnormalized_shape{N, C}; + out_shape[1] = unnormalized_shape; + out_shape[2] = unnormalized_shape; +} + +size_t InstanceNormForward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { +#define in(x) \ + { input_shapes[x], input(x)->dtype() } +#define out(x) \ + { output_shapes[x], output(x)->dtype() } + auto&& param = megdnn_opr()->param(); + if (param.affine) { + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), in(2), out(0), out(1), out(2)); + } else { + TensorLayout temp_weight = TensorLayout(dtype::Float32()); + TensorLayout temp_bias = TensorLayout(dtype::Float32()); + return megdnn_opr()->get_workspace_in_bytes( + in(0), temp_weight, temp_bias, out(0), out(1), out(2)); + } +#undef in +#undef out +} + +void InstanceNormForward::scn_do_execute() { + auto p = param(); + mgb_assert(p.format == Param::Format::NCHW, "only support inputs in shape NCHW."); + size_t C = input(0)->dev_tensor().shape()[1]; + auto opr = const_cast(megdnn_opr()); + opr->param().group = C; + mgb_assert(C != 0, "error param!"); + if (param().affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), {}, {}, + output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); + } +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(InstanceNormForward) { + auto p = opr.param(); + SymbolVarArray grad; + VarNodeArray ret; + if (p.affine) { + mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); + grad = InstanceNormBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), + opr.param()); + } else { + mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); + grad = InstanceNormBackward::make( + out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); + } + + uint32_t nr_ret = p.affine ? 3 : 1; + for (uint32_t i = 0; i < nr_ret; ++i) { + ret.push_back(grad[i].node()); + } + return ret; +} +#endif + +/* ==================== InstanceNormBackward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNormBackward); + +InstanceNormBackward::InstanceNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "instance_norm_backward", + {diff, data, weight, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, weight, mean, rstd}); +} + +InstanceNormBackward::InstanceNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, + const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "instance_norm_backward", + {diff, data, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, mean, rstd}); + auto mark_empty_var = [&](VarNode* var) { + var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::VOLATILE_CONTENT); + }; + mark_empty_var(output(1)); + mark_empty_var(output(2)); +} + +SymbolVarArray InstanceNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), weight.node(), mean.node(), + rstd.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray InstanceNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), mean.node(), rstd.node(), + param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void InstanceNormBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); + if (param().affine) { + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); + } else { + TensorShape empty; + empty.ndim = 0; + mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); + } + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void InstanceNormBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(2)->dtype()); +} + +size_t InstanceNormBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { +#define in(x) \ + { input_shapes[x], input(x)->dtype() } +#define out(x) \ + { output_shapes[x], output(x)->dtype() } + auto&& param = megdnn_opr()->param(); + if (param.affine) { + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), in(2), in(3), in(4), out(0), out(1), out(2)); + } else { + TensorLayout temp_weight = TensorLayout(dtype::Float32()); + TensorLayout temp_dweight = TensorLayout(dtype::Float32()); + TensorLayout temp_dbias = TensorLayout(dtype::Float32()); + return megdnn_opr()->get_workspace_in_bytes( + in(0), in(1), temp_weight, in(2), in(3), out(0), temp_dweight, + temp_dbias); + } +#undef in +#undef out +} + +void InstanceNormBackward::scn_do_execute() { + auto p = param(); + mgb_assert(p.format == Param::Format::NCHW, "only support inputs in shape NCHW."); + size_t C = input(0)->dev_tensor().shape()[1]; + auto opr = const_cast(megdnn_opr()); + opr->param().group = C; + mgb_assert(C != 0, "error param!"); + + if (p.affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(3))); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + {}, input(2)->dev_tensor().as_megdnn(), + input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + {}, {}, intl::get_megdnn_workspace_from_var(output(3))); + } +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/dnn/instance_norm.h b/src/opr/include/megbrain/opr/dnn/instance_norm.h new file mode 100644 index 000000000..27fa477a7 --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/instance_norm.h @@ -0,0 +1,67 @@ +#pragma once + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" + +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + InstanceNormForward, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC InstanceNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC InstanceNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void get_output_var_shape( + const TensorShapeArray& inp_shape, + TensorShapeArray& out_shape) const override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; +using InstanceNorm = InstanceNormForward; + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + InstanceNormBackward, intl::MegDNNOprWrapperBwd) // { +public: + MGE_WIN_DECLSPEC_FUC InstanceNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC InstanceNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param = {}, const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/instance_norm.cpp b/src/opr/test/dnn/instance_norm.cpp new file mode 100644 index 000000000..be7a9f28a --- /dev/null +++ b/src/opr/test/dnn/instance_norm.cpp @@ -0,0 +1,92 @@ +#include "megbrain/opr/dnn/instance_norm.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" +#include "megbrain/test/helper.h" +#include "megbrain/test/megdnn_helper.h" + +#include "megdnn/oprs.h" + +#include +#include +#include +#include + +using namespace mgb; + +namespace { +using Param = opr::InstanceNormForward::Param; + +void run_forward(bool is_affine) { + using Checker = AutoOprChecker<3, 3>; + + Param param; + param.eps = 1e-5; + param.affine = is_affine; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto out = + opr::InstanceNormForward::make(inputs[0], inputs[1], inputs[2], param); + return {out[0], out[1], out[2]}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = + MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + auto inp_shape = inp[0]->shape(); + auto n_slices = inp_shape[0]; + auto C = inp_shape[1]; + param.group = C; + opr->param() = param; + + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(inp_shape); + dest[1].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices, C}); + dest[2].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices, C}); + std::vector workspace(opr->get_workspace_in_bytes( + inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(), + dest[1].layout(), dest[2].layout())); + opr->exec( + inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), + dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), + {workspace.data(), workspace.size()}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator src_gen(0.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions option; + option.numdiff_max_err = 1e-4; + Checker checker{make_graph, fwd}; + + checker.set_input_generator(0, gen); + checker.set_input_generator(1, gen); + checker.set_input_generator(2, gen); + checker.set_input_allow_grad(0, false); + checker.set_input_allow_grad(1, false); + checker.set_input_allow_grad(2, false); + checker.set_output_allow_grad(0, false); + checker.set_output_allow_grad(1, false); + checker.set_output_allow_grad(2, false); + + checker.run({TensorShape{2, 16, 2, 10}, TensorShape{16}, TensorShape{16}}, option) + .run({TensorShape{2, 32, 2, 10}, TensorShape{32}, TensorShape{32}}, option) + .run({TensorShape{2, 64, 16, 10}, TensorShape{64}, TensorShape{64}}, + option); +} + +TEST(TestOprDNN, InstanceNormForward) { + REQUIRE_GPU(1); + run_forward(true); +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab