提交 40c8134a 编写于 作者: M Megvii Engine Team 提交者: 黄信达

feat(dnn,src,imperative): add instancenorm

GitOrigin-RevId: f71ae8ce3b9ea184fa5b7c871cd54657c4293785
上级 cfe9f4c2
...@@ -1269,6 +1269,11 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), ...@@ -1269,6 +1269,11 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
.add_fields('uint64', 'normalized_size', '1') .add_fields('uint64', 'normalized_size', '1')
) )
(pdef('Dropout')
.add_fields('float32', 'drop_prob', '0')
.add_fields('uint64', 'seed', '0')
)
(pdef('GroupNorm') (pdef('GroupNorm')
.add_fields('bool', 'affine', 'true') .add_fields('bool', 'affine', 'true')
.add_fields('float32', 'eps', '1e-5f') .add_fields('float32', 'eps', '1e-5f')
...@@ -1276,11 +1281,6 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), ...@@ -1276,11 +1281,6 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
.add_enum_alias('Format', 'Convolution') .add_enum_alias('Format', 'Convolution')
) )
(pdef('Dropout')
.add_fields('float32', 'drop_prob', '0')
.add_fields('uint64', 'seed', '0')
)
(pdef('RNNCell'). (pdef('RNNCell').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2') add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2')
) )
......
...@@ -11,6 +11,9 @@ void GroupNormBase::deduce_layout_fwd( ...@@ -11,6 +11,9 @@ void GroupNormBase::deduce_layout_fwd(
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight); MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias); MEGDNN_MARK_USED_VAR(bias);
megdnn_assert(
param().format == param::GroupNorm::Format::NCHW,
"The input of GroupNorm should be in NCHW format.");
size_t N = data.shape[0]; size_t N = data.shape[0];
size_t group = param().group; size_t group = param().group;
TensorLayout unnormalized_layout({N, group}, dtype::Float32()); TensorLayout unnormalized_layout({N, group}, dtype::Float32());
...@@ -39,6 +42,10 @@ void GroupNormBase::check_layout_fwd( ...@@ -39,6 +42,10 @@ void GroupNormBase::check_layout_fwd(
megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str()); megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str());
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
megdnn_assert(data.ndim == 4, "Only supports input of dim 4");
megdnn_assert(
param().format == param::GroupNorm::Format::NCHW,
"The input of GroupNorm should be in NCHW format.");
auto p = param(); auto p = param();
size_t C = data.shape[1]; size_t C = data.shape[1];
size_t group = p.group; size_t group = p.group;
...@@ -110,6 +117,10 @@ void GroupNormBackward::check_exec( ...@@ -110,6 +117,10 @@ void GroupNormBackward::check_exec(
megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str()); megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str());
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
megdnn_assert(data.ndim == 4, "Only supports input of dim 4");
megdnn_assert(
param().format == param::GroupNorm::Format::NCHW,
"The input of GroupNorm should be in NCHW format.");
if (p.affine) { if (p.affine) {
megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str()); megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str());
megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str()); megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str());
......
...@@ -68,11 +68,11 @@ private: ...@@ -68,11 +68,11 @@ private:
}; };
} // namespace megdnn } // namespace megdnn
/*! /*!
* \brief iterate though each operator class name; useful for explicit * \brief iterate though each operator class name; useful for explicit
* instantialization of create_operator<> templates * instantialization of create_operator<> templates
*/ */
// clang-format off // clang-format off
#define MEGDNN_FOREACH_OPR_CLASS(cb) \ #define MEGDNN_FOREACH_OPR_CLASS(cb) \
cb(ConvolutionForward) \ cb(ConvolutionForward) \
cb(ConvolutionBackwardData) \ cb(ConvolutionBackwardData) \
......
...@@ -433,7 +433,7 @@ __global__ void GetBackwardParamsCUDAKernel( ...@@ -433,7 +433,7 @@ __global__ void GetBackwardParamsCUDAKernel(
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]); const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
sum1 += ds[index] * scale_v; sum1 += ds[index] * scale_v;
sum2 += db[index] * scale_v; sum2 += db[index] * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]); const T scale_c = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
p1[index] = scale_c * var_inv; p1[index] = scale_c * var_inv;
} }
......
...@@ -27,22 +27,16 @@ void GroupNormForwardImpl::exec( ...@@ -27,22 +27,16 @@ void GroupNormForwardImpl::exec(
rstd.layout, workspace.size); rstd.layout, workspace.size);
auto p = param(); auto p = param();
using Format = param::GroupNorm::Format;
float eps = p.eps; float eps = p.eps;
int group = p.group; int group = p.group;
bool affine = p.affine; bool affine = p.affine;
auto layout = data.layout; auto layout = data.layout;
auto format = p.format;
size_t N, C, H, W, imsize; size_t N, C, H, W, imsize;
if (data.layout.ndim == 4 && format == Format::NCHW) {
N = layout.shape[0]; N = layout.shape[0];
C = layout.shape[1]; C = layout.shape[1];
H = layout.shape[2]; H = layout.shape[2];
W = layout.shape[3]; W = layout.shape[3];
imsize = H * W; imsize = H * W;
} else {
megdnn_throw(ssprintf("Unspport groupnorm input"));
}
auto stream = cuda_stream(handle()); auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::group_norm; using namespace ::megdnn::cuda::group_norm;
...@@ -94,22 +88,16 @@ void GroupNormBackwardImpl::exec( ...@@ -94,22 +88,16 @@ void GroupNormBackwardImpl::exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size); ddata.layout, dweight.layout, dbias.layout, workspace.size);
auto p = param(); auto p = param();
using Format = param::GroupNorm::Format;
bool affine = p.affine; bool affine = p.affine;
float eps = p.eps; float eps = p.eps;
int group = p.group; int group = p.group;
auto layout = data.layout; auto layout = data.layout;
auto format = p.format;
size_t N, C, H, W, imsize; size_t N, C, H, W, imsize;
if (layout.ndim == 4 && format == Format::NCHW) {
N = layout.shape[0]; N = layout.shape[0];
C = layout.shape[1]; C = layout.shape[1];
H = layout.shape[2]; H = layout.shape[2];
W = layout.shape[3]; W = layout.shape[3];
imsize = H * W; imsize = H * W;
} else {
megdnn_throw(ssprintf("Unspport groupnorm input"));
}
auto stream = cuda_stream(handle()); auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::group_norm; using namespace ::megdnn::cuda::group_norm;
......
...@@ -54,7 +54,7 @@ void forward( ...@@ -54,7 +54,7 @@ void forward(
} else { } else {
for (size_t j = 0; j < inner_size; j++) { for (size_t j = 0; j < inner_size; j++) {
dst.ptr<T>()[i * inner_size + j] = dst.ptr<T>()[i * inner_size + j] =
(data.ptr<T>()[i * inner_size + j] - slice_mean) / slice_std; (data.ptr<T>()[i * inner_size + j] - slice_mean) * slice_std;
} }
} }
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean); mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean);
......
...@@ -62,6 +62,7 @@ __all__ = [ ...@@ -62,6 +62,7 @@ __all__ = [
"gelu", "gelu",
"group_norm", "group_norm",
"hsigmoid", "hsigmoid",
"instance_norm",
"hswish", "hswish",
"indexing_one_hot", "indexing_one_hot",
"layer_norm", "layer_norm",
...@@ -1025,6 +1026,35 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: ...@@ -1025,6 +1026,35 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return output return output
def instance_norm(
inp: Tensor,
affine: bool,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
):
r"""Applies instance normalization to the input.
Refer to :class:`~.InstanceNorm` for more information.
Args:
inp: input tensor.
affine: whether to use learnable affine parameters (weight, bias)
weight: scaling tensor in the learnable affine parameters.
See :math:`\gamma` in :class:`~.InstanceNorm`.
bias: bias tensor in the learnable affine parameters.
See :math:`\beta` in :class:`~.InstanceNorm`.
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""
op = builtin.InstanceNorm(affine=affine, eps=eps)
if affine:
assert weight is not None, "weight must be provided if affine is True"
assert bias is not None, "bias must be provided if affine is True"
return apply(op, inp, weight, bias)[0]
else:
return apply(op, inp)[0]
def group_norm( def group_norm(
inp: Tensor, inp: Tensor,
num_groups: int, num_groups: int,
...@@ -1033,20 +1063,25 @@ def group_norm( ...@@ -1033,20 +1063,25 @@ def group_norm(
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
eps: float = 1e-5, eps: float = 1e-5,
): ):
r"""Applies Group Normalization over a mini-batch of inputs as described in r"""Applies group normalization to the input.
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
Refer to :class:`~.GroupNorm` for more information.
Args: Args:
inp: input tensor. inp: input tensor.
num_groups: number of groups to separate the channels into num_groups: number of groups to separate the channels into
affine: whether to use weight and bias See :attr:`num_groups` in :class:`~.GroupNorm`.
weight: must not be None when the affine is true affine: whether to use learnable affine parameters (weight, bias)
bias: must not be None when the affine is true weight: scaling tensor in the learnable affine parameters.
See :math:`\gamma` in :class:`~.GroupNorm`.
bias: bias tensor in the learnable affine parameters.
See :math:`\beta` in :class:`~.GroupNorm`.
eps: a value added to the denominator for numerical stability. Default: 1e-5 eps: a value added to the denominator for numerical stability. Default: 1e-5
""" """
op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,) op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,)
if affine: if affine:
assert weight is not None and bias is not None assert weight is not None, "weight must be provided if affine is True"
assert bias is not None, "bias must be provided if affine is True"
return apply(op, inp, weight, bias)[0] return apply(op, inp, weight, bias)[0]
else: else:
return apply(op, inp)[0] return apply(op, inp)[0]
...@@ -1060,15 +1095,19 @@ def layer_norm( ...@@ -1060,15 +1095,19 @@ def layer_norm(
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
eps: float = 1e-5, eps: float = 1e-5,
): ):
r"""Applies layer normalization to the input. Support tensor of any shape as input. r"""Applies layer normalization to the input.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
Refer to :class:`~.LayerNorm` for more information.
Args: Args:
inp: input tensor. inp: input tensor.
normalized_shape: the shape that you want to be normalizated normalized_shape: the shape that you want to be normalizated
affine: whether to use weight and bias See :attr:`normalized_shape` in :class:`~.LayerNorm`.
weight: must not be None when the affine is true affine: whether to use learnable affine parameters (weight, bias)
bias: must not be None when the affine is true weight: scaling tensor in the learnable affine parameters.
See :math:`\gamma` in :class:`~.LayerNorm`.
bias: bias tensor in the learnable affine parameters.
See :math:`\beta` in :class:`~.LayerNorm`.
eps: a value added to the denominator for numerical stability. Default: 1e-5 eps: a value added to the denominator for numerical stability. Default: 1e-5
""" """
if isinstance(normalized_shape, int): if isinstance(normalized_shape, int):
...@@ -1088,10 +1127,10 @@ def layer_norm( ...@@ -1088,10 +1127,10 @@ def layer_norm(
normalized_size=normalized_size, normalized_size=normalized_size,
) )
if affine: if affine:
assert weight is not None and bias is not None assert weight is not None, "weight must be provided if affine is True"
assert bias is not None, "bias must be provided if affine is True"
return apply(op, inp, weight, bias)[0] return apply(op, inp, weight, bias)[0]
else: else:
# assert weight is None and bias is None
return apply(op, inp)[0] return apply(op, inp)[0]
......
...@@ -9,8 +9,33 @@ from .module import Module ...@@ -9,8 +9,33 @@ from .module import Module
class GroupNorm(Module): class GroupNorm(Module):
"""Simple implementation of GroupNorm. Only support 4d tensor now. r"""Applies Group Normalization over a mini-batch of inputs
Reference: https://arxiv.org/pdf/1803.08494.pdf. Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the each group.
:math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of
attr:`num_channels` if :attr:`affine` is ``True``.
Args:
num_groups (int): number of groups that divided from channels.
num_channels (int): number of channels expected in input
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: this module has learnable affine parameters (weight, bias) when affine is set to be True.
Shape:
- Input: :math:`(N, C, H, W)` (now only support NCHW format tensor)
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.GroupNorm(3, 3)
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
""" """
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
...@@ -48,9 +73,33 @@ class GroupNorm(Module): ...@@ -48,9 +73,33 @@ class GroupNorm(Module):
class InstanceNorm(Module): class InstanceNorm(Module):
"""Simple implementation of InstanceNorm. Only support 4d tensor now. r"""Applies Instance Normalization over a mini-batch of inputs
Reference: https://arxiv.org/abs/1607.08022. Refer to `Instance Normalization https://arxiv.org/abs/1607.08022`__
Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.
:math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of
attr:`num_channels` if :attr:`affine` is ``True``.
Note that InstanceNorm equals using GroupNorm with num_groups = num_channels.
Args:
num_channels (int): number of channels expected in input
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: this module has learnable affine parameters (weight, bias) when affine is set to be True.
Shape:
- Input: :math:`(N, C, H, W)` (now only support NCHW format tensor)
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.InstanceNorm(3)
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
""" """
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
...@@ -72,20 +121,7 @@ class InstanceNorm(Module): ...@@ -72,20 +121,7 @@ class InstanceNorm(Module):
zeros_(self.bias) zeros_(self.bias)
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape x = F.nn.instance_norm(x, self.affine, self.weight, self.bias, self.eps)
format = x.format
assert C == self.num_channels
x = x.reshape(N, C, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x ** 2).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
# FIXME(czh): remove this after making it a builtin op.
if format == "nhwc":
x = mge.amp.convert_tensor_format(x, inplace=False)
return x return x
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
...@@ -94,8 +130,38 @@ class InstanceNorm(Module): ...@@ -94,8 +130,38 @@ class InstanceNorm(Module):
class LayerNorm(Module): class LayerNorm(Module):
"""Simple implementation of LayerNorm. Support tensor of any shape as input. r"""Applies Layer Normalization over a mini-batch of inputs
Reference: https://arxiv.org/pdf/1803.08494.pdf. Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator.
Args:
normalized_shape(int or tuple): input shape from an expected input of size
size :math:`[*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1]]`.
If it is a single integer, this module will normalize over the last dimension
which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: this module has learnable affine parameters (weight, bias) when affine is set to be True.
Shape:
- Input: :math:`(N, *)` (2-D, 3-D, 4-D or 5-D tensor)
- Output: :math:`(N, *)` (same shape as input)
Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.LayerNorm((4, 4))
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
""" """
def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs): def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs):
......
#! /usr/bin/env python3 #! /usr/bin/env python3
import argparse import argparse
import ntpath
import os import os
import pathlib import pathlib
import platform import platform
......
...@@ -511,8 +511,7 @@ class InternalGraph: ...@@ -511,8 +511,7 @@ class InternalGraph:
inp = F.zeros(shape = (3, 4)) inp = F.zeros(shape = (3, 4))
traced_module = tm.trace_module(net, inp) traced_module = tm.trace_module(net, inp)
Will produce the following ``InternalGraph``:: Will produce the following ``InternalGraph``:
print(traced_module.graph) print(traced_module.graph)
.. code-block:: text .. code-block:: text
...@@ -2463,6 +2462,7 @@ def trace_module( ...@@ -2463,6 +2462,7 @@ def trace_module(
with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph(name="top", qualname=net_name) global_scope = InternalGraph(name="top", qualname=net_name)
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True) builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
......
...@@ -16,6 +16,7 @@ from megengine.module import ( ...@@ -16,6 +16,7 @@ from megengine.module import (
Conv2d, Conv2d,
Dropout, Dropout,
GroupNorm, GroupNorm,
InstanceNorm,
Linear, Linear,
MaxPool2d, MaxPool2d,
Module, Module,
...@@ -703,9 +704,15 @@ def test_module_compatible(): ...@@ -703,9 +704,15 @@ def test_module_compatible():
@pytest.mark.skip(reason="pytest aborted") @pytest.mark.skip(reason="pytest aborted")
def test_grou_norm(): @pytest.mark.parametrize("affine", [True, False])
def test_grou_norm(affine):
num_groups = 256
num_channels = 256
weight_np = np.random.uniform(-0.5, 0.5, (num_channels))
bias_np = np.random.uniform(-0.5, 0.5, (num_channels))
class OriginGroupNormFunc(Module): class OriginGroupNormFunc(Module):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): def __init__(self, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert num_channels % num_groups == 0 assert num_channels % num_groups == 0
self.num_groups = num_groups self.num_groups = num_groups
...@@ -713,8 +720,8 @@ def test_grou_norm(): ...@@ -713,8 +720,8 @@ def test_grou_norm():
self.eps = eps self.eps = eps
self.affine = affine self.affine = affine
if self.affine: if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) self.weight = Parameter(weight_np)
self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) self.bias = Parameter(bias_np)
else: else:
self.weight = None self.weight = None
self.bias = None self.bias = None
...@@ -732,36 +739,105 @@ def test_grou_norm(): ...@@ -732,36 +739,105 @@ def test_grou_norm():
) )
return x return x
inp = np.random.randn(2, 256, 10, 16).astype("float32") inp = np.random.uniform(-0.5, 0.5, (2, num_channels, 10, 16)).astype("float32")
mge_inp = Tensor(inp) mge_inp = Tensor(inp)
mge_m = GroupNorm(32, 256) mge_m = GroupNorm(num_groups, num_channels, affine=affine)
mge_m.weight = Parameter(weight_np)
mge_m.bias = Parameter(bias_np)
ori_inp = Tensor(inp) ori_inp = Tensor(inp)
ori_m = OriginGroupNormFunc(32, 256) ori_m = OriginGroupNormFunc(affine=affine)
targets = np.array(2)
mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters())
ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters())
mge_gm = mge.autodiff.GradManager().attach((*mge_m.parameters(), mge_inp))
ori_gm = mge.autodiff.GradManager().attach((*ori_m.parameters(), ori_inp))
dy = Tensor(np.random.uniform(-0.5, 0.5, inp.shape))
for i in range(2): for i in range(2):
with mge_gm: with mge_gm:
mge_output = mge_m(mge_inp) mge_output = mge_m(mge_inp)
loss = F.loss.square_loss(
mge_output.sum(), mge.tensor(targets, dtype=np.float32) mge_gm.backward(mge_output, dy)
)
mge_gm.backward(loss)
with ori_gm: with ori_gm:
ori_output = ori_m(ori_inp) ori_output = ori_m(ori_inp)
loss = F.loss.square_loss(
ori_output.sum(), mge.tensor(targets, dtype=np.float32) ori_gm.backward(ori_output, dy)
np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05)
np.testing.assert_allclose(
ori_inp.grad.numpy(), mge_inp.grad.numpy(), atol=1e-05
)
if affine == True:
np.testing.assert_allclose(
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), atol=1e-05
)
np.testing.assert_allclose(
mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), atol=1e-05
)
@pytest.mark.parametrize("affine", [True, False])
def test_instance_norm(affine):
num_channels = 4
weight_np = np.random.uniform(-0.5, 0.5, (num_channels))
bias_np = np.random.uniform(-0.5, 0.5, (num_channels))
class OriginInstanceNormFunc(Module):
def __init__(self, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(weight_np)
self.bias = Parameter(bias_np)
else:
self.weight = None
self.bias = None
def forward(self, x):
N, C, H, W = x.shape
x = x.reshape(N, self.num_channels, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x * x).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(
1, -1, 1, 1
) )
ori_gm.backward(loss) return x
inp = np.random.uniform(-0.5, 0.5, (2, num_channels, 10, 16)).astype("float32")
mge_inp = Tensor(inp)
mge_m = InstanceNorm(num_channels, affine=affine)
mge_m.weight = Parameter(weight_np)
mge_m.bias = Parameter(bias_np)
ori_inp = Tensor(inp)
ori_m = OriginInstanceNormFunc(affine=affine)
mge_im = mge.autodiff.GradManager().attach((*mge_m.parameters(), mge_inp))
ori_im = mge.autodiff.GradManager().attach((*ori_m.parameters(), ori_inp))
dy = Tensor(np.random.uniform(-0.5, 0.5, inp.shape))
for i in range(2):
with mge_im:
mge_output = mge_m(mge_inp)
mge_im.backward(mge_output, dy)
with ori_im:
ori_output = ori_m(ori_inp)
ori_im.backward(ori_output, dy)
np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05)
np.testing.assert_allclose( np.testing.assert_allclose(
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03 ori_inp.grad.numpy(), mge_inp.grad.numpy(), atol=1e-04
)
if affine == True:
np.testing.assert_allclose(
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), atol=1e-04
) )
np.testing.assert_allclose( np.testing.assert_allclose(
mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03 mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), atol=1e-04
) )
#include "megbrain/opr/dnn/instance_norm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace instance_norm {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const InstanceNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::InstanceNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::InstanceNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& instance_norm = def.cast_final_safe<InstanceNorm>();
size_t nr_inp = inputs.size();
auto affine = instance_norm.affine;
mgb_assert(
(nr_inp == 3 && affine) || (nr_inp == 1 && !affine),
"num of inputs of pooling should be 1 or 3 but you give %zu",
inputs.size());
auto&& inp = inputs[0];
auto& inp_cn = inp.comp_node;
if (inp.layout.ndim == 0) {
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}}},
false};
}
size_t C = inputs[0].layout.shape[1];
auto p = instance_norm.param();
p.group = C;
DnnOprHelper<megdnn::GroupNorm> dnn_opr(p);
auto&& [oup_layout, mean_layout, rstd_layout] =
dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{});
return {{{oup_layout, inp_cn, {}},
{mean_layout, inp_cn, {}},
{rstd_layout, inp_cn, {}}},
true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<InstanceNorm>();
size_t nr_inp = inputs.size();
auto p = op_def.param();
mgb_assert(
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
"num of inputs of instancenorm should be 1 or 3 but you give %zu",
inputs.size());
auto cn = inputs[0]->comp_node();
using Format = megdnn::param::GroupNorm::Format;
mgb_assert(p.format == Format::NCHW, "only support inputs in shape NCHW.");
size_t C = inputs[0]->shape()[1];
p.group = C;
DnnOprCaller<megdnn::GroupNorm> caller(cn, p);
auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>(
inputs[0]->layout(), TensorLayout{}, TensorLayout{});
auto out = Tensor::make(oup_layout, cn);
auto mean = Tensor::make(mean_layout, cn);
auto rstd = Tensor::make(rstd_layout, cn);
if (p.affine) {
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd);
} else {
megdnn::TensorND empty_dnn;
caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd);
}
return {out, mean, rstd};
}
OP_TRAIT_REG(InstanceNorm, InstanceNorm)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace instance_norm
} // namespace mgb::imperative
\ No newline at end of file
...@@ -439,6 +439,7 @@ ValueRefList DTypePromoteTransformation::apply_transformation( ...@@ -439,6 +439,7 @@ ValueRefList DTypePromoteTransformation::apply_transformation(
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} }
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
......
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py 4489d2cae7002dbfef1359c4d4c8141a ../../dnn/scripts/opr_param_defs.py
4bd0317fd84b5065c8d88a7ca6241908 ../../src/core/include/megbrain/ir/ops.td 465dad57c288e2e2d5fb356f0baef7d7 ../../src/core/include/megbrain/ir/ops.td
cb32cb1ef6b2ef4a7defaeb02ecd36e3 generated/opdef.h.inl 5d850de38a2583233da0fac2f5be1f1d generated/opdef.h.inl
1c0230f60ddf3459de2aa4e16c1e2957 generated/opdef.cpp.inl dfb41c4fba4727b9474074c38ca169db generated/opdef.cpp.inl
f6cbfd25f0d61e7b94c687733f5ae9b9 generated/opdef.py.inl e71787dd73df41d5f967af66a0a5e71e generated/opdef.py.inl
3a023199c39ea5611975b902a882bbba generated/opdef.cpy.inl b67068fb053c20a065255c971e3d2082 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
...@@ -4419,6 +4419,110 @@ OP_TRAIT_REG(InplaceAdd, InplaceAdd) ...@@ -4419,6 +4419,110 @@ OP_TRAIT_REG(InplaceAdd, InplaceAdd)
.props(InplaceAdd_props_impl) .props(InplaceAdd_props_impl)
.make_name(InplaceAdd_make_name_impl); .make_name(InplaceAdd_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNorm);
namespace {
size_t InstanceNorm_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<InstanceNorm>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.affine));
val = mgb::hash_pair_combine(val, mgb::hash(op_.eps));
val = mgb::hash_pair_combine(val, mgb::hash(op_.group));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
return val;
}
bool InstanceNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<InstanceNorm>(),
&&b_ = rhs_.cast_final_safe<InstanceNorm>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.affine != b_.affine) return false;
if (a_.eps != b_.eps) return false;
if (a_.group != b_.group) return false;
if (a_.format != b_.format) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> InstanceNorm_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<InstanceNorm>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("affine", std::to_string(op_.affine));
props_.emplace_back("eps", std::to_string(op_.eps));
props_.emplace_back("group", std::to_string(op_.group));
switch (op_.format){
case InstanceNorm::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case InstanceNorm::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case InstanceNorm::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case InstanceNorm::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case InstanceNorm::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case InstanceNorm::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case InstanceNorm::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case InstanceNorm::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case InstanceNorm::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case InstanceNorm::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case InstanceNorm::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case InstanceNorm::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case InstanceNorm::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case InstanceNorm::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case InstanceNorm::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case InstanceNorm::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case InstanceNorm::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case InstanceNorm::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
return props_;
}
std::string InstanceNorm_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<InstanceNorm>();
static_cast<void>(op_);
return "InstanceNorm";
}
} // anonymous namespace
OP_TRAIT_REG(InstanceNorm, InstanceNorm)
.hash(InstanceNorm_hash_impl)
.is_same_st(InstanceNorm_is_same_st_impl)
.props(InstanceNorm_props_impl)
.make_name(InstanceNorm_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LAMBUpdate); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LAMBUpdate);
namespace { namespace {
......
...@@ -12787,6 +12787,178 @@ void _init_py_InplaceAdd(py::module m) { ...@@ -12787,6 +12787,178 @@ void _init_py_InplaceAdd(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(InplaceAdd::typeinfo(), &py_type).second); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(InplaceAdd::typeinfo(), &py_type).second);
} }
void _init_py_InstanceNorm_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<InstanceNorm::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(InstanceNorm) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"affine", serialization<decltype(opdef.affine)>::dump(opdef.affine)},
{"eps", serialization<decltype(opdef.eps)>::dump(opdef.eps)},
{"group", serialization<decltype(opdef.group)>::dump(opdef.group)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("affine");
if (iter != state.end()) {
opdef.affine = serialization<decltype(opdef.affine)>::load(iter->second);
}
}
{
auto&& iter = state.find("eps");
if (iter != state.end()) {
opdef.eps = serialization<decltype(opdef.eps)>::load(iter->second);
}
}
{
auto&& iter = state.find("group");
if (iter != state.end()) {
opdef.group = serialization<decltype(opdef.group)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(InstanceNorm)
int PyOp(InstanceNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL};
PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &affine, &eps, &group, &format, &scope))
return -1;
if (affine) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst().affine =
py::cast<decltype(InstanceNorm::affine)>(py::handle(affine));
} CATCH_ALL(-1)
}
if (eps) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst().eps =
py::cast<decltype(InstanceNorm::eps)>(py::handle(eps));
} CATCH_ALL(-1)
}
if (group) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst().group =
py::cast<decltype(InstanceNorm::group)>(py::handle(group));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(InstanceNorm)*>(self)->inst().format =
py::cast<decltype(InstanceNorm::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(InstanceNorm)::py_getsetters[] = {
{const_cast<char*>("affine"), py_get_generic(InstanceNorm, affine), py_set_generic(InstanceNorm, affine), const_cast<char*>("affine"), NULL},
{const_cast<char*>("eps"), py_get_generic(InstanceNorm, eps), py_set_generic(InstanceNorm, eps), const_cast<char*>("eps"), NULL},
{const_cast<char*>("group"), py_get_generic(InstanceNorm, group), py_set_generic(InstanceNorm, group), const_cast<char*>("group"), NULL},
{const_cast<char*>("format"), py_get_generic(InstanceNorm, format), py_set_generic(InstanceNorm, format), const_cast<char*>("format"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(InstanceNorm)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(InstanceNorm)::getstate, METH_NOARGS, "InstanceNorm getstate"},
{const_cast<char*>("__setstate__"), PyOp(InstanceNorm)::setstate, METH_VARARGS, "InstanceNorm setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(InstanceNorm)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(InstanceNorm)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(InstanceNorm)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(InstanceNorm)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, affine: bool = ..., eps: float = ..., group: int = ..., format: Union[str, Format] = ...) -> None\n"
};
void _init_py_InstanceNorm(py::module m) {
using py_op = PyOp(InstanceNorm);
auto& py_type = PyOpType(InstanceNorm);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.InstanceNorm";
py_type.tp_basicsize = sizeof(PyOp(InstanceNorm));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "InstanceNorm";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(InstanceNorm), &PyOp(InstanceNorm)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_InstanceNorm_Format(py_type);
PyType_Modified(&py_type);
m.add_object("InstanceNorm", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(InstanceNorm::typeinfo(), &py_type).second);
}
PyOpDefBegin(LAMBUpdate) // { PyOpDefBegin(LAMBUpdate) // {
static PyGetSetDef py_getsetters[]; static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[]; static PyMethodDef tp_methods[];
...@@ -22258,6 +22430,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { ...@@ -22258,6 +22430,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_IndexingSetMultiAxisVec(m); \ _init_py_IndexingSetMultiAxisVec(m); \
_init_py_IndexingSetOneHot(m); \ _init_py_IndexingSetOneHot(m); \
_init_py_InplaceAdd(m); \ _init_py_InplaceAdd(m); \
_init_py_InstanceNorm(m); \
_init_py_LAMBUpdate(m); \ _init_py_LAMBUpdate(m); \
_init_py_LRN(m); \ _init_py_LRN(m); \
_init_py_LSQ(m); \ _init_py_LSQ(m); \
......
...@@ -1167,6 +1167,23 @@ public: ...@@ -1167,6 +1167,23 @@ public:
} }
}; };
class InstanceNorm : public OpDefImplBase<InstanceNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Format = ::megdnn::param::GroupNorm::Format;
bool affine = true;
float eps = 1e-5f;
uint32_t group = 1;
Format format = ::megdnn::param::GroupNorm::Format::NCHW;
InstanceNorm() = default;
InstanceNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); }
InstanceNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {}
::megdnn::param::GroupNorm param() const {
return {affine, eps, group, format};
}
};
class LAMBUpdate : public OpDefImplBase<LAMBUpdate> { class LAMBUpdate : public OpDefImplBase<LAMBUpdate> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
...@@ -1339,6 +1339,17 @@ py::class_<InplaceAdd, std::shared_ptr<InplaceAdd>, OpDef> InplaceAddInst(m, "In ...@@ -1339,6 +1339,17 @@ py::class_<InplaceAdd, std::shared_ptr<InplaceAdd>, OpDef> InplaceAddInst(m, "In
InplaceAddInst InplaceAddInst
.def(py::init<>()); .def(py::init<>());
py::class_<InstanceNorm, std::shared_ptr<InstanceNorm>, OpDef> InstanceNormInst(m, "InstanceNorm");
InstanceNormInst.attr("Format") = AdaptivePoolingInst.attr("Format");
InstanceNormInst
.def(py::init<bool, float, uint32_t, ::megdnn::param::GroupNorm::Format, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {})
.def_readwrite("affine", &InstanceNorm::affine)
.def_readwrite("eps", &InstanceNorm::eps)
.def_readwrite("group", &InstanceNorm::group)
.def_readwrite("format", &InstanceNorm::format);
py::class_<LAMBUpdate, std::shared_ptr<LAMBUpdate>, OpDef> LAMBUpdateInst(m, "LAMBUpdate"); py::class_<LAMBUpdate, std::shared_ptr<LAMBUpdate>, OpDef> LAMBUpdateInst(m, "LAMBUpdate");
LAMBUpdateInst LAMBUpdateInst
......
...@@ -525,6 +525,8 @@ def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; ...@@ -525,6 +525,8 @@ def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;
def InstanceNorm: MgbHashableOp<"InstanceNorm",[GroupNormParam]>;
def RNN: MgbHashableOp<"RNN", [RNNParam]>; def RNN: MgbHashableOp<"RNN", [RNNParam]>;
def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>; def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/group_norm.h" #include "megbrain/opr/dnn/group_norm.h"
#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/instance_norm.h"
#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
...@@ -736,6 +737,215 @@ struct OprLoadDumpImplV2<opr::GroupNormBackward, 0> { ...@@ -736,6 +737,215 @@ struct OprLoadDumpImplV2<opr::GroupNormBackward, 0> {
} }
}; };
template <>
struct OprMaker<opr::InstanceNorm, 0> {
using Param = opr::GroupNorm::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::InstanceNorm::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 1);
return opr::InstanceNorm::make(i[0], param, config)[0].node()->owner_opr();
}
}
};
template <>
struct OprLoadDumpImplV2<opr::InstanceNorm, 0> {
using Opr = opr::InstanceNorm;
using Param = opr::GroupNorm::Param;
using ElemwiseParam = opr::Elemwise::Param;
using ReduceParam = opr::Reduce::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto graph = inputs[0]->owner_graph();
auto comp_node = inputs[0]->comp_node();
// std::unique_ptr<StaticInferManager> m_static_infer_manager;
auto opr_param = opr->cast_final_safe<opr::InstanceNorm>().param();
float eps = opr_param.eps;
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5));
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps));
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node});
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node});
auto origin_shape = opr::GetVarShape::make(inputs[0]).node();
TensorShape input_shape =
inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]);
size_t N = input_shape[0];
size_t C = input_shape[1];
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3];
int size = inner_size / C;
HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = N;
ptr[1] = C;
ptr[2] = size;
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node});
auto inp = opr::Reshape::make(inputs[0], target_shape);
auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2});
auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL});
auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2});
auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL});
auto var =
opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB});
auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD});
auto sqrt =
opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW});
auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB});
auto temp_inp =
opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV});
auto res = opr::Reshape::make(temp_inp, origin_shape);
if (inputs.size() == 3) {
auto mul_temp =
opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL});
auto res = opr::Elemwise::make(
{mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD});
return res.node()->owner_opr();
} else {
return res.node()->owner_opr();
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
// auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
// return OprMaker<opr::AdaptivePooling,0>::make(ctx.read_param<Param>(),
// inputs, ctx.graph(), config);
return OprMaker<opr::InstanceNorm, 0>::make(
ctx.read_param<Param>(), inputs, ctx.graph(), config);
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::InstanceNormBackward, 0> {
using Param = opr::GroupNormBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 5) {
return opr::InstanceNormBackward::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 4);
return opr::InstanceNormBackward::make(
i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
}
};
template <>
struct OprLoadDumpImplV2<opr::InstanceNormBackward, 0> {
using Opr = opr::InstanceNormBackward;
using Param = opr::GroupNormBackward::Param;
using ElemwiseParam = opr::Elemwise::Param;
using ReduceParam = opr::Reduce::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto rstd = inputs[4];
auto graph = inputs[1]->owner_graph();
auto comp_node = inputs[1]->comp_node();
auto opr_param = opr->cast_final_safe<opr::InstanceNormBackward>().param();
float eps = opr_param.eps;
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5));
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps));
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node});
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node});
auto const_node =
opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node});
TensorShape input_shape =
inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]);
auto origin_shape = opr::GetVarShape::make(inputs[1]).node();
size_t N = input_shape[0];
size_t C = input_shape[1];
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3];
int size = inner_size / C;
HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = N;
ptr[1] = C;
ptr[2] = size;
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node});
auto inp = opr::Reshape::make(inputs[1], target_shape);
auto temp_rstd =
opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD});
auto sqrt =
opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW});
auto slice_std = opr::Elemwise::make(
{const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV});
auto sub_mean =
opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB});
auto x_hat =
opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL});
x_hat = opr::Reshape::make(x_hat, origin_shape);
auto size_node =
opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node});
auto temp1 = opr::Elemwise::make(
{slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV});
auto dx_hat =
opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL});
HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32());
auto* ptr2 = tshape.ptr<dt_int32>();
ptr2[0] = N;
ptr2[1] = C;
ptr2[2] = 1;
ptr2[3] = input_shape[2];
ptr2[4] = input_shape[3];
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node});
x_hat = opr::Reshape::make(x_hat, target_shape);
dx_hat = opr::Reshape::make(dx_hat, target_shape);
auto temp2 =
opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL});
ptr2[2] = 1;
ptr2[3] = 1;
ptr2[4] = 1;
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node});
auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape);
auto sum_dx_hat =
opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape);
auto temp4 =
opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL});
auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB});
auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB});
auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL});
auto dx = opr::Reshape::make(dx_temp, origin_shape);
return dx.node()->owner_opr();
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
return OprMaker<opr::InstanceNormBackward, 0>::make(
ctx.read_param<Param>(), inputs, ctx.graph(), config);
}
};
template <class MegDNNConv = megdnn::LocalShare> template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 { struct MakeLocalShareCaller2 {
template <typename Opr> template <typename Opr>
...@@ -961,6 +1171,8 @@ MGB_SEREG_OPR(LayerNorm, 0); ...@@ -961,6 +1171,8 @@ MGB_SEREG_OPR(LayerNorm, 0);
MGB_SEREG_OPR(LayerNormBackward, 0); MGB_SEREG_OPR(LayerNormBackward, 0);
MGB_SEREG_OPR(GroupNorm, 0); MGB_SEREG_OPR(GroupNorm, 0);
MGB_SEREG_OPR(GroupNormBackward, 0); MGB_SEREG_OPR(GroupNormBackward, 0);
MGB_SEREG_OPR(InstanceNorm, 0);
MGB_SEREG_OPR(InstanceNormBackward, 0);
MGB_SEREG_OPR(RNNCellForward, 6); MGB_SEREG_OPR(RNNCellForward, 6);
MGB_SEREG_OPR(LSTMCellForward, 7); MGB_SEREG_OPR(LSTMCellForward, 7);
MGB_SEREG_OPR(RNNForward, 3); MGB_SEREG_OPR(RNNForward, 3);
...@@ -977,6 +1189,15 @@ MGB_SEREG_OPR_V2( ...@@ -977,6 +1189,15 @@ MGB_SEREG_OPR_V2(
GroupNormBackward, 0, GroupNormBackward, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNormBackward, 0>::replace_opr), (mgb::serialization::OprLoadDumpImplV2<opr::GroupNormBackward, 0>::replace_opr),
VERSION_2, CURRENT_VERSION); VERSION_2, CURRENT_VERSION);
MGB_SEREG_OPR_V2(
InstanceNorm, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::InstanceNorm, 0>::replace_opr),
VERSION_2, CURRENT_VERSION);
MGB_SEREG_OPR_V2(
InstanceNormBackward, 0,
(mgb::serialization::OprLoadDumpImplV2<
opr::InstanceNormBackward, 0>::replace_opr),
VERSION_2, CURRENT_VERSION);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -77,11 +77,28 @@ void GroupNormForward::get_output_var_shape( ...@@ -77,11 +77,28 @@ void GroupNormForward::get_output_var_shape(
size_t GroupNormForward::get_workspace_size_bytes( size_t GroupNormForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const { const TensorShapeArray& output_shapes) const {
return intl::MegDNNOprMethInvoker<megdnn::GroupNormForward>::get_workspace_in_bytes( #define in(x) \
megdnn_opr(), this, input_shapes, output_shapes); { input_shapes[x], input(x)->dtype() }
#define out(x) \
{ output_shapes[x], output(x)->dtype() }
auto&& param = megdnn_opr()->param();
if (param.affine) {
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), in(2), out(0), out(1), out(2));
} else {
TensorLayout temp_weight = TensorLayout(dtype::Float32());
TensorLayout temp_bias = TensorLayout(dtype::Float32());
return megdnn_opr()->get_workspace_in_bytes(
in(0), temp_weight, temp_bias, out(0), out(1), out(2));
}
#undef in
#undef out
} }
void GroupNormForward::scn_do_execute() { void GroupNormForward::scn_do_execute() {
mgb_assert(
param().format == Param::Format::NCHW,
"only support inputs in shape NCHW.");
if (param().affine) { if (param().affine) {
megdnn_opr()->exec( megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
...@@ -214,11 +231,30 @@ void GroupNormBackward::init_output_dtype() { ...@@ -214,11 +231,30 @@ void GroupNormBackward::init_output_dtype() {
size_t GroupNormBackward::get_workspace_size_bytes( size_t GroupNormBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const { const TensorShapeArray& output_shapes) const {
return intl::MegDNNOprMethInvoker<megdnn::GroupNormBackward>:: #define in(x) \
get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes); { input_shapes[x], input(x)->dtype() }
#define out(x) \
{ output_shapes[x], output(x)->dtype() }
auto&& param = megdnn_opr()->param();
if (param.affine) {
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), in(2), in(3), in(4), out(0), out(1), out(2));
} else {
TensorLayout temp_weight = TensorLayout(dtype::Float32());
TensorLayout temp_dweight = TensorLayout(dtype::Float32());
TensorLayout temp_dbias = TensorLayout(dtype::Float32());
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), temp_weight, in(2), in(3), out(0), temp_dweight,
temp_dbias);
}
#undef in
#undef out
} }
void GroupNormBackward::scn_do_execute() { void GroupNormBackward::scn_do_execute() {
mgb_assert(
param().format == Param::Format::NCHW,
"only support inputs in shape NCHW.");
if (param().affine) { if (param().affine) {
megdnn_opr()->exec( megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
......
#include "megbrain/opr/dnn/instance_norm.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"
#include "../internal/megdnn_opr_wrapper.inl"
using namespace mgb;
using namespace opr;
/* ==================== InstanceNormForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNormForward);
InstanceNormForward::InstanceNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "instance_norm", {data, weight, bias}} {
init_megdnn_opr(*this, param);
add_input({data, weight, bias});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
InstanceNormForward::InstanceNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "instance_norm", {data}} {
init_megdnn_opr(*this, param);
add_input({data});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
SymbolVarArray InstanceNormForward::make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param,
const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<InstanceNormForward>(
data.node(), weight.node(), bias.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray InstanceNormForward::make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<InstanceNormForward>(
data.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void InstanceNormForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
out_shape[0] = inp_shape[0];
size_t N = inp_shape[0].shape[0];
size_t C = inp_shape[0].shape[1];
TensorShape unnormalized_shape{N, C};
out_shape[1] = unnormalized_shape;
out_shape[2] = unnormalized_shape;
}
size_t InstanceNormForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
#define in(x) \
{ input_shapes[x], input(x)->dtype() }
#define out(x) \
{ output_shapes[x], output(x)->dtype() }
auto&& param = megdnn_opr()->param();
if (param.affine) {
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), in(2), out(0), out(1), out(2));
} else {
TensorLayout temp_weight = TensorLayout(dtype::Float32());
TensorLayout temp_bias = TensorLayout(dtype::Float32());
return megdnn_opr()->get_workspace_in_bytes(
in(0), temp_weight, temp_bias, out(0), out(1), out(2));
}
#undef in
#undef out
}
void InstanceNormForward::scn_do_execute() {
auto p = param();
mgb_assert(p.format == Param::Format::NCHW, "only support inputs in shape NCHW.");
size_t C = input(0)->dev_tensor().shape()[1];
auto opr = const_cast<megdnn::GroupNormForward*>(megdnn_opr());
opr->param().group = C;
mgb_assert(C != 0, "error param!");
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), {}, {},
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(InstanceNormForward) {
auto p = opr.param();
SymbolVarArray grad;
VarNodeArray ret;
if (p.affine) {
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx);
grad = InstanceNormBackward::make(
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2),
opr.param());
} else {
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx);
grad = InstanceNormBackward::make(
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param());
}
uint32_t nr_ret = p.affine ? 3 : 1;
for (uint32_t i = 0; i < nr_ret; ++i) {
ret.push_back(grad[i].node());
}
return ret;
}
#endif
/* ==================== InstanceNormBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(InstanceNormBackward);
InstanceNormBackward::InstanceNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"instance_norm_backward",
{diff, data, weight, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, weight, mean, rstd});
}
InstanceNormBackward::InstanceNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param,
const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"instance_norm_backward",
{diff, data, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, mean, rstd});
auto mark_empty_var = [&](VarNode* var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(1));
mark_empty_var(output(2));
}
SymbolVarArray InstanceNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<InstanceNormBackward>(
diff.node(), data.node(), weight.node(), mean.node(),
rstd.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray InstanceNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<InstanceNormBackward>(
diff.node(), data.node(), mean.node(), rstd.node(),
param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void InstanceNormBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
if (param().affine) {
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2)));
} else {
TensorShape empty;
empty.ndim = 0;
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty));
}
this->init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::GroupNormBackward>::val);
}
void InstanceNormBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(2)->dtype());
}
size_t InstanceNormBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
#define in(x) \
{ input_shapes[x], input(x)->dtype() }
#define out(x) \
{ output_shapes[x], output(x)->dtype() }
auto&& param = megdnn_opr()->param();
if (param.affine) {
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), in(2), in(3), in(4), out(0), out(1), out(2));
} else {
TensorLayout temp_weight = TensorLayout(dtype::Float32());
TensorLayout temp_dweight = TensorLayout(dtype::Float32());
TensorLayout temp_dbias = TensorLayout(dtype::Float32());
return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), temp_weight, in(2), in(3), out(0), temp_dweight,
temp_dbias);
}
#undef in
#undef out
}
void InstanceNormBackward::scn_do_execute() {
auto p = param();
mgb_assert(p.format == Param::Format::NCHW, "only support inputs in shape NCHW.");
size_t C = input(0)->dev_tensor().shape()[1];
auto opr = const_cast<megdnn::GroupNormBackward*>(megdnn_opr());
opr->param().group = C;
mgb_assert(C != 0, "error param!");
if (p.affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(3)));
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
{}, input(2)->dev_tensor().as_megdnn(),
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
{}, {}, intl::get_megdnn_workspace_from_var(output(3)));
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
InstanceNormForward, intl::MegDNNOprWrapperFwd<megdnn::GroupNormForward>) // {
public:
MGE_WIN_DECLSPEC_FUC InstanceNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC InstanceNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void get_output_var_shape(
const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
using InstanceNorm = InstanceNormForward;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
InstanceNormBackward, intl::MegDNNOprWrapperBwd<megdnn::GroupNormBackward>) // {
public:
MGE_WIN_DECLSPEC_FUC InstanceNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC InstanceNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param = {}, const OperatorNodeConfig& config = {});
private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
} // namespace opr
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include "megbrain/opr/dnn/instance_norm.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megdnn/oprs.h"
#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>
using namespace mgb;
namespace {
using Param = opr::InstanceNormForward::Param;
void run_forward(bool is_affine) {
using Checker = AutoOprChecker<3, 3>;
Param param;
param.eps = 1e-5;
param.affine = is_affine;
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto out =
opr::InstanceNormForward::make(inputs[0], inputs[1], inputs[2], param);
return {out[0], out[1], out[2]};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu()))
->create_operator<megdnn::GroupNormForward>();
auto inp_shape = inp[0]->shape();
auto n_slices = inp_shape[0];
auto C = inp_shape[1];
param.group = C;
opr->param() = param;
dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize(inp_shape);
dest[1].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, C});
dest[2].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, C});
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(),
dest[1].layout(), dest[2].layout()));
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(),
{workspace.data(), workspace.size()});
};
auto gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f);
src = *src_gen(src.shape(), src.comp_node());
};
Checker::RunOptions option;
option.numdiff_max_err = 1e-4;
Checker checker{make_graph, fwd};
checker.set_input_generator(0, gen);
checker.set_input_generator(1, gen);
checker.set_input_generator(2, gen);
checker.set_input_allow_grad(0, false);
checker.set_input_allow_grad(1, false);
checker.set_input_allow_grad(2, false);
checker.set_output_allow_grad(0, false);
checker.set_output_allow_grad(1, false);
checker.set_output_allow_grad(2, false);
checker.run({TensorShape{2, 16, 2, 10}, TensorShape{16}, TensorShape{16}}, option)
.run({TensorShape{2, 32, 2, 10}, TensorShape{32}, TensorShape{32}}, option)
.run({TensorShape{2, 64, 16, 10}, TensorShape{64}, TensorShape{64}},
option);
}
TEST(TestOprDNN, InstanceNormForward) {
REQUIRE_GPU(1);
run_forward(true);
}
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册