提交 6e13c86d 编写于 作者: Y Yibing Liu

Enable multiple groups for cudnn conv transpose

上级 669c0df6
......@@ -44,6 +44,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const T* input_data = input->data<T>();
......@@ -64,13 +65,13 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
// (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
layout, framework::vectorize2int(input->dims()), groups);
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
layout, framework::vectorize2int(output->dims()), groups);
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
......@@ -104,11 +105,17 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups;
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc,
input_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
......@@ -134,6 +141,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
// ------------------- cudnn descriptors ---------------------
......@@ -145,13 +153,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
layout, framework::vectorize2int(input->dims()), groups);
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
layout, framework::vectorize2int(output_grad->dims()), groups);
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
......@@ -205,15 +213,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups;
int output_grad_offset =
output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data));
for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
}
}
// ------------------- cudnn conv backward filter ---------------------
......@@ -221,11 +236,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
input_data, cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data));
for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + filter_offset * g));
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
......
......@@ -227,6 +227,21 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv2d_transpose"
class TestCUDNNWithGroups(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册