未验证 提交 741811e0 编写于 作者: L liuyuhui 提交者: GitHub

add bool type for tril api (#33402)

上级 e08fdd16
......@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
......
......@@ -18,7 +18,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
tril_triu,
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
......@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -585,7 +585,7 @@ def tril(x, diagonal=0, name=None):
Args:
x (Tensor): The input x which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册