未验证 提交 0f83674c 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5603 from chengduoZH/Add_conv3d_transpose_cudnn_op

add conv3d_trans_cudnn_op
...@@ -61,6 +61,18 @@ function(op_library TARGET) ...@@ -61,6 +61,18 @@ function(op_library TARGET)
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# pool_op contains several operators # pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op") if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1) set(pybind_flag 1)
...@@ -68,9 +80,11 @@ function(op_library TARGET) ...@@ -68,9 +80,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n") file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif() endif()
if ("${TARGET}" STREQUAL "compare_op") # pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1) set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif() endif()
# pool_with_index_op contains several operators # pool_with_index_op contains several operators
...@@ -80,25 +94,18 @@ function(op_library TARGET) ...@@ -80,25 +94,18 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif() endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# conv_transpose_op contains several operators # conv_transpose_op contains several operators
if ("${TARGET}" STREQUAL "conv_transpose_op") if ("${TARGET}" STREQUAL "conv_transpose_op")
set(pybind_flag 1) set(pybind_flag 1)
# It's enough to just adding one operator to pybind # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
endif() endif()
# pool_cudnn_op contains several operators # conv_transpose_cudnn_op contains two operators
if ("${TARGET}" STREQUAL "pool_cudnn_op") if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
set(pybind_flag 1) set(pybind_flag 1)
# It's enough to just adding one operator to pybind # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
endif() endif()
# save_restore_op contains several operators # save_restore_op contains several operators
......
...@@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad); // Because beta is zero, it is unnecessary to reset input_grad.
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, handle, &alpha, cudnn_filter_desc,
...@@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
if (filter_grad) { if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad); // Because beta is zero, it is unnecessary to reset filter_grad.
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
......
...@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ...@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d, REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(conv3d, REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d, REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>); ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(conv3d, REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>); ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
...@@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { ...@@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) { : Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.") AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1}); .SetDefault({1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be " "workspace is a section of GPU memory which will be "
...@@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL( ...@@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
...@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor output_desc; ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc; ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
// N, M, H, W // (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims())); layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims())); layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>( cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims())); layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc = cudnnConvolutionDescriptor_t cudnn_conv_desc =
...@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
// Input: (N, M, H, W) // Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims())); layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W) // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims())); layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W) // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>( cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims())); layout, framework::vectorize2int(filter->dims()));
...@@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), input_grad, 0); // Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data, handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo, cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
...@@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
if (filter_grad) { if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), filter_grad, 0); // Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter // Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
...@@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ...@@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>); ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>); ops::CudnnConvTransposeGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
...@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, ...@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad); conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
...@@ -18,14 +18,18 @@ namespace ops = paddle::operators; ...@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
...@@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::GPUPlace, T> set_zero; // Because beta is zero, it is unnecessary to reset input_grad.
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
......
...@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \ } \
} while (false) } while (false)
enum class DataLayout { enum class DataLayout { // Not use
kNHWC, kNHWC,
kNCHW, kNCHW,
kNCDHW,
kNCHW_VECT_C, kNCHW_VECT_C,
}; };
...@@ -107,12 +108,15 @@ class CudnnDataType<double> { ...@@ -107,12 +108,15 @@ class CudnnDataType<double> {
} }
}; };
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { inline cudnnTensorFormat_t GetCudnnTensorFormat(
const DataLayout& order) { // Not use
switch (order) { switch (order) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC; return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW: case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW; return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default: default:
PADDLE_THROW("Unknown cudnn equivalent for order"); PADDLE_THROW("Unknown cudnn equivalent for order");
} }
...@@ -139,7 +143,7 @@ class ScopedTensorDescriptor { ...@@ -139,7 +143,7 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1]; strides[i] = dims[i + 1] * strides[i + 1];
} }
// Update tensor descriptor dims setting if groups > 1 // Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW order // FIXME(typhoonzero): Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
if (groups > 1) { if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups; dims_with_group[1] = dims_with_group[1] / groups;
...@@ -176,9 +180,10 @@ class ScopedFilterDescriptor { ...@@ -176,9 +180,10 @@ class ScopedFilterDescriptor {
const cudnnDataType_t type, const cudnnDataType_t type,
const std::vector<int>& kernel, const std::vector<int>& kernel,
const int groups = 1) { const int groups = 1) {
// filter layout: MCHW, where M is the number of // filter layout: MCHW(MCDHW), where M is the number of
// output image channels, C is the number of input image channels, // output image channels, C is the number of input image channels,
// H and W is height and width of filter. // D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std::vector<int> kernel_with_group(kernel.begin(), kernel.end()); std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
if (groups > 1) { if (groups > 1) {
// M /= groups // M /= groups
......
...@@ -108,5 +108,11 @@ class TestWithStride(TestConv3dTransposeOp): ...@@ -108,5 +108,11 @@ class TestWithStride(TestConv3dTransposeOp):
self.filter_size = [f_c, 6, 3, 3, 3] self.filter_size = [f_c, 6, 3, 3, 3]
# ------------ test_cudnn ------------
class TestCudnn(TestConv3dTransposeOp):
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册