未验证 提交 741811e0 编写于 作者: L liuyuhui 提交者: GitHub

add bool type for tril api (#33402)

上级 e08fdd16
...@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ...@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>); ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>, tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>); ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
tril_triu_grad, tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
......
...@@ -18,7 +18,7 @@ namespace ops = paddle::operators; ...@@ -18,7 +18,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
tril_triu, tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
...@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>); ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
tril_triu_grad, tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
......
...@@ -585,7 +585,7 @@ def tril(x, diagonal=0, name=None): ...@@ -585,7 +585,7 @@ def tril(x, diagonal=0, name=None):
Args: Args:
x (Tensor): The input x which is a Tensor. x (Tensor): The input x which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``. Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0. diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main retained. A positive value includes just as many diagonals above the main
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册