未验证 提交 62fe3cf5 编写于 作者: L LoneRanger 提交者: GitHub

Fix Python IndexError of Case14: paddle.nn.functional.glu (#50016)

* 为split增加取值范围维度的判断

* 为glu的axis进行取值判断并添加单测

* 完善glu的单测

* fix glu
上级 3374600e
......@@ -72,5 +72,15 @@ class TestGLUV2(unittest.TestCase):
self.check_identity(fluid.CUDAPlace(0))
class TestGlu(unittest.TestCase):
def glu_axis_size(self):
paddle.enable_static()
x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='float32')
paddle.nn.functional.glu(x, axis=256)
def test_errors(self):
self.assertRaises(ValueError, self.glu_axis_size)
if __name__ == '__main__':
unittest.main()
......@@ -1622,6 +1622,13 @@ def glu(x, axis=-1, name=None):
check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], "glu"
)
rank = len(x.shape)
if not (-rank <= axis < rank):
raise ValueError(
"Expected value range of `axis` is [{}, {}), but received axis: {}".format(
-rank, rank, axis
)
)
a, b = chunk(x, 2, axis=axis, name=name)
gate = sigmoid(b, name=name)
out = paddle.multiply(a, gate, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册