From 1187c610fcdb383d861b12a7d350dbd4d72d58db Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 28 May 2021 17:39:06 +0800 Subject: [PATCH] modify to complex template types for fill_constant op (#33179) * modify to complex template types for fill_constant op * modify to complex template types for py_layer, strided_slice and reduce_sum_op.part --- paddle/fluid/operators/fill_constant_op.cc | 19 +++++++++---------- paddle/fluid/operators/fill_constant_op.cu.cc | 17 ++++++++--------- .../fluid/operators/fill_constant_op_xpu.cc | 13 ++++++------- paddle/fluid/operators/py_layer_op.cc | 8 ++++---- .../reduce_ops/reduce_sum_op.part.cu | 13 ++++++------- paddle/fluid/operators/strided_slice_op.cc | 8 ++++---- paddle/fluid/operators/strided_slice_op.cu | 11 +++++------ 7 files changed, 42 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index f35d8b6bbf8..d465e77ea18 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -147,16 +147,15 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_CPU_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); REGISTER_OP_VERSION(fill_constant) .AddCheckpoint( diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index e784c20b8b8..a862cda1388 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -15,12 +15,11 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_constant_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_CUDA_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); diff --git a/paddle/fluid/operators/fill_constant_op_xpu.cc b/paddle/fluid/operators/fill_constant_op_xpu.cc index 16dd4c9292f..d55b8e2b81b 100644 --- a/paddle/fluid/operators/fill_constant_op_xpu.cc +++ b/paddle/fluid/operators/fill_constant_op_xpu.cc @@ -15,11 +15,10 @@ limitations under the License. */ namespace ops = paddle::operators; #ifdef PADDLE_WITH_XPU -REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_XPU_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); #endif diff --git a/paddle/fluid/operators/py_layer_op.cc b/paddle/fluid/operators/py_layer_op.cc index f91496eeab1..c2f68675beb 100644 --- a/paddle/fluid/operators/py_layer_op.cc +++ b/paddle/fluid/operators/py_layer_op.cc @@ -199,9 +199,9 @@ REGISTER_OP_CPU_KERNEL( ops::PyLayerOpKernel, ops::PyLayerOpKernel, ops::PyLayerOpKernel, + ::paddle::platform::complex>, ops::PyLayerOpKernel); + ::paddle::platform::complex>); #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL( py_layer, ops::PyLayerOpKernel, @@ -218,7 +218,7 @@ REGISTER_OP_CUDA_KERNEL( ops::PyLayerOpKernel, ops::PyLayerOpKernel, ops::PyLayerOpKernel, + ::paddle::platform::complex>, ops::PyLayerOpKernel); + ::paddle::platform::complex>); #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 67de8bb9a0c..230bae0cdd4 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -20,10 +20,9 @@ using CUDAReduceSumGradKernel = ops::ReduceGradKernel; -REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_sum_grad, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel>, + CUDAReduceSumGradKernel>); diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index e49476e4dc7..d71be60e1f5 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -329,9 +329,9 @@ REGISTER_OP_CPU_KERNEL( ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, + paddle::platform::complex>, ops::StridedSliceKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( strided_slice_grad, @@ -340,6 +340,6 @@ REGISTER_OP_CPU_KERNEL( ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, + paddle::platform::complex>, ops::StridedSliceGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/strided_slice_op.cu b/paddle/fluid/operators/strided_slice_op.cu index b85403b1c5b..68a8312f081 100644 --- a/paddle/fluid/operators/strided_slice_op.cu +++ b/paddle/fluid/operators/strided_slice_op.cu @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/strided_slice_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( @@ -24,9 +23,9 @@ REGISTER_OP_CUDA_KERNEL( ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, + paddle::platform::complex>, ops::StridedSliceKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( strided_slice_grad, @@ -35,6 +34,6 @@ REGISTER_OP_CUDA_KERNEL( ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, + paddle::platform::complex>, ops::StridedSliceGradKernel); + paddle::platform::complex>); -- GitLab