未验证 提交 a9c20660 编写于 作者: H HongyuJia 提交者: GitHub

delete GetExpectedKernelType mkldnn of conv_op (#47044)

上级 7c92177c
......@@ -227,15 +227,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -494,14 +485,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::LibraryType::kCUDNN);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......@@ -673,24 +656,16 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_,
customized_type_value);
return type;
return framework::OpKernelType(data_type, ctx.GetPlace());
}
} // namespace operators
......
......@@ -69,12 +69,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"conv2d",
"conv2d_grad",
"depthwise_conv2d",
"depthwise_conv2d_grad",
"conv3d",
"conv3d_grad",
"prior_box",
"fc",
"mul",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册