From ee1aec62174b2be66ec6925fa285ff1914571752 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 11 Oct 2022 14:33:04 +0800 Subject: [PATCH] refine mkldnn code (#46677) --- paddle/fluid/operators/softmax_op.cc | 67 +++++++++++++++------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 3966b850c7..6c63b2719f 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -39,24 +39,9 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - } -#endif - if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()) || @@ -69,8 +54,24 @@ class SoftmaxOp : public framework::OperatorWithKernel { "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout_, library_); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::CanCUDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + layout_, + framework::LibraryType::kCUDNN); + } +#endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; @@ -136,23 +137,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - } -#endif if (input_data_type == framework::proto::VarType::FP16) { if (!(platform::is_gpu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) || @@ -162,9 +150,24 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::CanCUDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + layout_, + framework::LibraryType::kCUDNN); + } +#endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; -- GitLab