未验证 提交 8c067ec1 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Code] Opt GetExpectedKernelType code of conv_transpose_op (#46666)

* opt GetExpectedKernelType code of conv_transpose_op

* fix if error
上级 7a1e1f99
......@@ -36,30 +36,29 @@ using DataLayout = framework::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
if (ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
dev_ctx.cudnn_handle() != nullptr) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册