未验证 提交 79eaac55 编写于 作者: W WuHaobo 提交者: GitHub

polish_tril_triu_docstring and add dygraph (#24055)

* Update creation.py
上级 d31a174f
......@@ -134,6 +134,14 @@ class TestTrilTriuOpAPI(unittest.TestCase):
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
def test_api_with_dygraph(self):
with fluid.dygraph.guard():
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.dygraph.to_variable(data)
tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu(x).numpy()
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
if __name__ == '__main__':
unittest.main()
......@@ -696,8 +696,6 @@ def tril(input, diagonal=0, name=None):
# [ 5, 6, 0, 0],
# [ 9, 10, 11, 0]])
.. code-block:: python
# example 2, positive diagonal value
tril = tensor.tril(x, diagonal=2)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
......@@ -706,8 +704,6 @@ def tril(input, diagonal=0, name=None):
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
.. code-block:: python
# example 3, negative diagonal value
tril = tensor.tril(x, diagonal=-1)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
......@@ -717,6 +713,9 @@ def tril(input, diagonal=0, name=None):
# [ 9, 10, 0, 0]])
"""
if in_dygraph_mode():
op = getattr(core.ops, 'tril_triu')
return op(input, 'diagonal', diagonal, "lower", True)
return _tril_triu_op(LayerHelper('tril', **locals()))
......@@ -771,8 +770,6 @@ def triu(input, diagonal=0, name=None):
# [ 0, 6, 7, 8],
# [ 0, 0, 11, 12]])
.. code-block:: python
# example 2, positive diagonal value
triu = tensor.triu(x, diagonal=2)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
......@@ -781,8 +778,6 @@ def triu(input, diagonal=0, name=None):
# [0, 0, 0, 8],
# [0, 0, 0, 0]])
.. code-block:: python
# example 3, negative diagonal value
triu = tensor.triu(x, diagonal=-1)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
......@@ -792,6 +787,9 @@ def triu(input, diagonal=0, name=None):
# [ 0, 10, 11, 12]])
"""
if in_dygraph_mode():
op = getattr(core.ops, 'tril_triu')
return op(input, 'diagonal', diagonal, "lower", False)
return _tril_triu_op(LayerHelper('triu', **locals()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册