提交 21086965 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add local response normalization

GitOrigin-RevId: 939a4d26ddff7724237f4b4ecbac054753642dac
上级 ca4c93de
......@@ -69,6 +69,7 @@ __all__ = [
"leaky_relu",
"linear",
"local_conv2d",
"local_response_norm",
"logsigmoid",
"logsumexp",
"logsoftmax",
......@@ -1746,6 +1747,53 @@ def pad(
return output
def local_response_norm(
inp: Tensor,
kernel_size: int = 5,
k: float = 2.0,
alpha: float = 1e-4,
beta: float = 0.75,
) -> Tensor:
r"""
Apply local response normalization to the input tensor.
Args:
kernel_size: the size of the kernel to apply LRN on.
k: hyperparameter k. The default vaule is 2.0.
alpha: hyperparameter alpha. The default value is 1e-4.
beta: hyperparameter beta. The default value is 0.75.
Example:
.. testcode::
from megengine import tensor
import megengine.functional as f
import numpy as np
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
print('pass')
Outputs:
.. testoutput::
pass
"""
op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,)
(output,) = apply(op, inp)
return output
@lru_cache(maxsize=None)
def _get_layerPixelShuffle(device, dtype, dim_order):
@subgraph("LayerPixelShuffle", dtype, device, 3)
......
......@@ -29,6 +29,7 @@ from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .lrn import LocalResponseNorm
from .module import Module
from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .padding import Pad
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union
from ..functional import local_response_norm
from .module import Module
class LocalResponseNorm(Module):
r"""
Apply local response normalization to the input tensor.
Args:
kernel_size: the size of the kernel to apply LRN on.
k: hyperparameter k. The default vaule is 2.0.
alpha: hyperparameter alpha. The default value is 1e-4.
beta: hyperparameter beta. The default value is 0.75.
Example:
.. testcode::
from megengine import tensor
import megengine.module as M
import numpy as np
inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
[ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
[ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
[14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
[19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
out = op(inp)
np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
print('pass')
Outputs:
.. testoutput::
pass
"""
def __init__(
self,
kernel_size: int = 5,
k: float = 2.0,
alpha: float = 1e-4,
beta: float = 0.75,
**kwargs
):
super(LocalResponseNorm, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.k = k
self.alpha = alpha
self.beta = beta
def forward(self, inp):
return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta)
......@@ -21,6 +21,7 @@
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
......@@ -654,4 +655,13 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback();
} // namespace padding
namespace lrn {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LRN&>(def);
mgb_assert(inputs.size() == 1);
return opr::LRN::make(inputs[0], op.param());
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
} // namespace LRN
} // namespace mgb::imperative
......@@ -422,4 +422,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> {
def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
def LRN: MgbHashableOp<"LRN", [LRNParam]>;
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册