提交 3bd3cc0c 编写于 作者: C chengduoZH

add double type for kernel

上级 dec61ab6
...@@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, ...@@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv_cudnn, REGISTER_OP_CPU_KERNEL(conv_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad, conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
...@@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>); REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>); paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
...@@ -61,10 +61,12 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, ...@@ -61,10 +61,12 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn, conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
...@@ -72,7 +74,9 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, ...@@ -72,7 +74,9 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn, conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad, conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
...@@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>); ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>); ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>); ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>); ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
...@@ -20,14 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad, ...@@ -20,14 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
ops::PoolOpGrad); ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn, REGISTER_OP_CPU_KERNEL(pool2d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>); ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad, REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>) ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad, REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad,
ops::PoolOpGrad); ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn, REGISTER_OP_CPU_KERNEL(pool3d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>); ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad, REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>) ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
...@@ -162,8 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -162,8 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>); REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>,
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>); ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>); ops::PoolCudnnGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
...@@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad, ...@@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
ops::PoolOpGrad); ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d, REGISTER_OP_CPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::CPUPlace, float>); ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_grad, REGISTER_OP_CPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>) ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad, REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
ops::PoolOpGrad); ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d, REGISTER_OP_CPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::CPUPlace, float>); ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_grad, REGISTER_OP_CPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>); ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>);
...@@ -17,11 +17,15 @@ limitations under the License. */ ...@@ -17,11 +17,15 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d, REGISTER_OP_GPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::GPUPlace, float>); ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool2d_grad, REGISTER_OP_GPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>); ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d, REGISTER_OP_GPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::GPUPlace, float>); ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d_grad, REGISTER_OP_GPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>); ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
...@@ -250,10 +250,12 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, ...@@ -250,10 +250,12 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
...@@ -261,7 +263,9 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ...@@ -261,7 +263,9 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double>)
...@@ -18,14 +18,18 @@ namespace ops = paddle::operators; ...@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double>)
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double>)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册