提交 f302c6a3 编写于 作者: C chengduoZH

write conv2d and conv3d together

上级 ba7db29d
...@@ -41,8 +41,8 @@ namespace ops = paddle::operators; ...@@ -41,8 +41,8 @@ namespace ops = paddle::operators;
REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(conv_cudnn,
conv_cudnn, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad, conv_cudnn_grad,
ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>); ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
...@@ -198,12 +198,12 @@ namespace ops = paddle::operators; ...@@ -198,12 +198,12 @@ namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d, ops::GemmConv3DKernel<paddle::platform::CPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGrad3DKernel<paddle::platform::CPUPlace, float>);
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d, ops::GemmConv3DKernel<paddle::platform::GPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGrad3DKernel<paddle::platform::GPUPlace, float>);
此差异已折叠。
...@@ -61,25 +61,23 @@ class TestConv2dOp(OpTest): ...@@ -61,25 +61,23 @@ class TestConv2dOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
# self.groups = 1
# self.op_type = "conv2d"
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1] self.dilations = [1, 1]
......
...@@ -64,20 +64,20 @@ class TestConv3dOp(OpTest): ...@@ -64,20 +64,20 @@ class TestConv3dOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.03,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.03,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册