From 0d78034913ccc4bafc71011cd8ec2bc7c4cee715 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Mon, 22 Feb 2021 18:00:33 +0800 Subject: [PATCH] fix lrn static bug, test=release/2.0 (#30982) --- python/paddle/nn/functional/norm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index fcda579332..7ba42f880c 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -483,17 +483,26 @@ def local_response_norm(x, channel_last = True if data_format[-1] == "C" else False + from functools import reduce + sum_sizes = reduce(lambda x, y: x * y, sizes[1:]) + div = paddle.unsqueeze(paddle.multiply(x, x), axis=1) if not channel_last: pad4d_shape = [0, 0, size // 2, (size - 1) // 2] pool2d_shape = (size, 1) - reshape_shape = [sizes[0], 1, sizes[1], sizes[2], -1] + reshape_shape = [ + sizes[0], 1, sizes[1], sizes[2], + int(sum_sizes / (sizes[1] * sizes[2])) + ] pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2] pool3d_shape = (size, 1, 1) else: pad4d_shape = [size // 2, (size - 1) // 2, 0, 0] pool2d_shape = (1, size) - reshape_shape = [sizes[0], 1, sizes[1], -1, sizes[-1]] + reshape_shape = [ + sizes[0], 1, sizes[1], int(sum_sizes / (sizes[1] * sizes[-1])), + sizes[-1] + ] pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0] pool3d_shape = (1, 1, size) -- GitLab