未验证 提交 9567cbd7 编写于 作者: L liuyuhui 提交者: GitHub

[cherry-pick 2.1.1]2.1/fix concat (#33383)

* add unit8 for concat (#32850)

* add bool type for tril api (#33402)
上级 14440905
......@@ -233,7 +233,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
......@@ -242,4 +243,5 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
......@@ -23,7 +23,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
......@@ -31,4 +32,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
......@@ -100,6 +100,8 @@ REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradOpBaseMaker,
ops::ReduceMeanGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
......@@ -110,5 +112,6 @@ using CPUReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
ops::MeanGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<bool>,
CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>);
......@@ -65,5 +65,6 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>);
......@@ -20,5 +20,6 @@ using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>,
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
......@@ -109,7 +109,9 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, bool,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::SumFunctor>,
......@@ -128,7 +130,8 @@ using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>,
......
......@@ -70,7 +70,8 @@ class ReduceSumKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>,
ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>,
......
......@@ -20,7 +20,8 @@ using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>,
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>,
......
......@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
......
......@@ -18,7 +18,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
tril_triu,
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
......@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -576,7 +576,7 @@ def tril(x, diagonal=0, name=None):
Args:
x (Tensor): The input x which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main
......
......@@ -80,7 +80,7 @@ def concat(x, axis=0, name=None):
Args:
x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
float32, float64, int32, int64. All the Tensors in ``x`` must have same data type.
float32, float64, int32, int64, uint8. All the Tensors in ``x`` must have same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册