未验证 提交 8c067ec1 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Code] Opt GetExpectedKernelType code of conv_transpose_op (#46666)

* opt GetExpectedKernelType code of conv_transpose_op

* fix if error
上级 7a1e1f99
...@@ -36,30 +36,29 @@ using DataLayout = framework::DataLayout; ...@@ -36,30 +36,29 @@ using DataLayout = framework::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr; if (ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
if (use_cudnn) { dev_ctx.cudnn_handle() != nullptr) {
library_ = framework::LibraryType::kCUDNN; return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
} }
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册