提交 b720f282 编写于 作者: Z zchen0211

deconv modify

上级 4e228021
...@@ -38,13 +38,13 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { ...@@ -38,13 +38,13 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(conv2dtranspose_cudnn, ops::Conv2DTransposeOp, REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp,
ops::CudnnConv2DTransposeOpMaker, conv2dtranspose_cudnn_grad, ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad,
ops::Conv2DTransposeOpGrad); ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2dtranspose_cudnn, conv2d_transpose_cudnn,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2dtranspose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/operators/conv2d_op.h" #include "paddle/operators/conv2dtranspose_op.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/platform/cudnn_helper.h" #include "paddle/platform/cudnn_helper.h"
...@@ -76,7 +76,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -76,7 +76,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
// cudnnConvolutionBwdAlgo_t algo;
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
// Get the algorithm // Get the algorithm
...@@ -92,7 +91,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -92,7 +91,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes)); cudnn_output_desc, algo, &workspace_size_in_bytes));
// workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
// Allocate on GPU memory // Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace()); platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
...@@ -234,7 +232,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -234,7 +232,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>); ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>); ops::CudnnConvTransposeGradOpKernel<float>);
...@@ -45,13 +45,12 @@ class TestConv2dTransposeOp(OpTest): ...@@ -45,13 +45,12 @@ class TestConv2dTransposeOp(OpTest):
filter_ = np.random.random(self.filter_size).astype("float32") filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive( output = conv2dtranspose_forward_naive(
input_, filter_, conv2dtranspose_param).astype('float32') input_, filter_, conv2dtranspose_param).astype('float32')
# print 'deconv output py', output, output.shape
self.inputs = {'Input': input_, 'Filter': filter_} self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
# 'dilations': self.dilations 'dilations': self.dilations
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -91,7 +90,7 @@ class TestConv2dTransposeOp(OpTest): ...@@ -91,7 +90,7 @@ class TestConv2dTransposeOp(OpTest):
class TestCudnn(TestConv2dTransposeOp): class TestCudnn(TestConv2dTransposeOp):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv2dtranspose_cudnn" self.op_type = "conv2d_transpose_cudnn"
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册