diff --git a/paddle/fluid/operators/triu_indices_op.cc b/paddle/fluid/operators/triu_indices_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d02b54f6083f8b80cabe21d284f42ffee17ebb2e --- /dev/null +++ b/paddle/fluid/operators/triu_indices_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/nullary.h" + +namespace paddle { +namespace operators { + +class TriuIndicesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class TriuIndicesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("out", + "Tensor, the output tensor, with the shape (2,x), x bounded by " + "[0,row*col])"); + AddAttr("row", + "int number, the input of triu_indices op" + "which describes the number of row of the matrix") + .SetDefault(0); + AddAttr("col", + "int number, the input of triu_indices op" + "which describes the number of col of the matrix") + .SetDefault(0); + AddAttr( + "offset", + "int number, the input of triu_indices op bounded by [1-rows,cols-1" + "which describes the dignalline index of the upper triangular part of " + "the matrix") + .SetDefault(0); + AddAttr("dtype", "data type ,the input of triu_indices op") + .SetDefault(framework::proto::VarType::INT64); + + AddComment(R"DOC( + TriuIndices Operator. + The triu_indices operator returns the indices of the upper triangular part of the matrix + whose rows and cols is known. It is a 2-by-x tensor, where the first row contains row coordinates + of all indices and the second row contains column coordinates. Indices are ordered based on + rows and then columns. The upper triangular part of the matrix is defined as the elements on + and below the diagonal. + The argument offset controls which diagonal to consider, default value is 0. + A positive value includes just as fewer diagonals above the main diagonal, + and similarly a negative value excludes just as fewer diagonals below the main diagonal + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(triu_indices, + TriuIndicesInferShapeFunctor, + PD_INFER_META(phi::TriuIndicesInferMeta)); + +REGISTER_OPERATOR( + triu_indices, + ops::TriuIndicesOp, + ops::TriuIndicesOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + TriuIndicesInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 2c538c27bf8de1c2612ba84875f80289be91715b..b02ffd319ce050008b28345e75bb2411189d636f 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2710,6 +2710,18 @@ data_type : x backward : trilinear_interp_grad +- api : triu_indices + args : (int row, int col, int offset, DataType dtype, Place place={}) + output : Tensor(out) + infer_meta : + func : TriuIndicesInferMeta + param : [row, col, offset, dtype] + kernel : + func : triu_indices + param : [row, col, offset, dtype] + data_type : dtype + backend : place + # python API: paddle.nn.initializer.TruncatedNormal - api : truncated_gaussian_random args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={}) diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 069359bae92b2e76e18a571243cbec7664f20467..4d11c462743d361d0530ac7ffbad4f8eff9c9059 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -152,4 +152,33 @@ void TrilIndicesInferMeta( out->set_dims(out_dims); out->set_dtype(dtype); } + +void TriuIndicesInferMeta( + int row, int col, int offset, DataType dtype, MetaTensor* out) { + // number of elements in the first row of the tril,bounded by [0, cols] + // use total item number minus bottom rectangle item number to get + // the above rectangle item number + // triu_size = rows * cols - tril_size + // so the `offset` need to be set as `offset-1` in order to include + // the item on the diagonal line + offset = offset - 1; + auto n_first_row = + offset > 0 ? std::min(col, 1 + offset) : row + offset > 0; + // number of elements in the last row of the tril, bounded by [0, cols] + auto n_last_row = std::max(0, std::min(col, row + offset)); + // number of rows, bounded by [0, rows] + auto n_row_all = std::max(0, std::min(row, row + offset)); + auto n_row_trapezoid = (n_last_row - n_first_row + 1); + // calculate # of elements in the top trapezoid + auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1; + // calculate # of elements in the bottom rectangle if there is any + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * col; + } + std::vector tmp = {2, row * col - tril_size}; + auto out_dims = phi::make_ddim(tmp); + out->set_dims(out_dims); + out->set_dtype(dtype); +} } // namespace phi diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index a9f1818e319576d658d6f7794935f7177a7b26d9..3ac2b0a7cf393020b915e691567775c5f9312ea9 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -74,4 +74,7 @@ void UniformRandomInferMeta(const IntArray& shape, void TrilIndicesInferMeta( int rows, int cols, int offset, DataType dtype, MetaTensor* out); + +void TriuIndicesInferMeta( + int row, int col, int offset, DataType dtype, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/triu_indices_kernel.cc b/paddle/phi/kernels/cpu/triu_indices_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cac57da555fe4cb28c114c5875220161336002a1 --- /dev/null +++ b/paddle/phi/kernels/cpu/triu_indices_kernel.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/triu_indices_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void TriuIndicesKernel(const Context& dev_ctx, + int row, + int col, + int offset, + DataType dtype, + DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + const auto& out_dims = out->dims(); + int64_t triu_size = out_dims[1]; + int64_t i = 0; + T c = std::max(0, offset), r = 0; + while (i < triu_size) { + out_data[i] = r; + out_data[triu_size + i++] = c; + + // move to the next column and check if (r, c) is still in bound + c += 1; + if (c >= col) { + r += 1; + // not typing std::max with scalar_t as it could be an unsigned type + // NOTE: not necessary to check if c is less than col or overflows here, + // because i and triu_size act as a guard. + c = std::max(0, r + offset); + } + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + triu_indices, CPU, ALL_LAYOUT, phi::TriuIndicesKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/triu_indices_kernel.cu b/paddle/phi/kernels/gpu/triu_indices_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..cece4bee8a42c12609bfaf9d1fec59ca5bd7de45 --- /dev/null +++ b/paddle/phi/kernels/gpu/triu_indices_kernel.cu @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/triu_indices_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) { + int64_t bXb_cX4 = b * b - cX4; + double sr = ::sqrt(static_cast(bXb_cX4)); + T res = ::__double2ll_rd((-b + sign * sr) / 2); + if (bXb_cX4 != static_cast(sr * sr)) { + int llsr = ::__double2ll_rd(sr); + int diff = ::__double2ll_ru( + ::sqrt(::fabs(static_cast(bXb_cX4 - llsr * llsr)))); + auto l = res > diff ? res - diff : 0; + auto r = res + diff + 1; + x <<= 1; + while (l + 1 < r) { + auto m = (l + r) >> 1; + if (sign * (b + m) * m > x) { + r = m; + } else { + l = m; + } + } + res = l; + } + return res; +} + +template +__device__ inline void get_coordinate_in_triu_trapezoid(int f, + int x, + T* row, + T* col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = -1 - f; + auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x; + *row = resolve_root_int(b, cX4, x, -1); + *col = (x - (((f - *row + 1) * *row) >> 1)) + *row; +} + +template +__global__ void triu_indices_kernel(T* out_data, + int col_offset, + int m_first_row, + int col, + int rectangle_size, + int triu_size) { + int linear_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_index < triu_size) { + T r, c; + if (linear_index < rectangle_size) { + // the coordinate is within the top rectangle + r = linear_index / col; + c = linear_index % col; + } else { + // the coordinate falls in the bottom trapezoid + get_coordinate_in_triu_trapezoid( + m_first_row, linear_index - rectangle_size, &r, &c); + r += rectangle_size / col; + } + + c += col_offset; + out_data[linear_index] = r; + out_data[linear_index + triu_size] = c; + } +} + +template +void TriuIndicesKernel(const Context& dev_ctx, + int row, + int col, + int offset, + DataType dtype, + DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + auto out_dims = out->dims(); + int triu_size = out_dims[1]; + // auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt, + // device_opt, pin_memory_opt); + + if (triu_size > 0) { + // # of triu elements in the first row + auto m_first_row = offset > 0 ? std::max(col - offset, 0) + : // upper bounded by col + col; + + // size of the top rectangle + int rectangle_size = 0; + if (offset < 0) { + rectangle_size = std::min(row, -offset) * col; + } + + // using gpu_launch_config to get grid_size and block_size + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, triu_size); + + triu_indices_kernel<<>>(out_data, + std::max(0, offset), + m_first_row, + col, + rectangle_size, + triu_size); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + triu_indices, GPU, ALL_LAYOUT, phi::TriuIndicesKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/triu_indices_kernel.h b/paddle/phi/kernels/triu_indices_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5f1e09a8b65e4f43ef19ed98ec983262974416b2 --- /dev/null +++ b/paddle/phi/kernels/triu_indices_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TriuIndicesKernel(const Context& dev_ctx, + int row, + int col, + int offset, + DataType dtype, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6b99a91cdb307464e9d0dbfd3461aaaef2137f53..73135dfafbb86f857fb32e633ea2efddca873035 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -110,6 +110,7 @@ from .tensor.creation import assign # noqa: F401 from .tensor.creation import complex # noqa: F401 from .tensor.creation import clone # noqa: F401 from .tensor.creation import tril_indices #noqa: F401 +from .tensor.creation import triu_indices #noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -654,4 +655,5 @@ __all__ = [ # noqa 'heaviside', 'tril_indices', 'sgn', + 'triu_indices', ] diff --git a/python/paddle/fluid/tests/unittests/test_triu_indices_op.py b/python/paddle/fluid/tests/unittests/test_triu_indices_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b3dd9a2861686eeac780a1fb19ef649d7e87fd15 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_triu_indices_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard + + +class TestTriuIndicesOp(OpTest): + + def setUp(self): + self.op_type = "triu_indices" + self.inputs = {} + self.init_config() + self.outputs = {'out': self.target} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def init_config(self): + self.attrs = {'row': 4, 'col': 4, 'offset': -1} + self.target = np.triu_indices(self.attrs['row'], self.attrs['offset'], + self.attrs['col']) + self.target = np.array(self.target) + + +class TestTriuIndicesOpCase1(TestTriuIndicesOp): + + def init_config(self): + self.attrs = {'row': 0, 'col': 0, 'offset': 0} + self.target = np.triu_indices(0, 0, 0) + self.target = np.array(self.target) + + +class TestTriuIndicesOpCase2(TestTriuIndicesOp): + + def init_config(self): + self.attrs = {'row': 4, 'col': 4, 'offset': 2} + self.target = np.triu_indices(self.attrs['row'], self.attrs['offset'], + self.attrs['col']) + self.target = np.array(self.target) + + +class TestTriuIndicesAPICaseStatic(unittest.TestCase): + + def test_static(self): + if fluid.core.is_compiled_with_cuda(): + place = paddle.fluid.CUDAPlace(0) + else: + place = paddle.CPUPlace() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.triu_indices(4, 4, -1) + exe = paddle.static.Executor(place) + result = exe.run(feed={}, fetch_list=[data]) + expected_result = np.triu_indices(4, -1, 4) + np.testing.assert_array_equal(result[0], expected_result) + + +class TestTriuIndicesAPICaseDygraph(unittest.TestCase): + + def test_dygraph(self): + if fluid.core.is_compiled_with_cuda(): + place = paddle.fluid.CUDAPlace(0) + else: + place = paddle.CPUPlace() + with fluid.dygraph.base.guard(place=place): + out = paddle.triu_indices(4, 4, 2) + expected_result = np.triu_indices(4, 2, 4) + np.testing.assert_array_equal(out, expected_result) + + def test_dygraph_eager(self): + with _test_eager_guard(): + self.test_dygraph() + + +class TestTriuIndicesAPICaseError(unittest.TestCase): + + def test_case_error(self): + + def test_num_rows_type_check(): + out1 = paddle.triu_indices(1.0, 1, 2) + + self.assertRaises(TypeError, test_num_rows_type_check) + + def test_num_columns_type_check(): + out2 = paddle.triu_indices(4, -1, 2) + + self.assertRaises(TypeError, test_num_columns_type_check) + + def test_num_offset_type_check(): + out3 = paddle.triu_indices(4, 4, 2.0) + + self.assertRaises(TypeError, test_num_offset_type_check) + + +class TestTriuIndicesAPICaseDefault(unittest.TestCase): + + def test_default_CPU(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.triu_indices(4, None, 2) + exe = paddle.static.Executor(paddle.CPUPlace()) + result = exe.run(feed={}, fetch_list=[data]) + expected_result = np.triu_indices(4, 2) + np.testing.assert_array_equal(result[0], expected_result) + + with fluid.dygraph.base.guard(paddle.CPUPlace()): + out = paddle.triu_indices(4, None, 2) + expected_result = np.triu_indices(4, 2) + np.testing.assert_array_equal(out, expected_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index a1a8f2937d7514dcdfaf2ed8574be48850512a08..17393db9b4cce649e503e6afde6539e58c54261a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1917,3 +1917,88 @@ def tril_indices(row, col, offset=0, dtype='int64'): 'dtype': dtype }) return out + + +def triu_indices(row, col=None, offset=0, dtype='int64'): + """ + Return the indices of the upper triangular part of the 2-D matrix + whose row and col is known. Indices are ordered based on row and then columns. + The upper triangular part of the matrix is defined as the elements on + and above the diagonal. + + Args: + row (int): The input x which is a int number describe the number of row of the matrix. + col (int, optional): The input x which is a int number describe the number of col of the matrix. + default value for col is None, then it will be set equal to row, indicting a square matix. + offset (int, optional): The offset to consider, default value is 0. + + - If offset = 0, all elements on and above the main diagonal are retained. + - If offset > 0, include just as few diagonals above the main diagonal. + - If offset < 0, excludes just as few diagonals below the main diagonal. + + dtype (str|np.dtype|paddle.dtype, optional): the data type of the output tensor, + can be int32, int64, default value is int64. + Returns: + Tensor: Results of the indices of upper triangular part of a row * col matrix, + where the first row contains row coordinates of and the second row contains column coordinates. + + Examples: + .. code-block:: python + + import paddle + # example 1, default offset value + data1 = paddle.triu_indices(4,4,0) + print(data1) + # [[0, 0, 0, 0, 1, 1, 1, 2, 2, 3], + # [0, 1, 2, 3, 1, 2, 3, 2, 3, 3]] + # example 2, positive offset value + data2 = paddle.triu_indices(4,4,2) + print(data2) + # [[0, 0, 1], + # [2, 3, 3]] + # example 3, negative offset value + data3 = paddle.triu_indices(4,4,-1) + print(data3) + # [[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3], + # [0, 1, 2, 3, 0, 1, 2, 3, 1, 2, 3, 2, 3]] + """ + if not isinstance(row, int) or row < 0: + raise TypeError("row should be a non-negative int") + + if col is not None: + if not isinstance(col, int) or col < 0: + raise TypeError("col should be a non-negative int") + else: + col = row + + if not isinstance(offset, int): + raise TypeError("offset should be a int") + + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dygraph_mode(): + out = _C_ops.final_state_triu_indices(row, col, offset, dtype, + _current_expected_place()) + return out + + if _in_legacy_dygraph(): + out = _C_ops.triu_indices('row', row, 'col', col, 'offset', offset, + "dtype", dtype) + return out + + else: + helper = LayerHelper("triu_indices", **locals()) + + out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op(type='triu_indices', + inputs={}, + outputs={'out': [out]}, + attrs={ + 'row': row, + 'col': col, + 'offset': offset, + 'dtype': dtype + }) + return out