diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5b272e30ab3420b31b1dac15af630f153ed37be5..14773c04337c0c08045e5582d17f947ff6aeffbd 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1388,14 +1388,12 @@ bool OperatorWithKernel::SupportsKernelType( #endif // NOTE(jiahongyu): If MKLDNN can be used, the function SupportsKernelType needs -// to check whether current op supports MKLDNN kernel. There are three -// statements in if condition: The first statement checks whether library_type_ -// are changed by other high priority backends; the second checks whether this -// op has specific implementation; the third checks whether mkldnn kernel can be -// used. +// to check whether current op supports MKLDNN kernel. There are two statements +// in if condition: +// 1. Whether this op has specific implementation; +// 2. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN - if (kernel_type.library_type_ == framework::LibraryType::kPlain && - !paddle::platform::in_mkldnn_white_list(type_) && + if (!paddle::platform::in_mkldnn_white_list(type_) && this->CanMKLDNNBeUsed(exe_ctx, kernel_type.data_type_)) { auto tmp_kernel_type = kernel_type; tmp_kernel_type.library_type_ = framework::LibraryType::kMKLDNN; @@ -1571,13 +1569,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // NOTE(jiahongyu): The registered MKLDNN kernel have library_type = // LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default // values are kPlain, so we need to modify the library_type and data_layout_ -// here. There are three statements in if condition: The first statement checks -// whether library_type_ are changed by other high priority backends; the second -// checks whether this op has specific implementation; the third checks whether -// mkldnn kernel can be used. +// here. There are two statements in if condition: +// 1. Whether this op has specific implementation; +// 2. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN - if (kernel_type_->library_type_ == framework::LibraryType::kPlain && - !paddle::platform::in_mkldnn_white_list(type_) && + if (!paddle::platform::in_mkldnn_white_list(type_) && this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) { kernel_type_->library_type_ = framework::LibraryType::kMKLDNN; kernel_type_->data_layout_ = framework::DataLayout::kMKLDNN; @@ -1814,14 +1810,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( // NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function // GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and -// data_layout_ of expected_kernel_key need to be adjusted. There are three -// statements in if condition: The first statement checks whether library_type_ -// are changed by other high priority backends; the second checks whether this -// op has specific implementation; the third checks whether mkldnn kernel can be -// used. +// data_layout_ of expected_kernel_key need to be adjusted. There are two +// statements in if condition: +// 1. Whether this op has specific implementation; +// 2. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN - if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain && - !paddle::platform::in_mkldnn_white_list(type_) && + if (!paddle::platform::in_mkldnn_white_list(type_) && this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) { expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 1f70bcf4f428a5afe3b49d78e64ad3eb6454096a..28276ddbf8a83467664c93e22f417beb5b72822d 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -192,13 +192,11 @@ PreparedOp PrepareImpl( // NOTE(jiahongyu): The registered MKLDNN kernel have library_type = // LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default // values are kPlain, so we need to modify the library_type and data_layout_ -// here. There are three statements in if condition: The first statement checks -// whether library_type_ are changed by other high priority backends; the second -// checks whether this op has specific implementation; the third checks whether -// mkldnn kernel can be used. +// here. There are two statements in if condition: +// 1. Whether this op has specific implementation; +// 2. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN - if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain && - !paddle::platform::in_mkldnn_white_list(op.Type()) && + if (!paddle::platform::in_mkldnn_white_list(op.Type()) && op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;