Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a410c397
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a410c397
编写于
8月 25, 2022
作者:
R
Rayman
提交者:
GitHub
8月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[triu_indices] add triu_indices_op (#45168)
上级
20d38664
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
564 addition
and
0 deletion
+564
-0
paddle/fluid/operators/triu_indices_op.cc
paddle/fluid/operators/triu_indices_op.cc
+86
-0
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+12
-0
paddle/phi/infermeta/nullary.cc
paddle/phi/infermeta/nullary.cc
+29
-0
paddle/phi/infermeta/nullary.h
paddle/phi/infermeta/nullary.h
+3
-0
paddle/phi/kernels/cpu/triu_indices_kernel.cc
paddle/phi/kernels/cpu/triu_indices_kernel.cc
+51
-0
paddle/phi/kernels/gpu/triu_indices_kernel.cu
paddle/phi/kernels/gpu/triu_indices_kernel.cu
+133
-0
paddle/phi/kernels/triu_indices_kernel.h
paddle/phi/kernels/triu_indices_kernel.h
+29
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_triu_indices_op.py
python/paddle/fluid/tests/unittests/test_triu_indices_op.py
+134
-0
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+85
-0
未找到文件。
paddle/fluid/operators/triu_indices_op.cc
0 → 100644
浏览文件 @
a410c397
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
namespace
paddle
{
namespace
operators
{
class
TriuIndicesOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
Type
(
ctx
.
Attr
<
int
>
(
"dtype"
)),
ctx
.
GetPlace
());
}
};
class
TriuIndicesOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"out"
,
"Tensor, the output tensor, with the shape (2,x), x bounded by "
"[0,row*col])"
);
AddAttr
<
int
>
(
"row"
,
"int number, the input of triu_indices op"
"which describes the number of row of the matrix"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"col"
,
"int number, the input of triu_indices op"
"which describes the number of col of the matrix"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"offset"
,
"int number, the input of triu_indices op bounded by [1-rows,cols-1"
"which describes the dignalline index of the upper triangular part of "
"the matrix"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"dtype"
,
"data type ,the input of triu_indices op"
)
.
SetDefault
(
framework
::
proto
::
VarType
::
INT64
);
AddComment
(
R"DOC(
TriuIndices Operator.
The triu_indices operator returns the indices of the upper triangular part of the matrix
whose rows and cols is known. It is a 2-by-x tensor, where the first row contains row coordinates
of all indices and the second row contains column coordinates. Indices are ordered based on
rows and then columns. The upper triangular part of the matrix is defined as the elements on
and below the diagonal.
The argument offset controls which diagonal to consider, default value is 0.
A positive value includes just as fewer diagonals above the main diagonal,
and similarly a negative value excludes just as fewer diagonals below the main diagonal
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
triu_indices
,
TriuIndicesInferShapeFunctor
,
PD_INFER_META
(
phi
::
TriuIndicesInferMeta
));
REGISTER_OPERATOR
(
triu_indices
,
ops
::
TriuIndicesOp
,
ops
::
TriuIndicesOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
TriuIndicesInferShapeFunctor
);
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
a410c397
...
...
@@ -2710,6 +2710,18 @@
data_type
:
x
backward
:
trilinear_interp_grad
-
api
:
triu_indices
args
:
(int row, int col, int offset, DataType dtype, Place place={})
output
:
Tensor(out)
infer_meta
:
func
:
TriuIndicesInferMeta
param
:
[
row
,
col
,
offset
,
dtype
]
kernel
:
func
:
triu_indices
param
:
[
row
,
col
,
offset
,
dtype
]
data_type
:
dtype
backend
:
place
# python API: paddle.nn.initializer.TruncatedNormal
-
api
:
truncated_gaussian_random
args
:
(int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
...
...
paddle/phi/infermeta/nullary.cc
浏览文件 @
a410c397
...
...
@@ -152,4 +152,33 @@ void TrilIndicesInferMeta(
out
->
set_dims
(
out_dims
);
out
->
set_dtype
(
dtype
);
}
void
TriuIndicesInferMeta
(
int
row
,
int
col
,
int
offset
,
DataType
dtype
,
MetaTensor
*
out
)
{
// number of elements in the first row of the tril,bounded by [0, cols]
// use total item number minus bottom rectangle item number to get
// the above rectangle item number
// triu_size = rows * cols - tril_size
// so the `offset` need to be set as `offset-1` in order to include
// the item on the diagonal line
offset
=
offset
-
1
;
auto
n_first_row
=
offset
>
0
?
std
::
min
<
int64_t
>
(
col
,
1
+
offset
)
:
row
+
offset
>
0
;
// number of elements in the last row of the tril, bounded by [0, cols]
auto
n_last_row
=
std
::
max
<
int64_t
>
(
0
,
std
::
min
<
int64_t
>
(
col
,
row
+
offset
));
// number of rows, bounded by [0, rows]
auto
n_row_all
=
std
::
max
<
int64_t
>
(
0
,
std
::
min
<
int64_t
>
(
row
,
row
+
offset
));
auto
n_row_trapezoid
=
(
n_last_row
-
n_first_row
+
1
);
// calculate # of elements in the top trapezoid
auto
tril_size
=
(
n_first_row
+
n_last_row
)
*
n_row_trapezoid
>>
1
;
// calculate # of elements in the bottom rectangle if there is any
auto
diff_row
=
n_row_all
-
n_row_trapezoid
;
if
(
diff_row
>
0
)
{
tril_size
+=
diff_row
*
col
;
}
std
::
vector
<
int64_t
>
tmp
=
{
2
,
row
*
col
-
tril_size
};
auto
out_dims
=
phi
::
make_ddim
(
tmp
);
out
->
set_dims
(
out_dims
);
out
->
set_dtype
(
dtype
);
}
}
// namespace phi
paddle/phi/infermeta/nullary.h
浏览文件 @
a410c397
...
...
@@ -74,4 +74,7 @@ void UniformRandomInferMeta(const IntArray& shape,
void
TrilIndicesInferMeta
(
int
rows
,
int
cols
,
int
offset
,
DataType
dtype
,
MetaTensor
*
out
);
void
TriuIndicesInferMeta
(
int
row
,
int
col
,
int
offset
,
DataType
dtype
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/triu_indices_kernel.cc
0 → 100644
浏览文件 @
a410c397
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/triu_indices_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
TriuIndicesKernel
(
const
Context
&
dev_ctx
,
int
row
,
int
col
,
int
offset
,
DataType
dtype
,
DenseTensor
*
out
)
{
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
auto
&
out_dims
=
out
->
dims
();
int64_t
triu_size
=
out_dims
[
1
];
int64_t
i
=
0
;
T
c
=
std
::
max
<
int64_t
>
(
0
,
offset
),
r
=
0
;
while
(
i
<
triu_size
)
{
out_data
[
i
]
=
r
;
out_data
[
triu_size
+
i
++
]
=
c
;
// move to the next column and check if (r, c) is still in bound
c
+=
1
;
if
(
c
>=
col
)
{
r
+=
1
;
// not typing std::max with scalar_t as it could be an unsigned type
// NOTE: not necessary to check if c is less than col or overflows here,
// because i and triu_size act as a guard.
c
=
std
::
max
<
int64_t
>
(
0
,
r
+
offset
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
triu_indices
,
CPU
,
ALL_LAYOUT
,
phi
::
TriuIndicesKernel
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/triu_indices_kernel.cu
0 → 100644
浏览文件 @
a410c397
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/triu_indices_kernel.h"
#include <algorithm>
#include <tuple>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
>
__device__
inline
int
resolve_root_int
(
int
b
,
int
cX4
,
int
x
,
int32_t
sign
)
{
int64_t
bXb_cX4
=
b
*
b
-
cX4
;
double
sr
=
::
sqrt
(
static_cast
<
double
>
(
bXb_cX4
));
T
res
=
::
__double2ll_rd
((
-
b
+
sign
*
sr
)
/
2
);
if
(
bXb_cX4
!=
static_cast
<
int
>
(
sr
*
sr
))
{
int
llsr
=
::
__double2ll_rd
(
sr
);
int
diff
=
::
__double2ll_ru
(
::
sqrt
(
::
fabs
(
static_cast
<
double
>
(
bXb_cX4
-
llsr
*
llsr
))));
auto
l
=
res
>
diff
?
res
-
diff
:
0
;
auto
r
=
res
+
diff
+
1
;
x
<<=
1
;
while
(
l
+
1
<
r
)
{
auto
m
=
(
l
+
r
)
>>
1
;
if
(
sign
*
(
b
+
m
)
*
m
>
x
)
{
r
=
m
;
}
else
{
l
=
m
;
}
}
res
=
l
;
}
return
res
;
}
template
<
typename
T
>
__device__
inline
void
get_coordinate_in_triu_trapezoid
(
int
f
,
int
x
,
T
*
row
,
T
*
col
)
{
f
<<=
1
;
// all statements use 2f, so only calculate it once here.
auto
b
=
-
1
-
f
;
auto
cX4
=
x
<<
3
;
// 4 * c = 4 * (2x) = 8x;
*
row
=
resolve_root_int
<
T
>
(
b
,
cX4
,
x
,
-
1
);
*
col
=
(
x
-
(((
f
-
*
row
+
1
)
*
*
row
)
>>
1
))
+
*
row
;
}
template
<
typename
T
>
__global__
void
triu_indices_kernel
(
T
*
out_data
,
int
col_offset
,
int
m_first_row
,
int
col
,
int
rectangle_size
,
int
triu_size
)
{
int
linear_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
linear_index
<
triu_size
)
{
T
r
,
c
;
if
(
linear_index
<
rectangle_size
)
{
// the coordinate is within the top rectangle
r
=
linear_index
/
col
;
c
=
linear_index
%
col
;
}
else
{
// the coordinate falls in the bottom trapezoid
get_coordinate_in_triu_trapezoid
<
T
>
(
m_first_row
,
linear_index
-
rectangle_size
,
&
r
,
&
c
);
r
+=
rectangle_size
/
col
;
}
c
+=
col_offset
;
out_data
[
linear_index
]
=
r
;
out_data
[
linear_index
+
triu_size
]
=
c
;
}
}
template
<
typename
T
,
typename
Context
>
void
TriuIndicesKernel
(
const
Context
&
dev_ctx
,
int
row
,
int
col
,
int
offset
,
DataType
dtype
,
DenseTensor
*
out
)
{
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
out_dims
=
out
->
dims
();
int
triu_size
=
out_dims
[
1
];
// auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt,
// device_opt, pin_memory_opt);
if
(
triu_size
>
0
)
{
// # of triu elements in the first row
auto
m_first_row
=
offset
>
0
?
std
::
max
<
int
>
(
col
-
offset
,
0
)
:
// upper bounded by col
col
;
// size of the top rectangle
int
rectangle_size
=
0
;
if
(
offset
<
0
)
{
rectangle_size
=
std
::
min
<
int
>
(
row
,
-
offset
)
*
col
;
}
// using gpu_launch_config to get grid_size and block_size
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
triu_size
);
triu_indices_kernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
std
::
max
<
int
>
(
0
,
offset
),
m_first_row
,
col
,
rectangle_size
,
triu_size
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
triu_indices
,
GPU
,
ALL_LAYOUT
,
phi
::
TriuIndicesKernel
,
int
,
int64_t
)
{}
paddle/phi/kernels/triu_indices_kernel.h
0 → 100644
浏览文件 @
a410c397
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
TriuIndicesKernel
(
const
Context
&
dev_ctx
,
int
row
,
int
col
,
int
offset
,
DataType
dtype
,
DenseTensor
*
out
);
}
// namespace phi
python/paddle/__init__.py
浏览文件 @
a410c397
...
...
@@ -110,6 +110,7 @@ from .tensor.creation import assign # noqa: F401
from
.tensor.creation
import
complex
# noqa: F401
from
.tensor.creation
import
clone
# noqa: F401
from
.tensor.creation
import
tril_indices
#noqa: F401
from
.tensor.creation
import
triu_indices
#noqa: F401
from
.tensor.linalg
import
matmul
# noqa: F401
from
.tensor.linalg
import
dot
# noqa: F401
from
.tensor.linalg
import
norm
# noqa: F401
...
...
@@ -654,4 +655,5 @@ __all__ = [ # noqa
'heaviside'
,
'tril_indices'
,
'sgn'
,
'triu_indices'
,
]
python/paddle/fluid/tests/unittests/test_triu_indices_op.py
0 → 100644
浏览文件 @
a410c397
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid.framework
import
_test_eager_guard
class
TestTriuIndicesOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"triu_indices"
self
.
inputs
=
{}
self
.
init_config
()
self
.
outputs
=
{
'out'
:
self
.
target
}
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
()
def
init_config
(
self
):
self
.
attrs
=
{
'row'
:
4
,
'col'
:
4
,
'offset'
:
-
1
}
self
.
target
=
np
.
triu_indices
(
self
.
attrs
[
'row'
],
self
.
attrs
[
'offset'
],
self
.
attrs
[
'col'
])
self
.
target
=
np
.
array
(
self
.
target
)
class
TestTriuIndicesOpCase1
(
TestTriuIndicesOp
):
def
init_config
(
self
):
self
.
attrs
=
{
'row'
:
0
,
'col'
:
0
,
'offset'
:
0
}
self
.
target
=
np
.
triu_indices
(
0
,
0
,
0
)
self
.
target
=
np
.
array
(
self
.
target
)
class
TestTriuIndicesOpCase2
(
TestTriuIndicesOp
):
def
init_config
(
self
):
self
.
attrs
=
{
'row'
:
4
,
'col'
:
4
,
'offset'
:
2
}
self
.
target
=
np
.
triu_indices
(
self
.
attrs
[
'row'
],
self
.
attrs
[
'offset'
],
self
.
attrs
[
'col'
])
self
.
target
=
np
.
array
(
self
.
target
)
class
TestTriuIndicesAPICaseStatic
(
unittest
.
TestCase
):
def
test_static
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
paddle
.
fluid
.
CUDAPlace
(
0
)
else
:
place
=
paddle
.
CPUPlace
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
data
=
paddle
.
triu_indices
(
4
,
4
,
-
1
)
exe
=
paddle
.
static
.
Executor
(
place
)
result
=
exe
.
run
(
feed
=
{},
fetch_list
=
[
data
])
expected_result
=
np
.
triu_indices
(
4
,
-
1
,
4
)
np
.
testing
.
assert_array_equal
(
result
[
0
],
expected_result
)
class
TestTriuIndicesAPICaseDygraph
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
paddle
.
fluid
.
CUDAPlace
(
0
)
else
:
place
=
paddle
.
CPUPlace
()
with
fluid
.
dygraph
.
base
.
guard
(
place
=
place
):
out
=
paddle
.
triu_indices
(
4
,
4
,
2
)
expected_result
=
np
.
triu_indices
(
4
,
2
,
4
)
np
.
testing
.
assert_array_equal
(
out
,
expected_result
)
def
test_dygraph_eager
(
self
):
with
_test_eager_guard
():
self
.
test_dygraph
()
class
TestTriuIndicesAPICaseError
(
unittest
.
TestCase
):
def
test_case_error
(
self
):
def
test_num_rows_type_check
():
out1
=
paddle
.
triu_indices
(
1.0
,
1
,
2
)
self
.
assertRaises
(
TypeError
,
test_num_rows_type_check
)
def
test_num_columns_type_check
():
out2
=
paddle
.
triu_indices
(
4
,
-
1
,
2
)
self
.
assertRaises
(
TypeError
,
test_num_columns_type_check
)
def
test_num_offset_type_check
():
out3
=
paddle
.
triu_indices
(
4
,
4
,
2.0
)
self
.
assertRaises
(
TypeError
,
test_num_offset_type_check
)
class
TestTriuIndicesAPICaseDefault
(
unittest
.
TestCase
):
def
test_default_CPU
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
data
=
paddle
.
triu_indices
(
4
,
None
,
2
)
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
result
=
exe
.
run
(
feed
=
{},
fetch_list
=
[
data
])
expected_result
=
np
.
triu_indices
(
4
,
2
)
np
.
testing
.
assert_array_equal
(
result
[
0
],
expected_result
)
with
fluid
.
dygraph
.
base
.
guard
(
paddle
.
CPUPlace
()):
out
=
paddle
.
triu_indices
(
4
,
None
,
2
)
expected_result
=
np
.
triu_indices
(
4
,
2
)
np
.
testing
.
assert_array_equal
(
out
,
expected_result
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/creation.py
浏览文件 @
a410c397
...
...
@@ -1917,3 +1917,88 @@ def tril_indices(row, col, offset=0, dtype='int64'):
'dtype'
:
dtype
})
return
out
def
triu_indices
(
row
,
col
=
None
,
offset
=
0
,
dtype
=
'int64'
):
"""
Return the indices of the upper triangular part of the 2-D matrix
whose row and col is known. Indices are ordered based on row and then columns.
The upper triangular part of the matrix is defined as the elements on
and above the diagonal.
Args:
row (int): The input x which is a int number describe the number of row of the matrix.
col (int, optional): The input x which is a int number describe the number of col of the matrix.
default value for col is None, then it will be set equal to row, indicting a square matix.
offset (int, optional): The offset to consider, default value is 0.
- If offset = 0, all elements on and above the main diagonal are retained.
- If offset > 0, include just as few diagonals above the main diagonal.
- If offset < 0, excludes just as few diagonals below the main diagonal.
dtype (str|np.dtype|paddle.dtype, optional): the data type of the output tensor,
can be int32, int64, default value is int64.
Returns:
Tensor: Results of the indices of upper triangular part of a row * col matrix,
where the first row contains row coordinates of and the second row contains column coordinates.
Examples:
.. code-block:: python
import paddle
# example 1, default offset value
data1 = paddle.triu_indices(4,4,0)
print(data1)
# [[0, 0, 0, 0, 1, 1, 1, 2, 2, 3],
# [0, 1, 2, 3, 1, 2, 3, 2, 3, 3]]
# example 2, positive offset value
data2 = paddle.triu_indices(4,4,2)
print(data2)
# [[0, 0, 1],
# [2, 3, 3]]
# example 3, negative offset value
data3 = paddle.triu_indices(4,4,-1)
print(data3)
# [[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3],
# [0, 1, 2, 3, 0, 1, 2, 3, 1, 2, 3, 2, 3]]
"""
if
not
isinstance
(
row
,
int
)
or
row
<
0
:
raise
TypeError
(
"row should be a non-negative int"
)
if
col
is
not
None
:
if
not
isinstance
(
col
,
int
)
or
col
<
0
:
raise
TypeError
(
"col should be a non-negative int"
)
else
:
col
=
row
if
not
isinstance
(
offset
,
int
):
raise
TypeError
(
"offset should be a int"
)
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
if
in_dygraph_mode
():
out
=
_C_ops
.
final_state_triu_indices
(
row
,
col
,
offset
,
dtype
,
_current_expected_place
())
return
out
if
_in_legacy_dygraph
():
out
=
_C_ops
.
triu_indices
(
'row'
,
row
,
'col'
,
col
,
'offset'
,
offset
,
"dtype"
,
dtype
)
return
out
else
:
helper
=
LayerHelper
(
"triu_indices"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'triu_indices'
,
inputs
=
{},
outputs
=
{
'out'
:
[
out
]},
attrs
=
{
'row'
:
row
,
'col'
:
col
,
'offset'
:
offset
,
'dtype'
:
dtype
})
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录