未验证 提交 c4d03052 编写于 作者: W WuHaobo 提交者: GitHub

add tril op and triu op (#23469)

add tril op and  triu op
上级 3eb12bd1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"
#include <memory>
namespace paddle {
namespace operators {
class TrilTriuOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of TrilTriuOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of TrilTriuOp is not found."));
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Tensor, the input of tril_triu op");
AddOutput("Out",
"Tensor, the output tensor, with the same shape and data type as "
"input(x)");
AddAttr<int>("diagonal", "int number, the diagonal to consider.")
.SetDefault(0);
AddAttr<bool>("lower", "boolnumber, lower triangular or upper triangular.");
AddComment(R"DOC(
TrilTriu Operator.
The tril operator returns the lower triangular part of the matrix (2-D tensor)
or batch of matrices $input$. The lower triangular part of the matrix is defined
as the elements on and below the diagonal.
The triu operator returns the upper triangular part of a matrix (2-D tensor)
or batch of matrices $input$. The upper triangular part of the matrix is defined
as the elements on and above the diagonal.
The other elements of the result tensor out are set to 0.
The argument diagonal controls which diagonal to consider, default value is 0.
)DOC");
}
};
class TrilTriuGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@GRAD) of TrilTriuOp should not be null"));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Output(X@Grad) of TrilTriuOp should not be null"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
};
template <typename T>
class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tril_triu_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
tril_triu,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class TrilTriuCompute {
public:
HOSTDEVICE TrilTriuCompute(const T* in, const int diagonal, const bool lower,
const int64_t H, const int64_t W, T* out)
: in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {}
HOSTDEVICE void operator()(int64_t idx) {
const int64_t row = (idx / W_) % H_;
const int64_t col = idx % W_;
const bool mask =
lower_ ? (col - row > diagonal_) : (col - row < diagonal_);
out_[idx] = mask ? static_cast<T>(0) : in_[idx];
}
private:
const T* in_;
const int diagonal_;
const bool lower_;
const int64_t H_;
const int64_t W_;
T* out_;
};
template <typename DeviceContext, typename T>
class TrilTriuOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<framework::Tensor>("X");
const auto* x_data = x->data<T>();
auto* out = context.Output<framework::Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
const auto& dims = x->dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
paddle::operators::TrilTriuCompute<T> tril_triu_computer(
x_data, diagonal, lower, H, W, out_data);
for_range(tril_triu_computer);
}
};
template <typename DeviceContext, typename T>
class TrilTriuGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto* dout_data = d_out->data<T>();
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dx_data = d_x->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
const auto& dims = d_out->dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_out->numel()));
paddle::operators::TrilTriuCompute<T> tril_triu_grad_computer(
dout_data, diagonal, lower, H, W, dx_data);
for_range(tril_triu_grad_computer);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at #
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.tensor as tensor
class TrilTriuOpDefaultTest(OpTest):
""" the base class of other op testcases
"""
def setUp(self):
self.initTestCase()
self.real_np_op = getattr(np, self.real_op_type)
self.op_type = "tril_triu"
self.inputs = {'X': self.X}
self.attrs = {
'diagonal': self.diagonal,
'lower': True if self.real_op_type == 'tril' else False,
}
self.outputs = {
'Out': self.real_np_op(self.X, self.diagonal)
if self.diagonal else self.real_np_op(self.X)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None
self.X = np.arange(1, 101, dtype="float64").reshape([10, -1])
def case_generator(op_type, Xshape, diagonal, expected):
"""
Generate testcases with the params shape of X, diagonal and op_type.
If arg`expercted` is 'success', it will register an Optest case and expect to pass.
Otherwise, it will register an API case and check the expect failure.
"""
cls_name = "{0}_{1}_shape_{2}_diag_{3}".format(expected, op_type, Xshape,
diagonal)
errmsg = {
"diagonal: TypeError":
"diagonal in {} must be a python Int".format(op_type),
"input: ValueError":
"input shape in {} must be at least 2-D".format(op_type),
}
class FailureCase(unittest.TestCase):
def test_failure(self):
data = fluid.data(shape=Xshape, dtype='float64', name=cls_name)
with self.assertRaisesRegexp(
eval(expected.split(':')[-1]), errmsg[expected]):
getattr(tensor, op_type)(input=data, diagonal=diagonal)
class SuccessCase(TrilTriuOpDefaultTest):
def initTestCase(self):
self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float64")
CLASS = locals()['SuccessCase' if expected == "success" else 'FailureCase']
CLASS.__name__ = cls_name
globals()[cls_name] = CLASS
### NOTE: meaningful diagonal is [1 - min(H, W), max(H, W) -1]
### test the diagonal just at the border, upper/lower the border,
### negative/positive integer within range and a zero
cases = {
'success': {
(2, 2, 3, 4, 5): [-100, -3, -1, 0, 2, 4, 100], # normal shape
(10, 10, 1, 1): [-100, -1, 0, 1, 100], # small size of matrix
},
'diagonal: TypeError': {
(20, 20): [
'2020',
[20],
{
20: 20
},
(20, 20),
20.20,
], # str, list, dict, tuple, float
},
'input: ValueError': {
(2020, ): [None],
},
}
for _op_type in ['tril', 'triu']:
for _expected, _params in cases.items():
for _Xshape, _diaglist in _params.items():
list(
map(lambda _diagonal: case_generator(_op_type, _Xshape, _diagonal, _expected),
_diaglist))
class TestTrilTriuOpAPI(unittest.TestCase):
""" test case by using API and has -1 dimension
"""
def test_api(self):
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
tril_out, triu_out = tensor.tril(x), tensor.triu(x)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
tril_out, triu_out = exe.run(
fluid.default_main_program(),
feed={"x": data},
fetch_list=[tril_out, triu_out], )
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
if __name__ == '__main__':
unittest.main()
...@@ -32,8 +32,8 @@ from .creation import linspace #DEFINE_ALIAS ...@@ -32,8 +32,8 @@ from .creation import linspace #DEFINE_ALIAS
from .creation import full #DEFINE_ALIAS from .creation import full #DEFINE_ALIAS
# from .creation import linspace #DEFINE_ALIAS # from .creation import linspace #DEFINE_ALIAS
# from .creation import full_like #DEFINE_ALIAS # from .creation import full_like #DEFINE_ALIAS
# from .creation import triu #DEFINE_ALIAS from .creation import triu #DEFINE_ALIAS
# from .creation import tril #DEFINE_ALIAS from .creation import tril #DEFINE_ALIAS
# from .creation import meshgrid #DEFINE_ALIAS # from .creation import meshgrid #DEFINE_ALIAS
# from .stat import mean #DEFINE_ALIAS # from .stat import mean #DEFINE_ALIAS
# from .stat import reduce_mean #DEFINE_ALIAS # from .stat import reduce_mean #DEFINE_ALIAS
......
...@@ -34,8 +34,8 @@ __all__ = [ ...@@ -34,8 +34,8 @@ __all__ = [
# 'eye', # 'eye',
'full', 'full',
# 'full_like', # 'full_like',
# 'triu', 'triu',
# 'tril', 'tril',
# 'meshgrid' # 'meshgrid'
] ]
...@@ -404,3 +404,188 @@ def full(shape, ...@@ -404,3 +404,188 @@ def full(shape,
out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out) out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out)
return out return out
def _tril_triu_op(helper):
"""Base op of tril_op and triu_op
"""
op_type = helper.layer_type
x = helper.kwargs.get('input', None)
assert x is not None, 'x cannot be None in {}'.format(op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
op_type)
if len(x.shape) < 2:
raise ValueError("input shape in {} must be at least 2-D".format(
op_type))
diagonal = helper.kwargs.get('diagonal', 0)
if not isinstance(diagonal, (int, )):
raise TypeError("diagonal in {} must be a python Int".format(op_type))
name = helper.kwargs.get('name', None)
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="tril_triu",
inputs={"X": x},
attrs={
"diagonal": diagonal,
"lower": True if op_type == 'tril' else False,
},
outputs={"Out": out}, )
return out
def tril(input, diagonal=0, name=None):
"""
This op returns the lower triangular part of a matrix (2-D tensor) or batch
of matrices :attr:`input`, the other elements of the result tensor are set
to 0. The lower triangular part of the matrix is defined as the elements
on and below the diagonal.
Args:
input (Variable): The input variable which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main
diagonal, and similarly a negative value excludes just as many diagonals below
the main diagonal. The main diagonal are the set of indices
:math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where
:math:`d_{1}, d_{2}` are the dimensions of the matrix.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor, results of lower triangular operation by the specified diagonal of input tensor,
it's data type is the same as input's Tensor.
Raises:
TypeError: diagonal is not a int type.
ValueError: dimension of :attr:`input` is less than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle.tensor as tensor
import paddle.fluid as fluid
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
# example 1, default diagonal
tril = tensor.tril(x)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 1, 0, 0, 0],
# [ 5, 6, 0, 0],
# [ 9, 10, 11, 0]])
.. code-block:: python
# example 2, positive diagonal value
tril = tensor.tril(x, diagonal=2)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 1, 2, 3, 0],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
.. code-block:: python
# example 3, negative diagonal value
tril = tensor.tril(x, diagonal=-1)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 0, 0, 0, 0],
# [ 5, 0, 0, 0],
# [ 9, 10, 0, 0]])
"""
return _tril_triu_op(LayerHelper('tril', **locals()))
def triu(input, diagonal=0, name=None):
"""
This op returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
:attr:`input`, the other elements of the result tensor are set to 0.
The upper triangular part of the matrix is defined as the elements on and
above the diagonal.
Args:
input (Variable): The input variable which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and above the main diagonal are
retained. A positive value excludes just as many diagonals above the main
diagonal, and similarly a negative value includes just as many diagonals below
the main diagonal. The main diagonal are the set of indices
:math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where
:math:`d_{1}, d_{2}` are the dimensions of the matrix.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor, results of upper triangular operation by the specified diagonal of input tensor,
it's data type is the same as input's Tensor.
Raises:
TypeError: diagonal is not a int type.
ValueError: dimension of :attr:`input` is less than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle.tensor as tensor
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
# example 1, default diagonal
triu = tensor.triu(x)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[ 1, 2, 3, 4],
# [ 0, 6, 7, 8],
# [ 0, 0, 11, 12]])
.. code-block:: python
# example 2, positive diagonal value
triu = tensor.triu(x, diagonal=2)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[0, 0, 3, 4],
# [0, 0, 0, 8],
# [0, 0, 0, 0]])
.. code-block:: python
# example 3, negative diagonal value
triu = tensor.triu(x, diagonal=-1)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 0, 10, 11, 12]])
"""
return _tril_triu_op(LayerHelper('triu', **locals()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册