From 6fb4bb8efea3c21ef33b8568069c1cbc2a38a381 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 13 Nov 2017 17:58:44 +0800 Subject: [PATCH] add conv3d_trans_cudnn_op unit test --- paddle/operators/conv_transpose_cudnn_op.cc | 19 ++++++++++++++++++- .../tests/test_conv3d_transpose_op.py | 6 ++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/paddle/operators/conv_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc index 7ec3319cd0c..dbd1bc3c3bc 100644 --- a/paddle/operators/conv_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { framework::OpAttrChecker* op_checker) : Conv2DTransposeOpMaker(proto, op_checker) { AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault(std::vector{1, 1}); + .SetDefault({1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker { + public: + CudnnConv3DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv3DTransposeOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault({1, 1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " diff --git a/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py b/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py index 132fe793143..73ee260c5ab 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py @@ -93,5 +93,11 @@ class TestConv3dTransposeOp(OpTest): self.op_type = "conv3d_transpose" +# ------------ test_cudnn ------------ +class TestCudnn(TestConv3dTransposeOp): + def init_op_type(self): + self.op_type = "conv3d_transpose_cudnn" + + if __name__ == '__main__': unittest.main() -- GitLab