From 0dab0fc23c0e7d0baa9ae713cc847f4d6419b90c Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 8 Sep 2020 18:35:37 +0800 Subject: [PATCH] add back triu in fluid (#27135) --- python/paddle/fluid/layers/tensor.py | 8 +++++++- .../fluid/tests/unittests/test_tril_triu_op.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index a90551c1b7b..89acfc6075b 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -36,7 +36,7 @@ __all__ = [ 'tensor_array_to_tensor', 'concat', 'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', - 'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye' + 'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye', 'triu' ] @@ -1725,3 +1725,9 @@ def ones_like(x, out=None): attrs={'value': 1.0}, outputs={'Out': [out]}) return out + + +@deprecated(since="2.0.0", update_to="paddle.triu") +def triu(input, diagonal=0, name=None): + import paddle + return paddle.tensor.triu(x=input, diagonal=diagonal, name=name) diff --git a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py index aed265b21b5..2cd2599f2ea 100644 --- a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py +++ b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py @@ -142,6 +142,18 @@ class TestTrilTriuOpAPI(unittest.TestCase): self.assertTrue(np.allclose(tril_out, np.tril(data))) self.assertTrue(np.allclose(triu_out, np.triu(data))) + def test_fluid_api(self): + data = np.random.random([1, 9, 9, 4]).astype('float32') + x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x') + triu_out = fluid.layers.triu(x) + + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + triu_out = exe.run(fluid.default_main_program(), + feed={"x": data}, + fetch_list=[triu_out]) + if __name__ == '__main__': unittest.main() -- GitLab