未验证 提交 9c611210 编写于 作者: L Leo Chen 提交者: GitHub

Fix prelu for compatibility with saved model of old version (#26052)

* fix prelu for compatibility with saved model of old version

* reshape alpha
上级 1893cd6b
......@@ -2308,7 +2308,8 @@ class PRelu(layers.Layer):
#NOTE(zhiqiu): The _alpha_shape should be [1, channel] + [1] * len(input_shape[2:]), not [1, channel, 1, 1].
# However, the suffix 1 in the list is useless, since the tensor is viewed as one demension array during kernel calculation.
# And, input_shape is not required when mode is 'channel', so it is simplified.
self._alpha_shape = [1, channel]
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
self._alpha_shape = [1, channel, 1, 1]
elif mode == 'element':
assert isinstance(input_shape, (
list, tuple
......
......@@ -9670,7 +9670,8 @@ def prelu(x, mode, param_attr=None, name=None):
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
alpha_shape = [1, x.shape[1]]
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
assert len(
x.shape
......
......@@ -51,21 +51,22 @@ class PReluTest(OpTest):
if self.attrs == {'mode': "all"}:
alpha_np = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel"}:
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1]])
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
self.inputs = {'X': x_np, 'Alpha': alpha_np}
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100] to [1, 100, 1, 1] since np operands could not be broadcast together with shapes (2,100,3,4) (1,100)
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
reshaped_alpha = self.inputs['Alpha']
if self.attrs == {'mode': "channel"}:
self.inputs['Alpha'] = np.reshape(
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:]))
out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'],
0.) * self.inputs['Alpha']
out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册