未验证 提交 1187c610 编写于 作者: C chentianyu03 提交者: GitHub

modify to complex template types for fill_constant op (#33179)

* modify to complex template types for fill_constant op

* modify to complex template types for py_layer, strided_slice and reduce_sum_op.part
上级 cf08bab2
...@@ -147,16 +147,15 @@ REGISTER_OPERATOR( ...@@ -147,16 +147,15 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_CPU_KERNEL(
ops::FillConstantKernel<double>, fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<uint8_t>, ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>, ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>, ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex64>, ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex128>); ops::FillConstantKernel<paddle::platform::complex<double>>);
REGISTER_OP_VERSION(fill_constant) REGISTER_OP_VERSION(fill_constant)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -15,12 +15,11 @@ limitations under the License. */ ...@@ -15,12 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_CUDA_KERNEL(
ops::FillConstantKernel<double>, fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<uint8_t>, ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>, ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex64>, ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex128>); ops::FillConstantKernel<paddle::platform::complex<double>>);
...@@ -15,11 +15,10 @@ limitations under the License. */ ...@@ -15,11 +15,10 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_XPU_KERNEL(
ops::FillConstantKernel<int64_t>, fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<bool>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int>, ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex64>, ops::FillConstantKernel<paddle::platform::complex<double>>);
ops::FillConstantKernel<paddle::platform::complex128>);
#endif #endif
...@@ -199,9 +199,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -199,9 +199,9 @@ REGISTER_OP_CPU_KERNEL(
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex64>, ::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex128>); ::paddle::platform::complex<double>>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
py_layer, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, float>, py_layer, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -218,7 +218,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -218,7 +218,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int16_t>, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex64>, ::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex128>); ::paddle::platform::complex<double>>);
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
...@@ -20,10 +20,9 @@ using CUDAReduceSumGradKernel = ...@@ -20,10 +20,9 @@ using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>; ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>, REGISTER_OP_CUDA_KERNEL(
CUDAReduceSumGradKernel<float>, reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<double>, CUDAReduceSumGradKernel<float>, CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<int64_t>, CUDAReduceSumGradKernel<paddle::platform::complex<float>>,
CUDAReduceSumGradKernel<paddle::platform::complex64>, CUDAReduceSumGradKernel<paddle::platform::complex<double>>);
CUDAReduceSumGradKernel<paddle::platform::complex128>);
...@@ -329,9 +329,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -329,9 +329,9 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>, ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>, ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
strided_slice_grad, strided_slice_grad,
...@@ -340,6 +340,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -340,6 +340,6 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>, ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>, ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h" #include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -24,9 +23,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,9 +23,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>, ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>, ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
strided_slice_grad, strided_slice_grad,
...@@ -35,6 +34,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,6 +34,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>, ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>, ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册