未验证 提交 6d649d9e 编写于 作者: Z Zhuoyuan 提交者: GitHub

Merge pull request #5235 from zchen0211/develop

deconv cudnn forward passed
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2d_transpose_op.h"
namespace paddle {
namespace operators {
class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp,
ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad,
ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/conv2d_transpose_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;
template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
// N, M, H, W
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle();
// Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
// get workspace size able to allocate
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward ---------------------
T alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc,
input_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
template <typename T>
class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const T* input_data = input->data<T>();
const T* output_grad_data = output_grad->data<T>();
const T* filter_data = filter->data<T>();
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
// Input: (N, M, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
// ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionFwdAlgo_t data_algo;
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t bwd_filter_ws_size, fwd_ws_size;
size_t workspace_size_in_bytes = 0;
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) {
// choose backward algorithm for data
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, data_algo, &fwd_ws_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size);
}
if (filter_grad) {
// choose backward algorithm for filter
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
// get workspace for backwards filter algorithm
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc, filter_algo, &bwd_filter_ws_size));
workspace_size_in_bytes =
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data));
}
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
input_data, cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
......@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
#include "paddle/operators/conv2d_transpose_op.h"
namespace paddle {
namespace operators {
......@@ -95,13 +95,13 @@ void Conv2DTransposeOpGrad::InferShape(
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2d_transpose_grad,
ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose,
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose_grad,
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
......@@ -12,13 +12,13 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
#include "paddle/operators/conv2d_transpose_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2dtranspose,
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2dtranspose_grad,
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
......@@ -62,7 +62,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
// groups will alway be disabled in conv2d_transpose.
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
......
......@@ -45,23 +45,36 @@ class TestConv2dTransposeOp(OpTest):
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive(
input_, filter_, conv2dtranspose_param).astype('float32')
# print 'deconv output py', output, output.shape
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
# 'dilations': self.dilations
'dilations': self.dilations
}
self.outputs = {'Output': output}
def test_check_output(self):
print 'check output here'
print 'check output here for', self.op_type
self.check_output()
def test_check_grad(self):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose"
def test_check_grad_no_input(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
['Filter'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Input']))
def test_check_grad_no_filter(self):
self.check_grad(
......@@ -70,33 +83,15 @@ class TestConv2dTransposeOp(OpTest):
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
def test_check_grad(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Input']))
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestCudnn(TestConv2dTransposeOp):
def init_op_type(self):
self.op_type = "conv2dtranspose"
self.op_type = "conv2d_transpose_cudnn"
"""
class TestCudnn(TestConv2dOp):
def init_group(self):
self.groups = 1
def init_op_type(self):
self.op_type = "conv_cudnn"
"""
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册