diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 8c221ec5421147dc33a68489bedd8c76a4c233e3..42e5eb2a43820391daa8cc296ec92b7ea87b14b4 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -36,30 +36,29 @@ using DataLayout = framework::DataLayout; framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - bool use_cudnn = - ctx.HasAttr("use_cudnn") ? ctx.Attr("use_cudnn") : false; - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - if (use_cudnn) { - library_ = framework::LibraryType::kCUDNN; + if (ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn") && + dev_ctx.cudnn_handle() != nullptr) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, + framework::LibraryType::kCUDNN); } } #endif #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(