diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index 8fb0b3809503ecc86e33796a4bc7f7cb2d21f8bb..3e943c62e1ce17857e78e140efeb50e627e80a4e 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); REGISTER_OP_CPU_KERNEL( - tril_triu, ops::TrilTriuOpKernel, + tril_triu, ops::TrilTriuOpKernel, + ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel); REGISTER_OP_CPU_KERNEL( tril_triu_grad, + ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, diff --git a/paddle/fluid/operators/tril_triu_op.cu b/paddle/fluid/operators/tril_triu_op.cu index d04acd340597928ba0fbbbebf2dfc7eda1d698ac..9cbbdeeb2ce28453f2c22d063975fa82aae5d3b3 100644 --- a/paddle/fluid/operators/tril_triu_op.cu +++ b/paddle/fluid/operators/tril_triu_op.cu @@ -18,7 +18,7 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - tril_triu, + tril_triu, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, @@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL( ops::TrilTriuOpKernel); REGISTER_OP_CUDA_KERNEL( tril_triu_grad, + ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 7f37ab488f6bd91dae3cdfc05c8d7207dede5281..51a5e8df0fc2f2f6f0853148d3ff9cc3a347e1d6 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -585,7 +585,7 @@ def tril(x, diagonal=0, name=None): Args: x (Tensor): The input x which is a Tensor. - Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``. diagonal (int, optional): The diagonal to consider, default value is 0. If :attr:`diagonal` = 0, all elements on and below the main diagonal are retained. A positive value includes just as many diagonals above the main