未验证 提交 1187c610 编写于 作者: C chentianyu03 提交者: GitHub

modify to complex template types for fill_constant op (#33179)

* modify to complex template types for fill_constant op

* modify to complex template types for py_layer, strided_slice and reduce_sum_op.part
上级 cf08bab2
......@@ -147,16 +147,15 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex64>,
ops::FillConstantKernel<paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
REGISTER_OP_VERSION(fill_constant)
.AddCheckpoint(
......
......@@ -15,12 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex64>,
ops::FillConstantKernel<paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
......@@ -15,11 +15,10 @@ limitations under the License. */
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<paddle::platform::complex64>,
ops::FillConstantKernel<paddle::platform::complex128>);
REGISTER_OP_XPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
#endif
......@@ -199,9 +199,9 @@ REGISTER_OP_CPU_KERNEL(
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex64>,
::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex128>);
::paddle::platform::complex<double>>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
py_layer, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -218,7 +218,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex64>,
::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex128>);
::paddle::platform::complex<double>>);
#endif // PADDLE_WITH_CUDA
......@@ -20,10 +20,9 @@ using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex64>,
CUDAReduceSumGradKernel<paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>, CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex<float>>,
CUDAReduceSumGradKernel<paddle::platform::complex<double>>);
......@@ -329,9 +329,9 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
......@@ -340,6 +340,6 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......@@ -24,9 +23,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
......@@ -35,6 +34,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册