提交 ac86d644 编写于 作者: M Megvii Engine Team

fix(mge/module): fix prelu error when use_symbolic_shape is true

GitOrigin-RevId: 25b9c4d41d0584922eaaf657b0b2d3c2d515b6e8
上级 4f6c5d8f
...@@ -239,12 +239,6 @@ class PReLU(Module): ...@@ -239,12 +239,6 @@ class PReLU(Module):
self.weight = Parameter(data=[init]) self.weight = Parameter(data=[init])
def forward(self, inputs): def forward(self, inputs):
assert self.weight.shape == (1,) or self.weight.shape == (
1,
int(inputs.shape[1]),
1,
1,
), "invalid weight's shape"
return prelu(inputs, self.weight) return prelu(inputs, self.weight)
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest
import megengine as mge import megengine as mge
from megengine.module import LeakyReLU from megengine.jit.tracing import set_symbolic_shape
from megengine.module import LeakyReLU, PReLU
def test_leaky_relu(): def test_leaky_relu():
...@@ -21,3 +23,19 @@ def test_leaky_relu(): ...@@ -21,3 +23,19 @@ def test_leaky_relu():
np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data) np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data)
np.testing.assert_equal(output.numpy(), np_output) np.testing.assert_equal(output.numpy(), np_output)
@pytest.mark.parametrize("shape", [(1, 64, 15, 15), (64,)])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_prelu(shape, use_symbolic):
old_flag = set_symbolic_shape(use_symbolic)
data = np.random.random(size=shape)
num_channel = 1 if len(shape) == 1 else shape[1]
prelu = PReLU(num_parameters=num_channel, init=0.25)
output = prelu(mge.Tensor(data))
np_output = np.maximum(data, 0) + prelu.weight.numpy() * np.minimum(data, 0)
set_symbolic_shape(old_flag)
np.testing.assert_allclose(output.numpy(), np_output, atol=1e-5)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册