From 741811e02347c814ad95eff7552af7e3edc02c46 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Wed, 9 Jun 2021 14:30:53 +0800 Subject: [PATCH] add bool type for tril api (#33402) --- paddle/fluid/operators/tril_triu_op.cc | 4 +++- paddle/fluid/operators/tril_triu_op.cu | 3 ++- python/paddle/tensor/creation.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index 8fb0b380950..3e943c62e1c 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); REGISTER_OP_CPU_KERNEL( - tril_triu, ops::TrilTriuOpKernel, + tril_triu, ops::TrilTriuOpKernel, + ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel); REGISTER_OP_CPU_KERNEL( tril_triu_grad, + ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, diff --git a/paddle/fluid/operators/tril_triu_op.cu b/paddle/fluid/operators/tril_triu_op.cu index d04acd34059..9cbbdeeb2ce 100644 --- a/paddle/fluid/operators/tril_triu_op.cu +++ b/paddle/fluid/operators/tril_triu_op.cu @@ -18,7 +18,7 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - tril_triu, + tril_triu, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, @@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL( ops::TrilTriuOpKernel); REGISTER_OP_CUDA_KERNEL( tril_triu_grad, + ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 7f37ab488f6..51a5e8df0fc 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -585,7 +585,7 @@ def tril(x, diagonal=0, name=None): Args: x (Tensor): The input x which is a Tensor. - Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``. diagonal (int, optional): The diagonal to consider, default value is 0. If :attr:`diagonal` = 0, all elements on and below the main diagonal are retained. A positive value includes just as many diagonals above the main -- GitLab