未验证 提交 de46b159 编写于 作者: L lilong12 提交者: GitHub

Unify the rank of prelu alpha to 4, corresponding to [N, C, H, W], except for the all mode

上级 932aca16
......@@ -8693,7 +8693,7 @@ def prelu(x, mode, param_attr=None, name=None):
if mode == 'channel':
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
alpha_shape = x.shape[1:]
alpha_shape = [1, x.shape[1], x.shape[2], x.shape[3]]
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=helper.param_attr,
......
......@@ -712,7 +712,7 @@ class TestLayer(LayerTest):
self.assertTrue(
np.array_equal(btp1.bias.numpy(), btp2.bias.numpy()))
def test_prelu(self):
def prelu_test(self, mode):
inp_np = np.ones([5, 200, 100, 100]).astype('float32')
with self.static_graph():
data_t = layers.data(
......@@ -720,7 +720,6 @@ class TestLayer(LayerTest):
shape=[5, 200, 100, 100],
dtype="float32",
append_batch_size=False)
mode = 'channel'
out = layers.prelu(
data_t, mode, param_attr=ParamAttr(initializer=Constant(1.0)))
static_rlt = self.get_static_graph_result(
......@@ -732,7 +731,6 @@ class TestLayer(LayerTest):
shape=[5, 200, 100, 100],
dtype="float32",
append_batch_size=False)
mode = 'channel'
prelu = nn.PRelu(
'prelu',
mode=mode,
......@@ -742,7 +740,6 @@ class TestLayer(LayerTest):
feed={"input": inp_np}, fetch_list=[out])[0]
with self.dynamic_graph():
mode = 'channel'
prelu = nn.PRelu(
'prelu',
mode=mode,
......@@ -756,7 +753,6 @@ class TestLayer(LayerTest):
with self.dynamic_graph():
inp_np = np.random.randn(5, 200, 100, 100).astype("float32")
inp = base.to_variable(inp_np)
mode = 'channel'
prelu1 = nn.PRelu(
'prelu1',
mode=mode,
......@@ -779,6 +775,11 @@ class TestLayer(LayerTest):
self.assertTrue(
np.array_equal(prelu1.weight.numpy(), prelu2.weight.numpy()))
def test_prelu(self):
self.prelu_test("channel")
self.prelu_test("element")
self.prelu_test("all")
def test_embeding(self):
inp_word = np.array([[[1]]]).astype('int64')
dict_size = 20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册