未验证 提交 a06bec12 编写于 作者: 武毅 提交者: GitHub

Conv cudnn 3d (#5783)

* conv cudnn 3d

* update test case

* update

* update

* follow comments and remove groups from helper

* update

* refine

* update

* follow comments2

* update

* fix compile
上级 52a73587
...@@ -73,6 +73,13 @@ function(op_library TARGET) ...@@ -73,6 +73,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(conv2d);\n") file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif() endif()
# conv_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "conv_cudnn_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_cudnn);\n")
endif()
# pool_op contains several operators # pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op") if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1) set(pybind_flag 1)
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CudnnConvOpMaker : public Conv2DOpMaker { class CudnnConv2DOpMaker : public Conv2DOpMaker {
public: public:
CudnnConvOpMaker(framework::OpProto* proto, CudnnConv2DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) { : Conv2DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
...@@ -32,16 +32,43 @@ class CudnnConvOpMaker : public Conv2DOpMaker { ...@@ -32,16 +32,43 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
} }
}; };
class CudnnConv3DOpMaker : public Conv3DOpMaker {
public:
CudnnConv3DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv3DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, REGISTER_OP(conv2d_cudnn, ops::ConvOp, ops::CudnnConv2DOpMaker,
ops::ConvOpGrad); conv2d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP(conv3d_cudnn, ops::ConvOp, ops::CudnnConv3DOpMaker,
conv3d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(conv_cudnn, REGISTER_OP_CPU_KERNEL(conv3d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>, ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>); ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>, conv3d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
...@@ -56,6 +56,21 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -56,6 +56,21 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
ScopedFilterDescriptor filter_desc; ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
#if CUDNN_VERSION_MIN(7, 0, 0)
// cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
groups = 1;
#endif
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
...@@ -63,19 +78,34 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -63,19 +78,34 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
layout, framework::vectorize2int(output->dims()), groups); layout, framework::vectorize2int(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>( cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
int input_channels = input->dims()[1]; int input_channels = input->dims()[1];
int input_height = input->dims()[2]; int input_height, input_width, input_depth;
int input_width = input->dims()[3]; if (input->dims().size() == 5) {
int output_channels = output->dims()[1]; input_depth = input->dims()[2];
int output_height = output->dims()[2]; input_height = input->dims()[3];
int output_width = output->dims()[3]; input_width = input->dims()[4];
} else { // dim size is enforced in InferShape
input_depth = 1;
input_height = input->dims()[2];
input_width = input->dims()[3];
}
int output_channels = filter->dims()[0];
int output_height, output_width, output_depth;
if (output->dims().size() == 5) {
output_depth = output->dims()[2];
output_height = output->dims()[3];
output_width = output->dims()[4];
} else {
output_depth = 1;
output_height = output->dims()[2];
output_width = output->dims()[3];
}
int group_offset_in = input_channels / groups * input_height * input_width; int group_offset_in =
input_channels / groups * input_height * input_width * input_depth;
int group_offset_out = int group_offset_out =
output_channels / groups * output_height * output_width; output_channels / groups * output_height * output_width * output_depth;
int group_offset_filter = filter->numel() / groups; int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr; void* cudnn_workspace = nullptr;
...@@ -138,12 +168,26 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -138,12 +168,26 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc; ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_grad_desc; ScopedTensorDescriptor output_grad_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedFilterDescriptor filter_desc; ScopedFilterDescriptor filter_desc;
ScopedFilterDescriptor filter_grad_desc; ScopedFilterDescriptor filter_grad_desc;
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
#if CUDNN_VERSION_MIN(7, 0, 0)
// cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
groups = 1;
#endif
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
...@@ -152,22 +196,35 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -152,22 +196,35 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
layout, framework::vectorize2int(output_grad->dims()), groups); layout, framework::vectorize2int(output_grad->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>( cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr;
cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr;
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
int input_channels = input->dims()[1]; int input_channels = input->dims()[1];
int input_height = input->dims()[2]; int input_height, input_width, input_depth;
int input_width = input->dims()[3]; if (input->dims().size() == 5) {
input_depth = input->dims()[2];
input_height = input->dims()[3];
input_width = input->dims()[4];
} else { // dim size is enforced in InferShape
input_depth = 1;
input_height = input->dims()[2];
input_width = input->dims()[3];
}
int output_grad_channels = filter->dims()[0]; int output_grad_channels = filter->dims()[0];
int output_grad_height = output_grad->dims()[2]; int output_grad_height, output_grad_width, output_grad_depth;
int output_grad_width = output_grad->dims()[3]; if (input->dims().size() == 5) {
output_grad_depth = output_grad->dims()[2];
output_grad_height = output_grad->dims()[3];
output_grad_width = output_grad->dims()[4];
} else {
output_grad_depth = 1;
output_grad_height = output_grad->dims()[2];
output_grad_width = output_grad->dims()[3];
}
int group_offset_in = input_channels / groups * input_height * input_width; int group_offset_in =
int group_offset_out = input_channels / groups * input_height * input_width * input_depth;
output_grad_channels / groups * output_grad_height * output_grad_width; int group_offset_out = output_grad_channels / groups * output_grad_height *
output_grad_width * output_grad_depth;
int group_offset_filter = filter->numel() / groups; int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn backward algorithm --------------------- // ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t data_algo; cudnnConvolutionBwdDataAlgo_t data_algo;
...@@ -180,8 +237,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -180,8 +237,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) { if (input_grad) {
cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
layout, framework::vectorize2int(input_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
...@@ -190,19 +245,17 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -190,19 +245,17 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
cudnn_output_grad_desc, cudnn_conv_desc, cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor // dxDesc: Handle to the previously initialized output tensor
// descriptor. // descriptor.
cudnn_input_grad_desc, cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo)); workspace_size_limit, &data_algo));
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc, handle, cudnn_filter_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size)); cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
} }
if (filter_grad) { if (filter_grad) {
cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>(
layout, framework::vectorize2int(filter_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
...@@ -222,7 +275,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -222,7 +275,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace()); platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
...@@ -233,21 +285,20 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -233,21 +285,20 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
handle, &alpha, cudnn_filter_desc, handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc, filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
cudnn_input_grad_desc, input_grad_data + i * group_offset_in)); input_grad_data + i * group_offset_in));
} }
} }
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
if (filter_grad) { if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_output_grad_desc, output_grad_data + i * group_offset_out, cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
cudnn_conv_desc, filter_algo, cudnn_workspace, cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_grad_desc, workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + i * group_offset_filter)); filter_grad_data + i * group_offset_filter));
} }
} }
...@@ -259,8 +310,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -259,8 +310,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>, REGISTER_OP_GPU_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>); paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv3d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>, paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>); paddle::operators::CudnnConvGradOpKernel<double>);
...@@ -116,7 +116,7 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat( ...@@ -116,7 +116,7 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat(
case DataLayout::kNCHW: case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW; return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW: case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW return CUDNN_TENSOR_NCHW; // NOTE: cudnn treat NdTensor as the same
default: default:
PADDLE_THROW("Unknown cudnn equivalent for order"); PADDLE_THROW("Unknown cudnn equivalent for order");
} }
...@@ -143,7 +143,7 @@ class ScopedTensorDescriptor { ...@@ -143,7 +143,7 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1]; strides[i] = dims[i + 1] * strides[i + 1];
} }
// Update tensor descriptor dims setting if groups > 1 // Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW or NCDHW order // NOTE: Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
if (groups > 1) { if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups; dims_with_group[1] = dims_with_group[1] / groups;
...@@ -186,7 +186,6 @@ class ScopedFilterDescriptor { ...@@ -186,7 +186,6 @@ class ScopedFilterDescriptor {
// width of the filter. // width of the filter.
std::vector<int> kernel_with_group(kernel.begin(), kernel.end()); std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
if (groups > 1) { if (groups > 1) {
// M /= groups
kernel_with_group[0] /= groups; kernel_with_group[0] /= groups;
// NOTE: input filter(C) of the filter is already asserted to be C/groups. // NOTE: input filter(C) of the filter is already asserted to be C/groups.
} }
......
...@@ -16,8 +16,8 @@ def conv2d_forward_naive(input, filter, group, conv_param): ...@@ -16,8 +16,8 @@ def conv2d_forward_naive(input, filter, group, conv_param):
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1] out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1]
out = np.zeros((in_n, out_c, out_h, out_w)) out = np.zeros((in_n, out_c, out_h, out_w))
d_bolck_w = (dilation[0] * (f_h - 1) + 1) d_bolck_h = (dilation[0] * (f_h - 1) + 1)
d_bolck_h = (dilation[1] * (f_w - 1) + 1) d_bolck_w = (dilation[1] * (f_w - 1) + 1)
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )),
mode='constant', mode='constant',
...@@ -167,27 +167,27 @@ class TestWithDilation(TestConv2dOp): ...@@ -167,27 +167,27 @@ class TestWithDilation(TestConv2dOp):
#----------------Conv2dCudnn---------------- #----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp): class TestCudnn(TestConv2dOp):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithPad(TestWithPad): class TestCudnnWithPad(TestWithPad):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithStride(TestWithStride): class TestCudnnWithStride(TestWithStride):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithGroup(TestWithGroup): class TestCudnnWithGroup(TestWithGroup):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWith1x1(TestWith1x1): class TestCudnnWith1x1(TestWith1x1):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d_cudnn"
# cudnn v5 does not support dilation conv. # cudnn v5 does not support dilation conv.
......
...@@ -169,5 +169,31 @@ class TestWithDilation(TestConv3dOp): ...@@ -169,5 +169,31 @@ class TestWithDilation(TestConv3dOp):
self.groups = 3 self.groups = 3
class TestCudnn(TestConv3dOp):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
class TestWithGroup1Cudnn(TestWithGroup1):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
class TestWithGroup2Cudnn(TestWithGroup2):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
class TestWith1x1Cudnn(TestWith1x1):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCudnn(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv3d_cudnn"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册