diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c7fb75387aa31cab0cda67d020739f0ed48db823..1f8593a1f36e74c6c94477d39bec3dce23d925f3 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9864,7 +9864,7 @@ def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): #NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version. #NOTE(GuoxiaWang): support NHWC data format if data_format == 'NHWC': - alpha_shape = [1, 1, 1, x.shape[1]] + alpha_shape = [1, 1, 1, x.shape[-1]] else: alpha_shape = [1, x.shape[1], 1, 1]