未验证 提交 75db5b86 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[Hackathon No.5] tril_indices OP (#41639)

* add tril_indices cpu kernal

* modify tril_indice cpu op

* modify bug

* modify bug

* add tril_indices python api

* add tril_indices python api

* resolve conflict

* add tril_indices test

* modify details

* add tril_indices.cu

* pythonapi pass

* save tril_indices

* CPU tril_indices pass

* delete vlog

* modify test_tril_indices_op.py

* delete tril_indices_kernel.cc.swp

* delete tril_indice.cu

* modify code style

* add newline in creation.py

* modify creation.py linux newline

* delete annotation

* check code style

* check .py style add final_state??

* modify code style

* add gpu_tril_indices

* modify gpu_compiled_juage

* modify gpu judge

* code style

* add test example

* modify english document

modify english document

modify english document

modify document

modify document

* modify pram name

* modify pram name

* modify pram

* reduce test ex
上级 1f76eabf
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
class TrilIndicesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class TrilIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("out",
"Tensor, the output tensor, with the shape (2,x),x bounded by "
"[0,rows*cols])");
AddAttr<int>("rows",
"int number, the input of tril_indices op"
"which describes the number of row of the matrix")
.SetDefault(0);
AddAttr<int>("cols",
"int number, the input of tril_indices op"
"which describes the number of col of the matrix")
.SetDefault(0);
AddAttr<int>(
"offset",
"int number, the input of tril_indices op bounded by [1-rows,cols-1"
"which describes the dignalline index of the lower triangular part of "
"the matrix")
.SetDefault(0);
AddAttr<int>("dtype", "data type ,the input of tril_indices op")
.SetDefault(framework::proto::VarType::INT64);
AddComment(R"DOC(
TrilIndices Operator.
The tril_indices operator returns the indices of the lower triangular part of the matrix
whose rows and cols is knowed. It is a 2-by-x tensor,where the first row contains row coordinates
of all indices and the second row contains column coordinates. Indices are ordered based on
rows and then columns. The lower triangular part of the matrix is defined as the elements on
and below the diagonal.
The argument offset controls which diagonal to consider, default value is 0.
A positive valueincludes just as many diagonals above the main diagonal,
and similarly a negative value excludes just as many diagonals below the main diagonal
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(tril_indices, TrilIndicesInferShapeFunctor,
PD_INFER_META(phi::TrilIndicesInferMeta));
REGISTER_OPERATOR(
tril_indices, ops::TrilIndicesOp, ops::TrilIndicesOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
TrilIndicesInferShapeFunctor);
......@@ -115,4 +115,27 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
out->set_layout(DataLayout::NCHW);
}
void TrilIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out) {
// number of elements in the first row of the tril,bounded by [0, cols]
auto n_first_row =
offset > 0 ? std::min<int64_t>(cols, 1 + offset) : rows + offset > 0;
// number of elements in the last row of the tril, bounded by [0, cols]
auto n_last_row =
std::max<int64_t>(0, std::min<int64_t>(cols, rows + offset));
// number of rows, bounded by [0, rows]
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(rows, rows + offset));
auto n_row_trapezoid = (n_last_row - n_first_row + 1);
// calculate # of elements in the top trapezoid
auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1;
// calculate # of elements in the bottom rectangle if there is any
auto diff_row = n_row_all - n_row_trapezoid;
if (diff_row > 0) {
tril_size += diff_row * cols;
}
std::vector<int64_t> tmp = {2, tril_size};
auto out_dims = phi::make_ddim(tmp);
out->set_dims(out_dims);
out->set_dtype(dtype);
}
} // namespace phi
......@@ -72,4 +72,6 @@ void UniformRandomInferMeta(const IntArray& shape,
int seed,
MetaTensor* out);
void TrilIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tril_indices_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int64_t tril_size = out_dims[1];
int64_t i = 0;
T r = std::max<int64_t>(0, -offset), c = 0;
while (i < tril_size) {
out_data[i] = r;
out_data[tril_size + i++] = c;
// move to the next column and check if (r, c) is still in bound
c += 1;
if (c > r + offset || c >= cols) {
r += 1;
c = 0;
// NOTE: not necessary to check if r is less than row here, because i
// and tril_size provide the guarantee
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
tril_indices, CPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tril_indices_kernel.h"
#include <algorithm>
#include <tuple>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) {
int bXb_cX4 = b * b - cX4;
double sr = ::sqrt(static_cast<double>(bXb_cX4));
T res = ::__double2ll_rd((-b + sign * sr) / 2);
if (bXb_cX4 != static_cast<int>(sr * sr)) {
int llsr = ::__double2ll_rd(sr);
int diff = ::__double2ll_ru(
::sqrt(::fabs(static_cast<double>(bXb_cX4 - llsr * llsr))));
auto l = res > diff ? res - diff : 0;
auto r = res + diff + 1;
x <<= 1;
while (l + 1 < r) {
auto m = (l + r) >> 1;
if (sign * (b + m) * m > x) {
r = m;
} else {
l = m;
}
}
res = l;
}
return res;
}
template <typename T>
__device__ inline void get_coordinate_in_tril_trapezoid(int f,
int x,
T* row,
T* col) {
f <<= 1; // all statements use 2f, so only calculate it once here.
auto b = f - 1;
auto cX4 = -(x << 3); // 4 * c = 4 * (-2x) = -8x;
*row = resolve_root_int<T>(b, cX4, x, 1);
*col = x - ((f + *row - 1) * *row >> 1);
}
template <typename T>
__global__ void tril_indices_kernel(T* out_data,
int row_offset,
int m_first_row,
int col,
int trapezoid_size,
int tril_size) {
int linear_index = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_index < tril_size) {
T r, c;
if (linear_index < trapezoid_size) {
// the coordinate is within the top trapezoid
get_coordinate_in_tril_trapezoid<T>(m_first_row, linear_index, &r, &c);
} else {
// the coordinate falls in the bottom rectangle
auto surplus = linear_index - trapezoid_size;
// add the height of trapezoid: m_last_row (col) - m_first_row + 1
r = surplus / col + col - m_first_row + 1;
c = surplus % col;
}
r += row_offset;
out_data[linear_index] = r;
out_data[linear_index + tril_size] = c;
}
}
template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int tril_size = out_dims[1];
if (tril_size > 0) {
auto m_first_row = offset > 0
? std::min<int>(cols, 1 + offset)
: rows + offset > 0; // the number of first row
auto trapezoid_row_offset =
std::max<int>(0, -offset); // index of the first row who has number
auto rectangle_row_offset = trapezoid_row_offset + cols - m_first_row +
1; // the length of the right-up rest matrix
int rectangle_size = 0;
if (rectangle_row_offset < rows) {
rectangle_size = (rows - rectangle_row_offset) * cols;
} // the rectangle part of lowertriangle matrix
auto GetBlockGridSize = [&dev_ctx](int size) {
const int block_size =
std::min(size, static_cast<int>(dev_ctx.GetMaxThreadsPerBlock()));
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int>(1));
const int grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int, int>{grid_size, block_size};
};
std::tuple<int, int> block_grid_size = GetBlockGridSize(tril_size);
tril_indices_kernel<T><<<std::get<0>(block_grid_size),
std::get<1>(block_grid_size),
0,
dev_ctx.stream()>>>(out_data,
trapezoid_row_offset,
m_first_row,
cols,
tril_size - rectangle_size,
tril_size);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
tril_indices, GPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -105,6 +105,7 @@ from .tensor.creation import empty_like # noqa: F401
from .tensor.creation import assign # noqa: F401
from .tensor.creation import complex # noqa: F401
from .tensor.creation import clone # noqa: F401
from .tensor.creation import tril_indices #noqa: F401
from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
......@@ -637,4 +638,5 @@ __all__ = [ # noqa
'take_along_axis',
'put_along_axis',
'heaviside',
'tril_indices',
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
class TestTrilIndicesOp(OpTest):
def setUp(self):
self.op_type = "tril_indices"
self.inputs = {}
self.init_config()
self.outputs = {'out': self.target}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def init_config(self):
self.attrs = {'rows': 4, 'cols': 4, 'offset': -1}
self.target = np.tril_indices(self.attrs['rows'], self.attrs['offset'],
self.attrs['cols'])
self.target = np.array(self.target)
class TestTrilIndicesOpCase1(TestTrilIndicesOp):
def init_config(self):
self.attrs = {'rows': 0, 'cols': 0, 'offset': 0}
self.target = np.tril_indices(0, 0, 0)
self.target = np.array(self.target)
class TestTrilIndicesOpCase2(TestTrilIndicesOp):
def init_config(self):
self.attrs = {'rows': 4, 'cols': 4, 'offset': 2}
self.target = np.tril_indices(self.attrs['rows'], self.attrs['offset'],
self.attrs['cols'])
self.target = np.array(self.target)
class TestTrilIndicesAPICaseStatic(unittest.TestCase):
def test_static(self):
places = [
paddle.CPUPlace(), paddle.fluid.CUDAPlace(0)
] if fluid.core.is_compiled_with_cuda() else [paddle.CPUPlace()]
paddle.enable_static()
for place in places:
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.tril_indices(4, 4, -1)
exe1 = paddle.static.Executor(place)
result1 = exe1.run(feed={}, fetch_list=[data1])
expected_result1 = np.tril_indices(4, -1, 4)
self.assertTrue(np.allclose(result1, expected_result1))
class TestTrilIndicesAPICaseDygraph(unittest.TestCase):
def test_dygraph(self):
places = [
paddle.CPUPlace(), paddle.fluid.CUDAPlace(0)
] if fluid.core.is_compiled_with_cuda() else [paddle.CPUPlace()]
for place in places:
with fluid.dygraph.base.guard(place=place):
out1 = paddle.tril_indices(4, 4, 2)
expected_result1 = np.tril_indices(4, 2, 4)
self.assertEqual((out1.numpy() == expected_result1).all(), True)
def test_dygraph_eager(self):
with _test_eager_guard():
self.test_dygraph()
class TestTrilIndicesAPICaseError(unittest.TestCase):
def test_case_error(self):
def test_num_rows_type_check():
out1 = paddle.tril_indices(1.0, 1, 2)
self.assertRaises(TypeError, test_num_rows_type_check)
def test_num_columns_type_check():
out2 = paddle.tril_indices(4, -1, 2)
self.assertRaises(TypeError, test_num_columns_type_check)
def test_num_offset_type_check():
out3 = paddle.tril_indices(4, 4, 2.0)
self.assertRaises(TypeError, test_num_offset_type_check)
class TestTrilIndicesAPICaseDefault(unittest.TestCase):
def test_default_CPU(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.tril_indices(4, None, 2)
exe = paddle.static.Executor(paddle.CPUPlace())
result = exe.run(feed={}, fetch_list=[data])
expected_result = np.tril_indices(4, 2)
self.assertTrue(np.allclose(result, expected_result))
with fluid.dygraph.base.guard(paddle.CPUPlace()):
out = paddle.tril_indices(4, None, 2)
expected_result = np.tril_indices(4, 2)
self.assertEqual((out.numpy() == expected_result).all(), True)
if __name__ == "__main__":
unittest.main()
......@@ -1713,3 +1713,90 @@ def complex(real, imag, name=None):
attrs = {}
helper.append_op(type=op_type, inputs=inputs, attrs=attrs, outputs=outputs)
return out
def tril_indices(row, col, offset=0, dtype='int64'):
"""
Return the indices of the lower triangular part of the 2-D matrix
whose row and col is knowed.Indices are ordered based on row and then columns.
The lower triangular part of the matrix is defined as the elements on
and below the diagonal.
Args:
row (int): The input x which is a int number describe the number of row of the matrix.
col (int): The input x which is a int number describe the number of col of the matrix.
offset (int, optional): The offset to consider, default value is 0.
- If offset = 0, all elements on and below the main diagonal are retained.
- If offset > 0, include just as many diagonals above the main diagonal.
- If offset < 0, excludes just as many diagonals below the main diagonal.
dtype (int, optional): the data type of the output tensor, can be int32, int64.
Returns:
Tensor: Results of the indices of lower triangular part of a row * col matrix,
where the first row contains row coordinates of and the second row contains column coordinates.
Examples:
.. code-block:: python
:name: tril_indices-example
import paddle
# example 1, default offset value
data1 = paddle.tril_indices(4,4,0)
print(data1)
# [[0, 1, 1, 2, 2, 2, 3, 3, 3, 3],
# [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]]
# example 2, positive offset value
data2 = paddle.tril_indices(4,4,2)
print(data2)
# [[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
# [0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]
# example 3, negative offset value
data3 = paddle.tril_indices(4,4,-1)
print(data3)
# [[ 1, 2, 2, 3, 3, 3],
# [ 0, 0, 1, 0, 1, 2]]
"""
if not isinstance(row, int) or row < 0:
raise TypeError("row should be a non-negative int")
if col is not None:
if not isinstance(col, int) or col < 0:
raise TypeError("col should be a non-negative int")
else:
col = row
if not isinstance(offset, int):
raise TypeError("offset should be a int")
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
out = _C_ops.final_state_tril_indices(row, col, offset, dtype,
_current_expected_place())
return out
if _in_legacy_dygraph():
out = _C_ops.tril_indices('rows', row, 'cols', col, 'offset', offset,
"dtype", dtype)
return out
else:
helper = LayerHelper("tril_indices", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='tril_indices',
inputs={},
outputs={'out': [out]},
attrs={'rows': row,
'cols': col,
'offset': offset,
'dtype': dtype})
return out
......@@ -2140,6 +2140,18 @@
func : triangular_solve
backward : triangular_solve_grad
- api : tril_indices
args : (int rows, int cols, int offset, DataType dtype, Place place={})
output : Tensor(out)
infer_meta :
func : TrilIndicesInferMeta
param : [rows, cols, offset, dtype]
kernel :
func : tril_indices
param : [rows, cols, offset, dtype]
data_type : dtype
backend : place
- api : tril_triu
args : (Tensor x, int diagonal, bool lower)
output : Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册