diff --git a/paddle/fluid/operators/tril_indices_op.cc b/paddle/fluid/operators/tril_indices_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..be42f53dd23440001bf617140dd07b5b7c3110c9 --- /dev/null +++ b/paddle/fluid/operators/tril_indices_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/nullary.h" + +namespace paddle { +namespace operators { + +class TrilIndicesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class TrilIndicesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("out", + "Tensor, the output tensor, with the shape (2,x),x bounded by " + "[0,rows*cols])"); + AddAttr("rows", + "int number, the input of tril_indices op" + "which describes the number of row of the matrix") + .SetDefault(0); + AddAttr("cols", + "int number, the input of tril_indices op" + "which describes the number of col of the matrix") + .SetDefault(0); + AddAttr( + "offset", + "int number, the input of tril_indices op bounded by [1-rows,cols-1" + "which describes the dignalline index of the lower triangular part of " + "the matrix") + .SetDefault(0); + AddAttr("dtype", "data type ,the input of tril_indices op") + .SetDefault(framework::proto::VarType::INT64); + + AddComment(R"DOC( + TrilIndices Operator. + + The tril_indices operator returns the indices of the lower triangular part of the matrix + whose rows and cols is knowed. It is a 2-by-x tensor,where the first row contains row coordinates + of all indices and the second row contains column coordinates. Indices are ordered based on + rows and then columns. The lower triangular part of the matrix is defined as the elements on + and below the diagonal. + + The argument offset controls which diagonal to consider, default value is 0. + A positive valueincludes just as many diagonals above the main diagonal, + and similarly a negative value excludes just as many diagonals below the main diagonal + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(tril_indices, TrilIndicesInferShapeFunctor, + PD_INFER_META(phi::TrilIndicesInferMeta)); + +REGISTER_OPERATOR( + tril_indices, ops::TrilIndicesOp, ops::TrilIndicesOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + TrilIndicesInferShapeFunctor); diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 3a99103eda5c23aebbdd1def5343a9e8bfd28347..c3ded621718ce6ae48ce6462c9ea9a7be6bc36af 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -115,4 +115,27 @@ void TruncatedGaussianRandomInferMeta(const std::vector& shape, out->set_layout(DataLayout::NCHW); } +void TrilIndicesInferMeta( + int rows, int cols, int offset, DataType dtype, MetaTensor* out) { + // number of elements in the first row of the tril,bounded by [0, cols] + auto n_first_row = + offset > 0 ? std::min(cols, 1 + offset) : rows + offset > 0; + // number of elements in the last row of the tril, bounded by [0, cols] + auto n_last_row = + std::max(0, std::min(cols, rows + offset)); + // number of rows, bounded by [0, rows] + auto n_row_all = std::max(0, std::min(rows, rows + offset)); + auto n_row_trapezoid = (n_last_row - n_first_row + 1); + // calculate # of elements in the top trapezoid + auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1; + // calculate # of elements in the bottom rectangle if there is any + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * cols; + } + std::vector tmp = {2, tril_size}; + auto out_dims = phi::make_ddim(tmp); + out->set_dims(out_dims); + out->set_dtype(dtype); +} } // namespace phi diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 8d952d842c0c4428cabda5a492e5c37a98289121..a9f1818e319576d658d6f7794935f7177a7b26d9 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -72,4 +72,6 @@ void UniformRandomInferMeta(const IntArray& shape, int seed, MetaTensor* out); +void TrilIndicesInferMeta( + int rows, int cols, int offset, DataType dtype, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/tril_indices_kernel.cc b/paddle/phi/kernels/cpu/tril_indices_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c515a69f011d540417ece87cc705fdd8896f1dac --- /dev/null +++ b/paddle/phi/kernels/cpu/tril_indices_kernel.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tril_indices_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void TrilIndicesKernel(const Context& dev_ctx, + int rows, + int cols, + int offset, + DataType dtype, + DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + auto out_dims = out->dims(); + int64_t tril_size = out_dims[1]; + int64_t i = 0; + T r = std::max(0, -offset), c = 0; + while (i < tril_size) { + out_data[i] = r; + out_data[tril_size + i++] = c; + + // move to the next column and check if (r, c) is still in bound + c += 1; + if (c > r + offset || c >= cols) { + r += 1; + c = 0; + // NOTE: not necessary to check if r is less than row here, because i + // and tril_size provide the guarantee + } + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + tril_indices, CPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/tril_indices_kernel.cu b/paddle/phi/kernels/gpu/tril_indices_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..be83f28451166b1193098ab746abc2a28cf4c028 --- /dev/null +++ b/paddle/phi/kernels/gpu/tril_indices_kernel.cu @@ -0,0 +1,142 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tril_indices_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) { + int bXb_cX4 = b * b - cX4; + double sr = ::sqrt(static_cast(bXb_cX4)); + T res = ::__double2ll_rd((-b + sign * sr) / 2); + if (bXb_cX4 != static_cast(sr * sr)) { + int llsr = ::__double2ll_rd(sr); + int diff = ::__double2ll_ru( + ::sqrt(::fabs(static_cast(bXb_cX4 - llsr * llsr)))); + auto l = res > diff ? res - diff : 0; + auto r = res + diff + 1; + x <<= 1; + while (l + 1 < r) { + auto m = (l + r) >> 1; + if (sign * (b + m) * m > x) { + r = m; + } else { + l = m; + } + } + res = l; + } + return res; +} + +template +__device__ inline void get_coordinate_in_tril_trapezoid(int f, + int x, + T* row, + T* col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = f - 1; + auto cX4 = -(x << 3); // 4 * c = 4 * (-2x) = -8x; + *row = resolve_root_int(b, cX4, x, 1); + *col = x - ((f + *row - 1) * *row >> 1); +} + +template +__global__ void tril_indices_kernel(T* out_data, + int row_offset, + int m_first_row, + int col, + int trapezoid_size, + int tril_size) { + int linear_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_index < tril_size) { + T r, c; + if (linear_index < trapezoid_size) { + // the coordinate is within the top trapezoid + get_coordinate_in_tril_trapezoid(m_first_row, linear_index, &r, &c); + } else { + // the coordinate falls in the bottom rectangle + auto surplus = linear_index - trapezoid_size; + // add the height of trapezoid: m_last_row (col) - m_first_row + 1 + r = surplus / col + col - m_first_row + 1; + c = surplus % col; + } + r += row_offset; + + out_data[linear_index] = r; + out_data[linear_index + tril_size] = c; + } +} + +template +void TrilIndicesKernel(const Context& dev_ctx, + int rows, + int cols, + int offset, + DataType dtype, + DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + auto out_dims = out->dims(); + int tril_size = out_dims[1]; + + if (tril_size > 0) { + auto m_first_row = offset > 0 + ? std::min(cols, 1 + offset) + : rows + offset > 0; // the number of first row + auto trapezoid_row_offset = + std::max(0, -offset); // index of the first row who has number + auto rectangle_row_offset = trapezoid_row_offset + cols - m_first_row + + 1; // the length of the right-up rest matrix + int rectangle_size = 0; + if (rectangle_row_offset < rows) { + rectangle_size = (rows - rectangle_row_offset) * cols; + } // the rectangle part of lowertriangle matrix + + auto GetBlockGridSize = [&dev_ctx](int size) { + const int block_size = + std::min(size, static_cast(dev_ctx.GetMaxThreadsPerBlock())); + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int grid_size = + std::min(max_blocks, (size + block_size - 1) / block_size); + return std::tuple{grid_size, block_size}; + }; + + std::tuple block_grid_size = GetBlockGridSize(tril_size); + + tril_indices_kernel<<(block_grid_size), + std::get<1>(block_grid_size), + 0, + dev_ctx.stream()>>>(out_data, + trapezoid_row_offset, + m_first_row, + cols, + tril_size - rectangle_size, + tril_size); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + tril_indices, GPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/tril_indices_kernel.h b/paddle/phi/kernels/tril_indices_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1132a539ee6d11ec5cbc44cbb75a1303d406b407 --- /dev/null +++ b/paddle/phi/kernels/tril_indices_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TrilIndicesKernel(const Context& dev_ctx, + int rows, + int cols, + int offset, + DataType dtype, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 8c2ec1acf072a192807a17b8377162cd1bb66dd7..132105fb2b689f2aea80d3f814926102489d6411 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -105,6 +105,7 @@ from .tensor.creation import empty_like # noqa: F401 from .tensor.creation import assign # noqa: F401 from .tensor.creation import complex # noqa: F401 from .tensor.creation import clone # noqa: F401 +from .tensor.creation import tril_indices #noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -637,4 +638,5 @@ __all__ = [ # noqa 'take_along_axis', 'put_along_axis', 'heaviside', + 'tril_indices', ] diff --git a/python/paddle/fluid/tests/unittests/test_tril_indices_op.py b/python/paddle/fluid/tests/unittests/test_tril_indices_op.py new file mode 100644 index 0000000000000000000000000000000000000000..29b07a5fb8463b143dad4fe055909568fce9059d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tril_indices_op.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard + + +class TestTrilIndicesOp(OpTest): + def setUp(self): + self.op_type = "tril_indices" + self.inputs = {} + self.init_config() + self.outputs = {'out': self.target} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def init_config(self): + self.attrs = {'rows': 4, 'cols': 4, 'offset': -1} + self.target = np.tril_indices(self.attrs['rows'], self.attrs['offset'], + self.attrs['cols']) + self.target = np.array(self.target) + + +class TestTrilIndicesOpCase1(TestTrilIndicesOp): + def init_config(self): + self.attrs = {'rows': 0, 'cols': 0, 'offset': 0} + self.target = np.tril_indices(0, 0, 0) + self.target = np.array(self.target) + + +class TestTrilIndicesOpCase2(TestTrilIndicesOp): + def init_config(self): + self.attrs = {'rows': 4, 'cols': 4, 'offset': 2} + self.target = np.tril_indices(self.attrs['rows'], self.attrs['offset'], + self.attrs['cols']) + self.target = np.array(self.target) + + +class TestTrilIndicesAPICaseStatic(unittest.TestCase): + def test_static(self): + places = [ + paddle.CPUPlace(), paddle.fluid.CUDAPlace(0) + ] if fluid.core.is_compiled_with_cuda() else [paddle.CPUPlace()] + paddle.enable_static() + for place in places: + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data1 = paddle.tril_indices(4, 4, -1) + exe1 = paddle.static.Executor(place) + result1 = exe1.run(feed={}, fetch_list=[data1]) + expected_result1 = np.tril_indices(4, -1, 4) + self.assertTrue(np.allclose(result1, expected_result1)) + + +class TestTrilIndicesAPICaseDygraph(unittest.TestCase): + def test_dygraph(self): + places = [ + paddle.CPUPlace(), paddle.fluid.CUDAPlace(0) + ] if fluid.core.is_compiled_with_cuda() else [paddle.CPUPlace()] + for place in places: + with fluid.dygraph.base.guard(place=place): + out1 = paddle.tril_indices(4, 4, 2) + expected_result1 = np.tril_indices(4, 2, 4) + self.assertEqual((out1.numpy() == expected_result1).all(), True) + + def test_dygraph_eager(self): + with _test_eager_guard(): + self.test_dygraph() + + +class TestTrilIndicesAPICaseError(unittest.TestCase): + def test_case_error(self): + def test_num_rows_type_check(): + out1 = paddle.tril_indices(1.0, 1, 2) + + self.assertRaises(TypeError, test_num_rows_type_check) + + def test_num_columns_type_check(): + out2 = paddle.tril_indices(4, -1, 2) + + self.assertRaises(TypeError, test_num_columns_type_check) + + def test_num_offset_type_check(): + out3 = paddle.tril_indices(4, 4, 2.0) + + self.assertRaises(TypeError, test_num_offset_type_check) + + +class TestTrilIndicesAPICaseDefault(unittest.TestCase): + def test_default_CPU(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.tril_indices(4, None, 2) + exe = paddle.static.Executor(paddle.CPUPlace()) + result = exe.run(feed={}, fetch_list=[data]) + expected_result = np.tril_indices(4, 2) + self.assertTrue(np.allclose(result, expected_result)) + + with fluid.dygraph.base.guard(paddle.CPUPlace()): + out = paddle.tril_indices(4, None, 2) + expected_result = np.tril_indices(4, 2) + self.assertEqual((out.numpy() == expected_result).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 4580ff708e9f16a5d2d748b133e846f73760dc7d..c7e73cec47bead602e4a93b9e436c1a90dca9fa0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1713,3 +1713,90 @@ def complex(real, imag, name=None): attrs = {} helper.append_op(type=op_type, inputs=inputs, attrs=attrs, outputs=outputs) return out + + +def tril_indices(row, col, offset=0, dtype='int64'): + """ + Return the indices of the lower triangular part of the 2-D matrix + whose row and col is knowed.Indices are ordered based on row and then columns. + The lower triangular part of the matrix is defined as the elements on + and below the diagonal. + + Args: + row (int): The input x which is a int number describe the number of row of the matrix. + col (int): The input x which is a int number describe the number of col of the matrix. + offset (int, optional): The offset to consider, default value is 0. + + - If offset = 0, all elements on and below the main diagonal are retained. + - If offset > 0, include just as many diagonals above the main diagonal. + - If offset < 0, excludes just as many diagonals below the main diagonal. + + dtype (int, optional): the data type of the output tensor, can be int32, int64. + + Returns: + Tensor: Results of the indices of lower triangular part of a row * col matrix, + where the first row contains row coordinates of and the second row contains column coordinates. + + Examples: + .. code-block:: python + :name: tril_indices-example + + import paddle + + # example 1, default offset value + data1 = paddle.tril_indices(4,4,0) + print(data1) + # [[0, 1, 1, 2, 2, 2, 3, 3, 3, 3], + # [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]] + + # example 2, positive offset value + data2 = paddle.tril_indices(4,4,2) + print(data2) + # [[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + # [0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]] + + # example 3, negative offset value + data3 = paddle.tril_indices(4,4,-1) + print(data3) + # [[ 1, 2, 2, 3, 3, 3], + # [ 0, 0, 1, 0, 1, 2]] + """ + if not isinstance(row, int) or row < 0: + raise TypeError("row should be a non-negative int") + + if col is not None: + if not isinstance(col, int) or col < 0: + raise TypeError("col should be a non-negative int") + else: + col = row + + if not isinstance(offset, int): + raise TypeError("offset should be a int") + + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dygraph_mode(): + out = _C_ops.final_state_tril_indices(row, col, offset, dtype, + _current_expected_place()) + return out + + if _in_legacy_dygraph(): + out = _C_ops.tril_indices('rows', row, 'cols', col, 'offset', offset, + "dtype", dtype) + return out + + else: + helper = LayerHelper("tril_indices", **locals()) + + out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='tril_indices', + inputs={}, + outputs={'out': [out]}, + attrs={'rows': row, + 'cols': col, + 'offset': offset, + 'dtype': dtype}) + return out diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 05a00390205f98f9b8b07264f8b1c0441beb91ce..54a5100c892fce6489c31ab0cf9536c5a0fa2797 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -2140,6 +2140,18 @@ func : triangular_solve backward : triangular_solve_grad +- api : tril_indices + args : (int rows, int cols, int offset, DataType dtype, Place place={}) + output : Tensor(out) + infer_meta : + func : TrilIndicesInferMeta + param : [rows, cols, offset, dtype] + kernel : + func : tril_indices + param : [rows, cols, offset, dtype] + data_type : dtype + backend : place + - api : tril_triu args : (Tensor x, int diagonal, bool lower) output : Tensor(out)