diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index bf51ebd1d48d761a8e3249e764c32ab18ef5cf29..4ab7cf1c494791e2cd7c6bde1d0e388fb2416abb 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -133,6 +133,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode( auto* dev_ctx = pool.Get(place_); auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context); auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (op_with_kernel->CanCUDNNBeUsed(exec_ctx, + expected_kernel_key.data_type_)) { + expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + } +#endif VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n"; VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n"; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 33f0e58b2c4193a9e0d0a02b3ee538b78ad4cca5..7739aa0aa43f36a2af095c04d0f411cf296a4e5f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -635,6 +635,12 @@ void BuildOpFuncList(const platform::Place& place, *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (op_with_kernel->CanCUDNNBeUsed(exec_ctx, + expected_kernel_key.data_type_)) { + expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + } +#endif VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // change device by the device_guard() ApplyDeviceGuard(op, place, &expected_kernel_key); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d8312c698b70fe28e2311a7e467feea2b181b7c3..ec931e5a6394a53e82f0ea52b63b1f7f240e990e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -58,6 +58,10 @@ class DenseTensor; #include "paddle/fluid/platform/device/mlu/mlu_info.h" #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#endif + DECLARE_bool(benchmark); DECLARE_bool(check_nan_inf); DECLARE_bool(enable_unused_var_check); @@ -1358,6 +1362,39 @@ bool OperatorWithKernel::SupportsMKLDNN( } } +bool OperatorWithKernel::SupportsCUDNN( + const proto::VarType::Type data_type) const { + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + phi::TransToPhiKernelName(type_)); + paddle::experimental::DataType phi_data_type = + framework::TransToPhiDataType(data_type); + auto has_phi_kernel = std::any_of( + phi_kernels.begin(), + phi_kernels.end(), + [phi_data_type](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::GPUDNN && + kern_pair.first.dtype() == phi_data_type; + }); + if (has_phi_kernel) { + return true; + } else { + auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_); + if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) { + return false; + } else { + auto& op_kernels = op_kernel_iter->second; + return std::any_of( + op_kernels.begin(), + op_kernels.end(), + [data_type](OpKernelMap::const_reference kern_pair) { + return platform::is_gpu_place(kern_pair.first.place_) && + kern_pair.first.library_type_ == LibraryType::kCUDNN && + kern_pair.first.data_type_ == data_type; + }); + } + } +} + bool OperatorWithKernel::SupportsKernelType( const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const { auto& all_op_kernels = AllOpKernels(); @@ -1409,17 +1446,49 @@ bool OperatorWithKernel::SupportsKernelType( } #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (this->CanCUDNNBeUsed(exe_ctx, kernel_type.data_type_)) { + auto tmp_kernel_type = kernel_type; + tmp_kernel_type.library_type_ = framework::LibraryType::kCUDNN; + return kernels.find(tmp_kernel_type) != kernels.end(); + } +#endif + return kernel_iter != kernels.end(); } bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const { - const std::string use_mkldnn_attr = "use_mkldnn"; - return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr(use_mkldnn_attr) && + return ctx.HasAttr("use_mkldnn") && ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()) && this->SupportsMKLDNN(data_type); } +bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { + bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn") && + paddle::platform::is_gpu_place(ctx.GetPlace()); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (use_cudnn) { + auto& dev_ctx = ctx.device_context(); + use_cudnn &= (dev_ctx.cudnn_handle() != nullptr); + } +#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP + +#if defined(PADDLE_WITH_CUDA) + if (use_cudnn && data_type == framework::proto::VarType::BF16) { + PADDLE_ENFORCE_GE( + platform::DnnVersion(), + 8100, + platform::errors::InvalidArgument( + "bfloat16 can only be used when CUDNN_VERSION >= 8100")); + } +#endif // PADDLE_WITH_CUDA + + return use_cudnn && this->SupportsCUDNN(data_type); +} + void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { PADDLE_THROW(platform::errors::PermissionDenied( "The default InferShape function of OperatorWithKernel is not allowed to " @@ -1583,6 +1652,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (this->CanCUDNNBeUsed(exe_ctx, kernel_type_->data_type_)) { + kernel_type_->library_type_ = framework::LibraryType::kCUDNN; + } +#endif + // NOTE(Liu-xiandong):In my ctest, this branch do not be executed, // I can't understand it, it's really confusing. // But we still need to keep this to avoid errors. @@ -1826,6 +1901,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( } #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (this->CanCUDNNBeUsed(ctx, expected_kernel_key.data_type_)) { + expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + } +#endif + if (HasAttr("op_device")) { if (Attr("op_device") == "cpu") { expected_kernel_key.place_ = platform::CPUPlace(); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 31c7797a6340ce33ebf1323bae33de61bb783e72..c602a97ab6f7777981bf8a25da968e516981bcbe 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -632,12 +632,17 @@ class OperatorWithKernel : public OperatorBase { bool SupportsMKLDNN(proto::VarType::Type data_type) const; + bool SupportsCUDNN(proto::VarType::Type data_type) const; + bool SupportsKernelType(const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const; bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; + bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const; + virtual void InferShape(InferShapeContext* ctx) const; void RuntimeInferShape(const Scope& scope, diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index d76e06bd4143e29635610826be570cf83fbe02b5..5d9eff29e7180f7d4338dac1f2cffd8c8f282045 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -246,6 +246,12 @@ PreparedOp PrepareImpl( } #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { + expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + } +#endif + #if defined(PADDLE_WITH_XPU) bool is_xpu_unsupport = paddle::platform::is_xpu_place(expected_kernel_key.place_) && diff --git a/paddle/fluid/operators/activation_cudnn.cu.cc b/paddle/fluid/operators/activation_cudnn.cu.cc deleted file mode 100644 index 3afe6b4608fc41bfce94fec166f05dedd063ea65..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/activation_cudnn.cu.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" - -namespace paddle { -namespace operators { - -using platform::ActivationDescriptor; -using platform::TensorDescriptor; - -template -class CudnnActivationKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - phi::DenseTensor *X, *Out; - ExtractActivationTensor(context, X, Out); - ActivationDescriptor act_desc; - TensorDescriptor x_desc, out_desc; - x_desc.set(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation")); - out_desc.set(GET_DATA_SAFELY(Out, "Output", "Out", "CudnnActivation"); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc deleted file mode 100644 index c4e2685dd5958aca6165400871f462f7d668657b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/activation_cudnn_op.cu.cc +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" - -namespace paddle { -namespace operators { - -using phi::GPUContext; -using platform::ActivationDescriptor; -using platform::TensorDescriptor; - -#ifdef PADDLE_WITH_HIP -#define GPUDNN_ACTIVATION_RELU miopenActivationRELU -#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU -#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC -#define GPUDNN_ACTIVATION_TANH miopenActivationTANH -#else -#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU -#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU -#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID -#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH -#endif - -template -struct CudnnActivationFunctor { - using ELEMENT_TYPE = T; -#ifdef PADDLE_WITH_HIP - CudnnActivationFunctor(const phi::GPUContext& ctx, - const T& c, - const miopenActivationMode_t& m) - : ctx_(ctx), coef_(c), mode_(m) {} -#else - CudnnActivationFunctor(const phi::GPUContext& ctx, - const T& c, - const cudnnActivationMode_t& m) - : ctx_(ctx), coef_(c), mode_(m) {} -#endif - void operator()(const phi::DenseTensor& x, phi::DenseTensor* out) { - ActivationDescriptor act_desc; - act_desc.set(mode_, coef_); - TensorDescriptor x_desc, out_desc; - x_desc.set(x); - out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation")); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationForward( - ctx_.cudnn_handle(), - act_desc.desc(), - platform::CudnnDataType::kOne(), - x_desc.desc(), - x.data(), - platform::CudnnDataType::kZero(), - out_desc.desc(), - out->mutable_data(ctx_.GetPlace()))); -#else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationForward( - ctx_.cudnn_handle(), - act_desc.desc(), - platform::CudnnDataType::kOne(), - x_desc.desc(), - x.data(), - platform::CudnnDataType::kZero(), - out_desc.desc(), - out->mutable_data(ctx_.GetPlace()))); -#endif - } - const phi::GPUContext& ctx_; - const T coef_; -#ifdef PADDLE_WITH_HIP - const miopenActivationMode_t mode_; -#else - const cudnnActivationMode_t mode_; -#endif -}; - -template -struct CudnnActivationGradFunctor { - using ELEMENT_TYPE = T; -#ifdef PADDLE_WITH_HIP - CudnnActivationGradFunctor(const phi::GPUContext& ctx, - const T& c, - const miopenActivationMode_t& m) - : ctx_(ctx), coef_(c), mode_(m) {} -#else - CudnnActivationGradFunctor(const phi::GPUContext& ctx, - const T& c, - const cudnnActivationMode_t& m) - : ctx_(ctx), coef_(c), mode_(m) {} -#endif - void operator()(const phi::DenseTensor& x, - const phi::DenseTensor& out, - const phi::DenseTensor dout, - phi::DenseTensor* dx) { - ActivationDescriptor act_desc; - act_desc.set(mode_, coef_); - TensorDescriptor x_desc, out_desc, dout_desc, dx_desc; - x_desc.set(x); - out_desc.set(out); - dout_desc.set(dout); - dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad")); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationBackward( - ctx_.cudnn_handle(), - act_desc.desc(), - platform::CudnnDataType::kOne(), - out_desc.desc(), - out.data(), - dout_desc.desc(), - dout.data(), - x_desc.desc(), - x.data(), - platform::CudnnDataType::kZero(), - dx_desc.desc(), - dx->mutable_data(ctx_.GetPlace()))); -#else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationBackward( - ctx_.cudnn_handle(), - act_desc.desc(), - platform::CudnnDataType::kOne(), - out_desc.desc(), - out.data(), - dout_desc.desc(), - dout.data(), - x_desc.desc(), - x.data(), - platform::CudnnDataType::kZero(), - dx_desc.desc(), - dx->mutable_data(ctx_.GetPlace()))); -#endif - } - const phi::GPUContext& ctx_; - const T coef_; -#ifdef PADDLE_WITH_HIP - const miopenActivationMode_t mode_; -#else - const cudnnActivationMode_t mode_; -#endif -}; - -template -struct CudnnReluFunctor : public CudnnActivationFunctor { - explicit CudnnReluFunctor(const phi::GPUContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} -}; -template -struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { - explicit CudnnReluGradFunctor(const phi::GPUContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudnnRelu6Functor : public CudnnActivationFunctor { - explicit CudnnRelu6Functor(const phi::GPUContext& ctx) - : CudnnActivationFunctor(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {} -}; -template -struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { - explicit CudnnRelu6GradFunctor(const phi::GPUContext& ctx) - : CudnnActivationGradFunctor( - ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {} - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudnnSigmoidFunctor : public CudnnActivationFunctor { - explicit CudnnSigmoidFunctor(const phi::GPUContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} -}; -template -struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { - explicit CudnnSigmoidGradFunctor(const phi::GPUContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudnnTanhFunctor : public CudnnActivationFunctor { - explicit CudnnTanhFunctor(const phi::GPUContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} -}; -template -struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor { - explicit CudnnTanhGradFunctor(const phi::GPUContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -class CudnnActivationKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { - const phi::DenseTensor* X = nullptr; - phi::DenseTensor* Out = nullptr; - ExtractActivationTensor(context, &X, &Out); - Out->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - Functor functor(dev_ctx); - functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out); - } -}; - -template -class CudnnActivationGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { - static_assert(Functor::FwdDeps() == ActBwdOpFwdDeps::kDepOut, - "Forward deps must be Out."); - - const phi::DenseTensor *X, *Out, *dOut; - X = Out = dOut = nullptr; - phi::DenseTensor* dX = nullptr; - ExtractActivationGradTensor( - context, &X, &Out, &dOut, &dX); - dX->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - Functor functor(dev_ctx); - functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"), - GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"), - GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"), - dX); - } -}; - -} // namespace operators -} // namespace paddle - -namespace plat = paddle::platform; -namespace ops = paddle::operators; - -#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro) \ - __macro(relu, CudnnReluFunctor, CudnnReluGradFunctor); \ - __macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor); \ - __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \ - __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor) - -#ifdef PADDLE_WITH_HIP -#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ - REGISTER_OP_KERNEL(act_type, \ - CUDNN, \ - plat::CUDAPlace, \ - ops::CudnnActivationKernel>); \ - REGISTER_OP_KERNEL( \ - act_type##_grad, \ - CUDNN, \ - plat::CUDAPlace, \ - ops::CudnnActivationGradKernel>); -#else -#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ - REGISTER_OP_KERNEL(act_type, \ - CUDNN, \ - plat::CUDAPlace, \ - ops::CudnnActivationKernel>, \ - ops::CudnnActivationKernel>); \ - REGISTER_OP_KERNEL( \ - act_type##_grad, \ - CUDNN, \ - plat::CUDAPlace, \ - ops::CudnnActivationGradKernel>, \ - ops::CudnnActivationGradKernel>); -#endif - -FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL); diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 8d123710e750e4ad8e7edf222a76afd534b2390d..2d7eb04f1dba0d0a54253a4a416f59643213d40a 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -134,15 +134,8 @@ class AffineGridOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library = framework::LibraryType::kCUDNN; - } -#endif auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); - return framework::OpKernelType( - data_type, ctx.GetPlace(), phi::DataLayout::kAnyLayout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; @@ -252,17 +245,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Output")), - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - library_); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Output")); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 14ad48616c8087d1172014cac8c3959bf5fe83dd..7158685c3ec867398f159eda84ce15de79cde130 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -209,24 +209,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( paddle::framework::DataTypeToString(filter_data_type))); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { -#if PADDLE_WITH_CUDA - if (input_data_type == framework::proto::VarType::BF16) { - PADDLE_ENFORCE_GE( - platform::DnnVersion(), - 8100, - platform::errors::InvalidArgument( - "bfloat16 can only be used when CUDNN_VERSION >= 8100")); - } -#endif // PADDLE_WITH_CUDA - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP - return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -476,16 +458,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(data_type, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif - return framework::OpKernelType(data_type, ctx.GetPlace()); } @@ -657,14 +629,6 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(data_type, - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index f5702f2179431d0c0bf15b2f302cf091cc4dc822..e9c4245bc4731680ee16586dccde9e7f5e4a53dc 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -28,9 +28,6 @@ limitations under the License. */ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#endif namespace paddle { namespace operators { @@ -40,14 +37,6 @@ using DataLayout = phi::DataLayout; framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(data_type, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(data_type, ctx.GetPlace()); } @@ -268,14 +257,6 @@ Example: framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(data_type, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(data_type, ctx.GetPlace()); } @@ -343,14 +324,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker { framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(data_type, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 77865647c4c5bbfb3488f808bb69e5cd5a967c47..7f57d6e288f87a8a59f63bba70768d08882313f9 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -35,17 +35,8 @@ class GridSampleOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - library_); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; @@ -146,17 +137,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - library_); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 7842de9b17a3bfc4ea769cf2b09f04ff6b641719..48bfa3576ab6c880c32b100ce25a1e082a0a229f 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -44,21 +44,13 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType PoolOp::GetKernelTypeForVar( @@ -86,22 +78,13 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( framework::OpKernelType PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType PoolOpGrad::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 5b4b9aef88637326d892e222ef3ec7e2ccf5084c..80f13a51ab0b1249bbbca9e39ace4c219aee65a6 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -43,14 +43,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout_, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; @@ -135,14 +127,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout_, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 42e0e5250e0841794241785b4648466c4f32b359..bc11f53e009353092d353babcfa3725544b82276 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -48,14 +48,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout_, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; @@ -140,14 +132,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout_, - framework::LibraryType::kCUDNN); - } -#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 7cde56121b00cf78b238bf61bbfbba44c35358d9..97fe4d620cb9c3ed55a6ca280448b6729aad49f6 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -362,7 +362,6 @@ register_unity_group( lstm_op.cu.cc rnn_op.cu.cc split_op.cu.cc - activation_cudnn_op.cu.cc assign_value_op.cu.cc run_program_op.cu.cc warpctc_op.cu.cc) diff --git a/paddle/fluid/platform/device/gpu/cuda/cudnn_helper.h b/paddle/fluid/platform/device/gpu/cuda/cudnn_helper.h index 4fa25476336f67f52c0b1c4d490bbe744a54e9fe..7181fc34f88caee1d185d526d151840cd1c53328 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cudnn_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cudnn_helper.h @@ -616,18 +616,6 @@ class ScopedActivationDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; -inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn"); - use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (use_cudnn) { - auto& dev_ctx = ctx.device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - return use_cudnn; -} - #if CUDNN_VERSION >= 7001 class ScopedCTCLossDescriptor { public: diff --git a/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h b/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h index 0c9d6d24cd1bfa7ae101f20c4bb19cf936dbf8be..fb2392e657523a22085c82d4d0d03164ac9e0241 100644 --- a/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h +++ b/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h @@ -553,18 +553,6 @@ class ScopedActivationDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; -inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn"); - use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_HIP - if (use_cudnn) { - auto& dev_ctx = ctx.device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - return use_cudnn; -} - class ScopedCTCLossDescriptor { public: ScopedCTCLossDescriptor() {