未验证 提交 a9c20660 编写于 作者: H HongyuJia 提交者: GitHub

delete GetExpectedKernelType mkldnn of conv_op (#47044)

上级 7c92177c
...@@ -227,15 +227,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -227,15 +227,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -494,14 +485,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -494,14 +485,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -673,24 +656,16 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -673,24 +656,16 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
int customized_type_value = auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
} }
#endif #endif
auto type = framework::OpKernelType( return framework::OpKernelType(data_type, ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_,
customized_type_value);
return type;
} }
} // namespace operators } // namespace operators
......
...@@ -69,12 +69,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = { ...@@ -69,12 +69,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"reduce_sum_grad", "reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one. // need to analysis and solve them one-by-one.
"conv2d",
"conv2d_grad",
"depthwise_conv2d",
"depthwise_conv2d_grad",
"conv3d",
"conv3d_grad",
"prior_box", "prior_box",
"fc", "fc",
"mul", "mul",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册