未验证 提交 9567cbd7 编写于 作者: L liuyuhui 提交者: GitHub

[cherry-pick 2.1.1]2.1/fix concat (#33383)

* add unit8 for concat (#32850)

* add bool type for tril api (#33402)
上级 14440905
...@@ -233,7 +233,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -233,7 +233,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>); ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -242,4 +243,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -242,4 +243,5 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>); ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
...@@ -23,7 +23,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -23,7 +23,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>); ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
...@@ -31,4 +32,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -31,4 +32,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>); ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
...@@ -100,6 +100,8 @@ REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, ...@@ -100,6 +100,8 @@ REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradOpBaseMaker, ops::ReduceMeanDoubleGradOpBaseMaker,
ops::ReduceMeanGradNoNeedBufferVarInferer); ops::ReduceMeanGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(reduce_mean, REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>, float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
...@@ -110,5 +112,6 @@ using CPUReduceMeanGradKernel = ...@@ -110,5 +112,6 @@ using CPUReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
ops::MeanGradFunctor, true>; ops::MeanGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>, REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<bool>,
CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>); CPUReduceMeanGradKernel<double>);
...@@ -65,5 +65,6 @@ class ReduceMeanKernel : public framework::OpKernel<T> { ...@@ -65,5 +65,6 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>); ops::ReduceMeanKernel<double>);
...@@ -20,5 +20,6 @@ using CUDAReduceMeanGradKernel = ...@@ -20,5 +20,6 @@ using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>; ops::MeanGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>); CUDAReduceMeanGradKernel<double>);
...@@ -109,7 +109,9 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, ...@@ -109,7 +109,9 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInferer); ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float, reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, bool,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double, ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::SumFunctor>, ops::SumFunctor>,
...@@ -128,7 +130,8 @@ using CPUReduceSumGradKernel = ...@@ -128,7 +130,8 @@ using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T, ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>; ops::SumGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>, REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>, CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>, CPUReduceSumGradKernel<int64_t>,
......
...@@ -70,7 +70,8 @@ class ReduceSumKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,8 @@ class ReduceSumKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>,
ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>, ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>, ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>, ops::ReduceSumKernel<paddle::platform::complex64>,
......
...@@ -20,7 +20,8 @@ using CUDAReduceSumGradKernel = ...@@ -20,7 +20,8 @@ using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>; ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>, CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>, CUDAReduceSumGradKernel<int64_t>,
......
...@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ...@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>); ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>, tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>); ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
tril_triu_grad, tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>, ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
......
...@@ -18,7 +18,7 @@ namespace ops = paddle::operators; ...@@ -18,7 +18,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
tril_triu, tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
...@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -26,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>); ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
tril_triu_grad, tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>, ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
......
...@@ -576,7 +576,7 @@ def tril(x, diagonal=0, name=None): ...@@ -576,7 +576,7 @@ def tril(x, diagonal=0, name=None):
Args: Args:
x (Tensor): The input x which is a Tensor. x (Tensor): The input x which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``. Support data types: ``bool``, ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0. diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main retained. A positive value includes just as many diagonals above the main
......
...@@ -80,7 +80,7 @@ def concat(x, axis=0, name=None): ...@@ -80,7 +80,7 @@ def concat(x, axis=0, name=None):
Args: Args:
x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16, x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
float32, float64, int32, int64. All the Tensors in ``x`` must have same data type. float32, float64, int32, int64, uint8. All the Tensors in ``x`` must have same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors. axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册