提交 43e4f429 编写于 作者: S SunAhong1993

fix the lrn

上级 1cbb878c
......@@ -18,3 +18,4 @@ from .pad_two_input import PadWithTwoInput
from .pad_all_dim2 import PadAllDim2
from .pad_all_dim4 import PadAllDim4
from .pad_all_dim4_one_input import PadAllDim4WithOneInput
from .lrn import LocalResponseNorm
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class LocalResponseNorm(object):
def __init__(self,
size,
alpha=1e-4,
beta=0.75,
k=1.):
self.size = size
self.alpha = alpha
self.beta = beta
self.k = k
def __call__(self, x):
sizes = x.shape
dim = len(sizes)
if dim < 3:
raise ValueError(
'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim))
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
pad4d_shape = [self.size // 2, (self.size - 1) // 2, 0, 0]
pool2d_shape = (1, self.size)
pad5d_shape = [self.size // 2, (self.size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, self.size)
if dim == 3:
div = paddle.nn.functional.pad(div, pad=pad4d_shape)
div = paddle.nn.functional.avg_pool2d(
div, kernel_size=pool2d_shape, stride=1)
div = paddle.squeeze(div, axis=1)
else:
tmp = paddle.unsqueeze(x, axis=1)
reshape_shape = paddle.shape(tmp)
new_reshape_shape = paddle.cast(reshape_shape, "float32")
index = paddle.full(shape=[1], fill_value=-2, dtype="int32")
value = paddle.full(shape=[1], fill_value=-1, dtype="float32")
new_reshape_shape = paddle.scatter(new_reshape_shape, index, value)
new_reshape_shape = paddle.cast(new_reshape_shape, "int32")
div = paddle.reshape(div, shape=reshape_shape)
div = paddle.nn.functional.pad(div,
pad=pad5d_shape,
data_format='NCDHW')
div = paddle.nn.functional.avg_pool3d(
div, kernel_size=pool3d_shape, stride=1)
div = paddle.reshape(paddle.squeeze(div, axis=1), sizes)
div = paddle.scale(div, scale=self.alpha, bias=self.k)
div = paddle.pow(div, self.beta)
res = paddle.divide(x, div)
return res
\ No newline at end of file
......@@ -542,7 +542,6 @@ class OpSet9():
value=value,
mode=string(mode))
else:
print(pads_len)
raise Exception("The padding value is wrong!")
if not op_independent:
return node.name + '_paded'
......@@ -2031,7 +2030,7 @@ class OpSet9():
'k': bias
}
self.paddle_graph.add_layer(
"paddle.nn.LocalResponseNorm",
"custom_layer:LocalResponseNorm",
inputs={"x": val_x.name},
outputs=layer_outputs,
**layer_attrs)
......@@ -17,4 +17,5 @@ from .one_hot import one_hot
from .pad_two_input import pad_with_two_input
from .pad_all_dim2 import pad_all_dim2
from .pad_all_dim4 import pad_all_dim4
from .pad_all_dim4_one_input import pad_all_dim4_one_input
\ No newline at end of file
from .pad_all_dim4_one_input import pad_all_dim4_one_input
from .lrn import local_response_norm
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def local_response_norm(x, size, alpha=1e-4, beta=0.75, k=1.):
sizes = x.shape
dim = len(sizes)
if dim < 3:
raise ValueError(
'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim))
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
pad4d_shape = [size // 2, (size - 1) // 2, 0, 0]
pool2d_shape = (1, size)
pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, size)
if dim == 3:
div = paddle.nn.functional.pad(div, pad=pad4d_shape)
div = paddle.nn.functional.avg_pool2d(
div, kernel_size=pool2d_shape, stride=1)
div = paddle.squeeze(div, axis=1)
else:
tmp = paddle.unsqueeze(x, axis=1)
reshape_shape = paddle.shape(tmp)
new_reshape_shape = paddle.cast(reshape_shape, "float32")
index = paddle.full(shape=[1], fill_value=-2, dtype="int32")
value = paddle.full(shape=[1], fill_value=-1, dtype="float32")
new_reshape_shape = paddle.scatter(new_reshape_shape, index, value)
new_reshape_shape = paddle.cast(new_reshape_shape, "int32")
div = paddle.reshape(div, shape=reshape_shape)
div = paddle.nn.functional.pad(div,
pad=pad5d_shape,
data_format='NCDHW')
div = paddle.nn.functional.avg_pool3d(
div, kernel_size=pool3d_shape, stride=1)
div = paddle.reshape(paddle.squeeze(div, axis=1), sizes)
div = paddle.scale(div, scale=alpha, bias=k)
div = paddle.pow(div, beta)
res = paddle.divide(x, div)
return res
\ No newline at end of file
......@@ -1797,7 +1797,7 @@ class OpSet9():
'k': bias
}
self.paddle_graph.add_layer(
'paddle.nn.functional.local_response_norm',
"custom_layer:local_response_norm",
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册