diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index e3dcff949f43c3438efdd7a2349168a6867339ad..599be6912b760ec97586bf821725781f68fa8385 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -78,4 +78,7 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) endif() + if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) + cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) + endif() endif() diff --git a/paddle/fluid/operators/fused/cudnn_fusion_helper.h b/paddle/fluid/operators/fused/cudnn_fusion_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..4434681e60b3b144fee3c7d8d477c49aac7e252e --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_fusion_helper.h @@ -0,0 +1,162 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +namespace dynload = platform::dynload; + +#if CUDNN_VERSION >= 8000 + +// A wrapper for cuDNN fused_op API. +class CudnnFusionOp { + public: + explicit CudnnFusionOp(cudnnFusedOps_t op_id) : plan_created_(false) { + // New 'fused op' descriptor creation + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsPlan(&op_, op_id)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnCreateFusedOpsConstParamPack(&op_const_params_, op_id)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsVariantParamPack( + &op_variant_params_, op_id)); + } + + ~CudnnFusionOp() { + // New 'fused op' descriptor destruction + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_)); + } + + // Execute fused op + void Execute(cudnnHandle_t cudnn_handle) { + PADDLE_ENFORCE_EQ( + plan_created_, true, + platform::errors::Fatal( + "CudnnFusionOp exec requested without a valid 'plan', need: " + ", GetWorkspaceSizeBytes(), Execute().")); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnFusedOpsExecute(cudnn_handle, op_, op_variant_params_)); + } + + // Set const param pack attribute given a descriptor. + template + void SetOpConstParamDesc(cudnnFusedOpsConstParamLabel_t param_label, + T *param_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsConstParamPackAttribute( + op_const_params_, param_label, param_ptr)); + plan_created_ = false; + } + + // Set multiple const param pack attribute given a descriptor. + template + void SetOpConstParamDesc( + const std::vector ¶m_labels, + T *param_ptr) { + for (auto param_label : param_labels) { + SetOpConstParamDesc(param_label, param_ptr); + } + } + + // Set const param pack attribute given a value of param. + template + void SetOpConstParamAttr(cudnnFusedOpsConstParamLabel_t param_label, + T param) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsConstParamPackAttribute(op_const_params_, + param_label, ¶m)); + plan_created_ = false; + } + + // Set multiple const param pack attribute given a value of param. + template + void SetOpConstParamAttr( + const std::vector ¶m_labels, + T param) { + for (auto param_label : param_labels) { + SetOpConstParamAttr(param_label, param); + } + } + + // Set a variant param pack attribute given a reference to a param. + template + void SetOpVariantParamAttrPtr(cudnnFusedOpsVariantParamLabel_t param_label, + T *param_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsVariantParamPackAttribute( + op_variant_params_, param_label, param_ptr)); + } + + // Set multiple const param pack attributes given a reference to a param. + template + void SetOpVariantParamAttrPtr( + const std::vector ¶m_labels, + const T *param_ptr) { + for (auto param_label : param_labels) { + SetOpVariantParamAttrPtr(param_label, param_ptr); + } + } + + // Get the workspace, which is required before Execute(). + size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) { + size_t workspace_bytes = 0U; + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan( + cudnn_handle, op_, op_const_params_, &workspace_bytes)); + plan_created_ = true; + return workspace_bytes; + } + + private: + bool plan_created_; + + cudnnFusedOpsPlan_t op_; + cudnnFusedOpsConstParamPack_t op_const_params_; + cudnnFusedOpsVariantParamPack_t op_variant_params_; +}; + +static inline std::vector GetStrides(const std::vector &shape) { + if (shape.size() < 1) { + return {}; + } + int dim = static_cast(shape.size()); + std::vector pro_shape(shape); + std::vector strides(dim); + int temp = pro_shape[1]; + pro_shape.erase(pro_shape.begin() + 1); + pro_shape.push_back(temp); + strides.back() = 1; + for (int i = dim - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * pro_shape[i + 1]; + } + strides.pop_back(); + strides.insert(strides.begin() + 1, 1); + return strides; +} + +static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; } + +#endif // CUDNN_VERSION >= 8000 +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..1ead78b8b64e1836e9e1de66b756cadf2435ddc9 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h @@ -0,0 +1,139 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; + +#if CUDNN_VERSION >= 8000 +template +class CudnnNormConvolutionOp { + public: + CudnnNormConvolutionOp() + : fwd_op_(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS) {} + ~CudnnNormConvolutionOp() {} + + void Init(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, + const std::vector &filter_shape, + const std::vector &output_shape, const int &pad, + const int &stride, const int &dilate, const int &group) { + cudnn_fwd_compute_type_ = platform::CudnnDataType::type; + dtype_ = platform::CudnnDataType::type; + format_ = CUDNN_TENSOR_NHWC; + + InitDescriptors(ctx, input_shape, filter_shape, output_shape, pad, stride, + dilate, group); + GetWorkspaceSize(ctx); + } + + void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr, + T *filter_ptr, T *output_ptr, float *sum_ptr, + float *sum_of_squares_ptr) { + auto handle = ctx.cudnn_handle(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + // Set variant_param + // input ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr); + fwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); + // output ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // fused op execute + fwd_op_.Execute(handle); + }, + fwd_workspace_byte_); + } + + // TBD + void Backward(const platform::CUDADeviceContext &ctx) {} + + private: + void InitDescriptors(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, + const std::vector &filter_shape, + const std::vector &output_shape, const int &pad, + const int &stride, const int &dilate, const int &group) { + // Set constant_param + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER, + CUDNN_PARAM_YDATA_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + + std::vector pad_vec = {pad, pad}; + std::vector stride_vec = {stride, stride}; + std::vector dilate_vec = {dilate, dilate}; + int output_channel = filter_shape[0]; + std::vector stats_shape = {1, 1, 1, output_channel}; + + // set conv desc + conv_desc_.set(dtype_, pad_vec, stride_vec, dilate_vec, false, group); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC, conv_desc_.desc()); + + // set input desc + in_desc_.set(input_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, in_desc_.desc()); + + // set filter desc + filter_desc_.set(filter_shape, format_, dtype_, group); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_WDESC, filter_desc_.desc()); + + // set output desc + out_desc_.set(output_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, out_desc_.desc()); + + // set output_stats desc + out_stats_desc_.set(stats_shape, format_, cudnn_fwd_compute_type_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC, + out_stats_desc_.desc()); + + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, CUDNN_BATCHNORM_SPATIAL); + } + + void GetWorkspaceSize(const platform::CUDADeviceContext &ctx) { + auto handle = ctx.cudnn_handle(); + fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); + } + + size_t fwd_workspace_byte_ = 0; + + cudnnDataType_t dtype_; + cudnnDataType_t cudnn_fwd_compute_type_; + platform::TensorDescriptor in_desc_; + platform::FilterDescriptor filter_desc_; + platform::TensorDescriptor out_desc_; + platform::TensorDescriptor out_stats_desc_; + platform::ConvolutionDescriptor conv_desc_; + cudnnTensorFormat_t format_; + + CudnnFusionOp fwd_op_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..125ed85642292050ee85d7d3b9745dedcd75368a --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc @@ -0,0 +1,262 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace op = paddle::operators; +using Tensor = paddle::framework::Tensor; + +USE_OP(conv2d); +USE_OP_DEVICE_KERNEL(conv2d, CUDNN); + +// get paddle conv2d op results as baseline +template +void Conv2DForwardCompute(const Tensor &x, const Tensor &w, Tensor *y, + const platform::CUDADeviceContext &ctx) { + framework::Scope scope; + auto var_x = scope.Var("Input"); + auto tensor_x = var_x->GetMutable(); + auto var_w = scope.Var("Filter"); + auto tensor_w = var_w->GetMutable(); + auto var_y = scope.Var("Output"); + auto tensor_y = var_y->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(x, place, tensor_x); + TensorCopySync(w, place, tensor_w); + + framework::AttributeMap attrs; + bool use_cudnn = true; + std::string data_format = "NHWC"; + std::string padding_algorithm = "SAME"; + attrs.insert({"use_cudnn", use_cudnn}); + attrs.insert({"data_format", data_format}); + attrs.insert({"padding_algorithm", padding_algorithm}); + + auto op = framework::OpRegistry::CreateOp( + "conv2d", {{"Input", {"Input"}}, {"Filter", {"Filter"}}}, + {{"Output", {"Output"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*tensor_y, place, y); + ctx.Wait(); +} + +template +class TestCudnnNormConvOpForward { + public: + TestCudnnNormConvOpForward() { + batch_size_ = 2; + height_ = 8; + width_ = 8; + input_channels_ = 8; + output_channels_ = 32; + kernel_size_ = 1; + stride_ = 1; + pad_ = 0; + } + + TestCudnnNormConvOpForward(int batch_size, int height, int width, + int input_channels, int output_channels, + int kernel_size, int stride) { + batch_size_ = batch_size; + height_ = height; + width_ = width; + input_channels_ = input_channels; + output_channels_ = output_channels; + kernel_size_ = kernel_size; + stride_ = stride; + pad_ = (kernel_size_ - 1) / 2; + } + + ~TestCudnnNormConvOpForward() {} + + void SetUp() { + input_size_ = batch_size_ * height_ * width_ * input_channels_; + filter_size_ = + output_channels_ * input_channels_ * kernel_size_ * kernel_size_; + output_size_ = batch_size_ * height_ * width_ * output_channels_; + param_size_ = output_channels_; + + input_vec_.resize(input_size_); + filter_raw_vec_.resize(filter_size_); + filter_pro_vec_.resize(filter_size_); + + std::default_random_engine random(0); + std::uniform_real_distribution dis(0.0, 1.0); + for (int i = 0; i < input_size_; ++i) { + input_vec_[i] = static_cast(dis(random)); + } + for (int i = 0; i < filter_size_; ++i) { + filter_raw_vec_[i] = static_cast(dis(random)); + } + // transpoes for filter + // NCHW->NHWC + for (int oc = 0; oc < output_channels_; ++oc) { + for (int kh = 0; kh < kernel_size_; ++kh) { + for (int kw = 0; kw < kernel_size_; ++kw) { + for (int ic = 0; ic < input_channels_; ++ic) { + int dst_idx = oc * kernel_size_ * kernel_size_ * input_channels_ + + kh * kernel_size_ * input_channels_ + + kw * input_channels_ + ic; + int src_idx = oc * kernel_size_ * kernel_size_ * input_channels_ + + ic * kernel_size_ * kernel_size_ + kh * kernel_size_ + + kw; + filter_pro_vec_[dst_idx] = filter_raw_vec_[src_idx]; + } + } + } + } + + framework::TensorFromVector(input_vec_, *ctx_, &input_); + input_.Resize({batch_size_, height_, width_, input_channels_}); + framework::TensorFromVector(filter_raw_vec_, *ctx_, &filter_raw_); + filter_raw_.Resize( + {output_channels_, input_channels_, kernel_size_, kernel_size_}); + framework::TensorFromVector(filter_pro_vec_, *ctx_, &filter_pro_); + filter_pro_.Resize( + {output_channels_, kernel_size_, kernel_size_, input_channels_}); + output_.Resize({batch_size_, height_, width_, output_channels_}); + base_output_.Resize({batch_size_, height_, width_, output_channels_}); + sum_.Resize({1, 1, 1, output_channels_}); + sum_of_squares_.Resize({1, 1, 1, output_channels_}); + ctx_->Wait(); + } + + void BaselineForward() { + Conv2DForwardCompute(input_, filter_raw_, &base_output_, *ctx_); + ctx_->Wait(); + } + + // get forward results of cudnn_norm_conv + void FusedForward() { + auto input_shape = framework::vectorize(input_.dims()); + auto filter_shape = framework::vectorize(filter_pro_.dims()); + auto output_shape = framework::vectorize(output_.dims()); + T *input_ptr = input_.data(); + T *filter_ptr = filter_pro_.data(); + T *output_ptr = output_.mutable_data(place_); + float *sum_ptr = sum_.mutable_data(place_); + float *sum_of_squares_ptr = sum_of_squares_.mutable_data(place_); + + std::shared_ptr> conv_op( + new op::CudnnNormConvolutionOp()); + conv_op->Init(*ctx_, input_shape, filter_shape, output_shape, pad_, stride_, + dilate_, group_); + conv_op->Forward(*ctx_, input_ptr, filter_ptr, output_ptr, sum_ptr, + sum_of_squares_ptr); + ctx_->Wait(); + } + + void Run() { + SetUp(); + BaselineForward(); + FusedForward(); + } + + // check forward correctness between baseline and results of normconv. + void CheckOut(const T diff, bool is_relative_atol = false) { + std::vector base_output_vec, output_vec; + output_vec.resize(output_size_); + base_output_vec.resize(output_size_); + TensorToVector(base_output_, *ctx_, &base_output_vec); + TensorToVector(output_, *ctx_, &output_vec); + ctx_->Wait(); + + for (int i = 0; i < output_size_; ++i) { + if (is_relative_atol) { + EXPECT_LT( + std::abs((output_vec[i] - base_output_vec[i]) / base_output_vec[i]), + diff); + } else { + EXPECT_LT(std::abs(output_vec[i] - base_output_vec[i]), diff); + } + } + } + + private: + int batch_size_, height_, width_, input_channels_, output_channels_; + int kernel_size_, stride_, pad_; + const int dilate_ = 1; + const int group_ = 1; + int input_size_, filter_size_, output_size_, param_size_; + + framework::Tensor input_, filter_raw_, filter_pro_, output_, base_output_; + framework::Tensor sum_, sum_of_squares_; + std::vector input_vec_, filter_raw_vec_, filter_pro_vec_; + + platform::CUDAPlace place_ = platform::CUDAPlace(0); + platform::CUDADeviceContext *ctx_ = + static_cast( + platform::DeviceContextPool::Instance().Get(place_)); +}; + +// test for fp16, kernel = 1, output_channels = input_channels +TEST(CudnnNormConvForward, GPUCudnnNormConvForward1Fp16) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 32; + int kernel_size = 1; + int stride = 1; + TestCudnnNormConvOpForward test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.Run(); + test.CheckOut(static_cast(1e-3), true); +} + +// test for fp16, kernel = 3, output_channels = input_channels +TEST(CudnnNormConvForward, GPUCudnnNormConvForward2Fp16) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 32; + int kernel_size = 3; + int stride = 1; + TestCudnnNormConvOpForward test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.Run(); + test.CheckOut(static_cast(1e-3), true); +} + +// test for fp16, kernel = 1, output_channels = input_channels * 4 +TEST(CudnnNormConvForward, GPUCudnnNormConvForward3Fp16) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 128; + int kernel_size = 1; + int stride = 1; + TestCudnnNormConvOpForward test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.Run(); + test.CheckOut(static_cast(1e-3), true); +} diff --git a/paddle/fluid/platform/cudnn_desc.h b/paddle/fluid/platform/cudnn_desc.h index 486b3346c37607a86504e2002c19f438804c351f..318c85ee484bef17833e48b068bf01f014eccd79 100644 --- a/paddle/fluid/platform/cudnn_desc.h +++ b/paddle/fluid/platform/cudnn_desc.h @@ -44,6 +44,9 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) { inline std::vector TransformDimOrder(const std::vector& dims) { std::vector transformed_dims(dims.begin(), dims.end()); + if (dims.size() < 4) { + return transformed_dims; + } int H, W, D, C; if (dims.size() == 4) { H = dims[1]; @@ -155,8 +158,8 @@ class TensorDescriptor { dims_with_group.data(), strides.data())); } - void set(const Tensor& tensor, const cudnnTensorFormat_t format) { - auto dims = framework::vectorize(tensor.dims()); + void set(const std::vector& dims, const cudnnTensorFormat_t format, + const cudnnDataType_t dtype) { std::vector transformed_dims; if (format == CUDNN_TENSOR_NHWC) { transformed_dims = TransformDimOrder(dims); @@ -164,8 +167,14 @@ class TensorDescriptor { transformed_dims = dims; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptorEx( - desc_.get(), format, ToCudnnDataType(tensor.type()), - transformed_dims.size(), transformed_dims.data())); + desc_.get(), format, dtype, transformed_dims.size(), + transformed_dims.data())); + } + + void set(const Tensor& tensor, const cudnnTensorFormat_t format) { + auto dims = framework::vectorize(tensor.dims()); + auto dtype = ToCudnnDataType(tensor.type()); + set(dims, format, dtype); } private: @@ -191,9 +200,8 @@ class FilterDescriptor { T* desc() { return desc_.get(); } T* desc() const { return desc_.get(); } - void set(const Tensor& tensor, const cudnnTensorFormat_t format, - const int groups = 1) { - auto dims = framework::vectorize(tensor.dims()); + void set(const std::vector& dims, const cudnnTensorFormat_t format, + const cudnnDataType_t dtype, const int groups = 1) { std::vector transformed_dims; if (format == CUDNN_TENSOR_NHWC) { transformed_dims = TransformDimOrder(dims); @@ -204,8 +212,15 @@ class FilterDescriptor { transformed_dims[1] = transformed_dims[1] / groups; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetFilterNdDescriptor( - desc_.get(), ToCudnnDataType(tensor.type()), format, - transformed_dims.size(), transformed_dims.data())); + desc_.get(), dtype, format, transformed_dims.size(), + transformed_dims.data())); + } + + void set(const Tensor& tensor, const cudnnTensorFormat_t format, + const int groups = 1) { + auto dims = framework::vectorize(tensor.dims()); + auto dtype = ToCudnnDataType(tensor.type()); + set(dims, format, dtype, groups); } private: diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 4828a97e4df4d54000739adff28bc861d2da2213..3420c38fe963956813ce2cd18ba5c63d366d217c 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -180,7 +180,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif #if CUDNN_VERSION >= 8000 -#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) __macro(cudnnSetRNNDescriptor_v8); +#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \ + __macro(cudnnSetRNNDescriptor_v8); \ + __macro(cudnnCreateFusedOpsPlan); \ + __macro(cudnnCreateFusedOpsConstParamPack); \ + __macro(cudnnCreateFusedOpsVariantParamPack); \ + __macro(cudnnDestroyFusedOpsPlan); \ + __macro(cudnnDestroyFusedOpsConstParamPack); \ + __macro(cudnnDestroyFusedOpsVariantParamPack); \ + __macro(cudnnFusedOpsExecute); \ + __macro(cudnnSetFusedOpsConstParamPackAttribute); \ + __macro(cudnnSetFusedOpsVariantParamPackAttribute); \ + __macro(cudnnMakeFusedOpsPlan); CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif