未验证 提交 5e555279 编写于 作者: L Leo Chen 提交者: GitHub

Fix prelu for compatibility with saved model of old version (#26052) (#26130)

* fix prelu for compatibility with saved model of old version

* reshape alpha
上级 1e01335e
...@@ -2265,7 +2265,8 @@ class PRelu(layers.Layer): ...@@ -2265,7 +2265,8 @@ class PRelu(layers.Layer):
#NOTE(zhiqiu): The _alpha_shape should be [1, channel] + [1] * len(input_shape[2:]), not [1, channel, 1, 1]. #NOTE(zhiqiu): The _alpha_shape should be [1, channel] + [1] * len(input_shape[2:]), not [1, channel, 1, 1].
# However, the suffix 1 in the list is useless, since the tensor is viewed as one demension array during kernel calculation. # However, the suffix 1 in the list is useless, since the tensor is viewed as one demension array during kernel calculation.
# And, input_shape is not required when mode is 'channel', so it is simplified. # And, input_shape is not required when mode is 'channel', so it is simplified.
self._alpha_shape = [1, channel] #NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
self._alpha_shape = [1, channel, 1, 1]
elif mode == 'element': elif mode == 'element':
assert isinstance(input_shape, ( assert isinstance(input_shape, (
list, tuple list, tuple
......
...@@ -10671,7 +10671,8 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -10671,7 +10671,8 @@ def prelu(x, mode, param_attr=None, name=None):
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" ) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). #NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified. # To be consistent with Prelu, it is simplified.
alpha_shape = [1, x.shape[1]] #NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element': elif mode == 'element':
assert len( assert len(
x.shape x.shape
......
...@@ -51,21 +51,22 @@ class PReluTest(OpTest): ...@@ -51,21 +51,22 @@ class PReluTest(OpTest):
if self.attrs == {'mode': "all"}: if self.attrs == {'mode': "all"}:
alpha_np = np.random.uniform(-1, -0.5, (1)) alpha_np = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel"}: elif self.attrs == {'mode': "channel"}:
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1]]) alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
else: else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100] to [1, 100, 1, 1] since np operands could not be broadcast together with shapes (2,100,3,4) (1,100) # NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
reshaped_alpha = self.inputs['Alpha']
if self.attrs == {'mode': "channel"}: if self.attrs == {'mode': "channel"}:
self.inputs['Alpha'] = np.reshape( reshaped_alpha = np.reshape(
self.inputs['Alpha'], self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:])) [1, self.x_shape[1]] + [1] * len(self.x_shape[2:]))
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha
0.) * self.inputs['Alpha']
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册