未验证 提交 ebdf3ef9 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Adjust mkldnn_placement_pass to check library type and data type (#49899)

* Adjust mkldnn_placement_pass to check library type and data type

* Check if var has inputs

* Remove unrelated test

* Refactor
上级 9dd1f4bf
...@@ -671,9 +671,8 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { ...@@ -671,9 +671,8 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
for (size_t bid = 0; bid < program.Size(); ++bid) { for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid); auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) { for (auto* op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) { if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op))
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
}
} }
} }
#else #else
......
...@@ -13,6 +13,36 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,36 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h" #include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
bool CUDNNPlacementPass::IsSupport(const Node* op) const {
std::string attr_name = GetAttrName();
if (!(op->Op()->HasAttr(attr_name) || op->Op()->HasProtoAttr(attr_name)))
return false;
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op->Op()->Type());
if (it == all_kernels.end()) {
// All control operators don't have kernel.
return false;
}
for (auto& kernel_pair : it->second) {
if (platform::is_gpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
return true;
}
}
return false;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass) REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass)
.RequirePassAttr("cudnn_enabled_op_types"); .RequirePassAttr("cudnn_enabled_op_types");
...@@ -27,6 +27,9 @@ namespace ir { ...@@ -27,6 +27,9 @@ namespace ir {
* Specifies which operators should use cuDNN. * Specifies which operators should use cuDNN.
*/ */
class CUDNNPlacementPass : public PlacementPassBase { class CUDNNPlacementPass : public PlacementPassBase {
protected:
bool IsSupport(const Node* op) const override;
private: private:
const std::string GetPlacementName() const override { return "cuDNN"; } const std::string GetPlacementName() const override { return "cuDNN"; }
......
...@@ -40,6 +40,7 @@ void TestFcRNNFusePass(const std::string& pass_name, ...@@ -40,6 +40,7 @@ void TestFcRNNFusePass(const std::string& pass_name,
"__param_scope__", "__param_scope__",
(pass_name == "fc_gru_fuse_pass" ? fc_gru_test::CreateParamScope() (pass_name == "fc_gru_fuse_pass" ? fc_gru_test::CreateParamScope()
: fc_lstm_test::CreateParamScope())); : fc_lstm_test::CreateParamScope()));
RegisterOpKernel({"mul", "elementwise_add"});
graph.reset(mkldnn_placement_pass_->Apply(graph.release())); graph.reset(mkldnn_placement_pass_->Apply(graph.release()));
auto check_num_mkldnn_nodes = [&](const std::unique_ptr<ir::Graph>& graph) { auto check_num_mkldnn_nodes = [&](const std::unique_ptr<ir::Graph>& graph) {
......
...@@ -13,6 +13,69 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,69 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
inline bool FoundOneDNNKernelWithCorrectDataType(
const framework::ir::Node* op) {
const auto op_type = op->Op()->Type();
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kernel_pair : it->second) {
if (platform::is_cpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ ==
framework::LibraryType::kMKLDNN)) {
if (op->inputs.size() > 0) {
if (op->inputs[0]->IsVar() &&
op->inputs[0]->Var()->Name() != "feed" &&
kernel_pair.first.data_type_ ==
op->inputs[0]->Var()->GetDataType())
return true;
} else {
return true;
}
}
}
}
return false;
}
inline bool FoundPhiOneDNNKernelWithCorrectDataType(
const framework::ir::Node* op) {
auto op_type = op->Op()->Type();
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(op_type));
for (auto& kernel_pair : phi_kernels) {
if (kernel_pair.first.backend() == phi::Backend::ONEDNN) {
if (op->inputs.size() > 0) {
if (op->inputs[0]->IsVar() && op->inputs[0]->Var()->Name() != "feed" &&
kernel_pair.first.dtype() ==
framework::TransToPhiDataType(
op->inputs[0]->Var()->GetDataType()))
return true;
} else {
return true;
}
}
}
return false;
}
bool MKLDNNPlacementPass::IsSupport(const Node* op) const {
if (FoundOneDNNKernelWithCorrectDataType(op) ||
FoundPhiOneDNNKernelWithCorrectDataType(op))
return true;
return false;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass) REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
.RequirePassAttr("mkldnn_enabled_op_types"); .RequirePassAttr("mkldnn_enabled_op_types");
...@@ -27,6 +27,9 @@ namespace ir { ...@@ -27,6 +27,9 @@ namespace ir {
* Specifies which operators should use MKLDNN. * Specifies which operators should use MKLDNN.
*/ */
class MKLDNNPlacementPass : public PlacementPassBase { class MKLDNNPlacementPass : public PlacementPassBase {
protected:
bool IsSupport(const Node* op) const override;
private: private:
const std::string GetPlacementName() const override { return "MKLDNN"; } const std::string GetPlacementName() const override { return "MKLDNN"; }
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/utils/tribool.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/utils/tribool.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -80,6 +80,7 @@ class PlacementPassTest { ...@@ -80,6 +80,7 @@ class PlacementPassTest {
"l"})) { "l"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
var->SetDataType(framework::proto::VarType::FP32);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
var->SetPersistable(true); var->SetPersistable(true);
} }
...@@ -129,7 +130,7 @@ class PlacementPassTest { ...@@ -129,7 +130,7 @@ class PlacementPassTest {
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) { unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc(); auto prog = BuildProgramDesc();
RegisterOpKernel({"conv2d", "pool2d", "concat", "relu"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass"); auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
...@@ -162,8 +163,8 @@ class PlacementPassTest { ...@@ -162,8 +163,8 @@ class PlacementPassTest {
}; };
TEST(MKLDNNPlacementPass, enable_conv_relu) { TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
PlacementPassTest().MainTest({"conv2d", "relu"}, 3); PlacementPassTest().MainTest({"conv2d", "relu"}, 4);
} }
TEST(MKLDNNPlacementPass, enable_relu_pool) { TEST(MKLDNNPlacementPass, enable_relu_pool) {
...@@ -172,8 +173,9 @@ TEST(MKLDNNPlacementPass, enable_relu_pool) { ...@@ -172,8 +173,9 @@ TEST(MKLDNNPlacementPass, enable_relu_pool) {
} }
TEST(MKLDNNPlacementPass, enable_all) { TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
PlacementPassTest().MainTest({}, 4); // 1 concat
PlacementPassTest().MainTest({}, 6);
} }
TEST(MKLDNNPlacementPass, placement_name) { TEST(MKLDNNPlacementPass, placement_name) {
......
...@@ -933,6 +933,21 @@ static int GetNumOpNodes(const std::unique_ptr<Graph>& graph, ...@@ -933,6 +933,21 @@ static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
return num_nodes; return num_nodes;
} }
static void RegisterOpKernel(std::vector<std::string>&& op_types) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CPUPlace place = platform::CPUPlace();
OpKernelType mkldnn_kernel_type = OpKernelType(proto::VarType::FP32,
place,
DataLayout::kAnyLayout,
LibraryType::kMKLDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {};
for (auto& op_name : op_types)
all_kernels[op_name][mkldnn_kernel_type] = fake_kernel_func;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -32,13 +32,10 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const { ...@@ -32,13 +32,10 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) && if (IsSupport(n)) {
IsSupport(op->Type())) { if (op_types_list.empty() ||
if (op_types_list.empty() && IsDefaultOpTypes(op->Type())) { std::find(op_types_list.begin(), op_types_list.end(), n->Name()) !=
op->SetAttr(attr_name, true); op_types_list.end()) {
} else if (std::find(op_types_list.begin(),
op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr(attr_name, true); op->SetAttr(attr_name, true);
} }
} }
...@@ -46,53 +43,6 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const { ...@@ -46,53 +43,6 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
} }
} }
bool PlacementPassBase::IsSupport(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operators don't have kernel.
return false;
}
for (auto& kernel_pair : it->second) {
if (platform::is_gpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
return true;
}
}
} else if (GetAttrName() == "use_mkldnn") {
// This ops have use_mkldnn attr, but not support for now.
const std::vector<std::string> op_types = {
"trilinear_interp", "bicubic_interp", "linear_interp"};
return std::find(op_types.begin(), op_types.end(), op_type) ==
op_types.end();
}
return false;
}
bool PlacementPassBase::IsDefaultOpTypes(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
return true;
} else if (GetAttrName() == "use_mkldnn") {
// For interpolate ops, there's a little difference between Paddle and
// MKLDNN.
// If run MKLDNN interpolate ops, manual set AnalysisConfig and apply
// the corresponding pass.
const std::vector<std::string> not_default_op_types = {"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"bicubic_interp",
"linear_interp",
"bilinear_interp_v2",
"linear_interp_v2"};
bool is_interpolate_op = std::find(not_default_op_types.begin(),
not_default_op_types.end(),
op_type) != not_default_op_types.end();
return !is_interpolate_op;
}
return false;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,10 +35,7 @@ class PlacementPassBase : public Pass { ...@@ -35,10 +35,7 @@ class PlacementPassBase : public Pass {
virtual const std::string GetPlacementName() const = 0; virtual const std::string GetPlacementName() const = 0;
virtual const std::string GetAttrName() const = 0; virtual const std::string GetAttrName() const = 0;
virtual const std::unordered_set<std::string> GetOpTypesList() const = 0; virtual const std::unordered_set<std::string> GetOpTypesList() const = 0;
virtual bool IsSupport(const Node* op) const = 0;
private:
bool IsSupport(const std::string& op_type) const;
bool IsDefaultOpTypes(const std::string& op_type) const;
#if PADDLE_WITH_TESTING #if PADDLE_WITH_TESTING
friend class PlacementPassTest; friend class PlacementPassTest;
......
...@@ -115,17 +115,15 @@ endif() ...@@ -115,17 +115,15 @@ endif()
if(WITH_TESTING) if(WITH_TESTING)
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
if(WITH_GPU) inference_base_test(
inference_base_test( test_api_impl
test_api_impl SRCS
SRCS api_impl_tester.cc
api_impl_tester.cc DEPS
DEPS paddle_inference_shared
paddle_inference_shared ARGS
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR}
--word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
--book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
endif()
elseif(WIN32) elseif(WIN32)
inference_base_test( inference_base_test(
test_api_impl test_api_impl
...@@ -137,7 +135,6 @@ if(WITH_TESTING) ...@@ -137,7 +135,6 @@ if(WITH_TESTING)
--word2vec_dirname=${WORD2VEC_MODEL_DIR} --word2vec_dirname=${WORD2VEC_MODEL_DIR}
--book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
endif() endif()
endif() endif()
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
......
...@@ -331,6 +331,18 @@ TEST(inference_api_native, image_classification_gpu) { ...@@ -331,6 +331,18 @@ TEST(inference_api_native, image_classification_gpu) {
// } // }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
TEST(inference_api_native, image_classification_cpu_onednn) {
FLAGS_use_mkldnn = true;
MainImageClassification(paddle::PaddlePlace::kCPU);
}
TEST(inference_api_native, word2vec_cpu_onednn) {
FLAGS_use_mkldnn = true;
MainWord2Vec(paddle::PaddlePlace::kCPU);
}
#endif
TEST(PassBuilder, Delete) { TEST(PassBuilder, Delete) {
AnalysisConfig config; AnalysisConfig config;
config.DisableGpu(); config.DisableGpu();
......
...@@ -140,4 +140,32 @@ inline std::string FindOutputNameByVarName(framework::OpDesc* op, ...@@ -140,4 +140,32 @@ inline std::string FindOutputNameByVarName(framework::OpDesc* op,
if (output_name == searched_name) ret = name; if (output_name == searched_name) ret = name;
return ret; return ret;
} }
inline bool FoundOneDNNKernel(const framework::OpDesc* op) {
auto op_type = op->Type();
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kernel_pair : it->second) {
if (platform::is_cpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ ==
framework::LibraryType::kMKLDNN)) {
return true;
}
}
}
return false;
}
inline bool FoundPhiOneDNNKernel(const framework::OpDesc* op) {
auto op_type = op->Type();
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(op_type));
for (auto& kernel_pair : phi_kernels)
if (kernel_pair.first.backend() == phi::Backend::ONEDNN) return true;
return false;
}
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册