未验证 提交 ebdf3ef9 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Adjust mkldnn_placement_pass to check library type and data type (#49899)

* Adjust mkldnn_placement_pass to check library type and data type

* Check if var has inputs

* Remove unrelated test

* Refactor
上级 9dd1f4bf
......@@ -671,11 +671,10 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op))
op->SetAttr("use_mkldnn", true);
}
}
}
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
......
......@@ -13,6 +13,36 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
bool CUDNNPlacementPass::IsSupport(const Node* op) const {
std::string attr_name = GetAttrName();
if (!(op->Op()->HasAttr(attr_name) || op->Op()->HasProtoAttr(attr_name)))
return false;
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op->Op()->Type());
if (it == all_kernels.end()) {
// All control operators don't have kernel.
return false;
}
for (auto& kernel_pair : it->second) {
if (platform::is_gpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
return true;
}
}
return false;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass)
.RequirePassAttr("cudnn_enabled_op_types");
......@@ -27,6 +27,9 @@ namespace ir {
* Specifies which operators should use cuDNN.
*/
class CUDNNPlacementPass : public PlacementPassBase {
protected:
bool IsSupport(const Node* op) const override;
private:
const std::string GetPlacementName() const override { return "cuDNN"; }
......
......@@ -40,6 +40,7 @@ void TestFcRNNFusePass(const std::string& pass_name,
"__param_scope__",
(pass_name == "fc_gru_fuse_pass" ? fc_gru_test::CreateParamScope()
: fc_lstm_test::CreateParamScope()));
RegisterOpKernel({"mul", "elementwise_add"});
graph.reset(mkldnn_placement_pass_->Apply(graph.release()));
auto check_num_mkldnn_nodes = [&](const std::unique_ptr<ir::Graph>& graph) {
......
......@@ -13,6 +13,69 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
inline bool FoundOneDNNKernelWithCorrectDataType(
const framework::ir::Node* op) {
const auto op_type = op->Op()->Type();
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kernel_pair : it->second) {
if (platform::is_cpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ ==
framework::LibraryType::kMKLDNN)) {
if (op->inputs.size() > 0) {
if (op->inputs[0]->IsVar() &&
op->inputs[0]->Var()->Name() != "feed" &&
kernel_pair.first.data_type_ ==
op->inputs[0]->Var()->GetDataType())
return true;
} else {
return true;
}
}
}
}
return false;
}
inline bool FoundPhiOneDNNKernelWithCorrectDataType(
const framework::ir::Node* op) {
auto op_type = op->Op()->Type();
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(op_type));
for (auto& kernel_pair : phi_kernels) {
if (kernel_pair.first.backend() == phi::Backend::ONEDNN) {
if (op->inputs.size() > 0) {
if (op->inputs[0]->IsVar() && op->inputs[0]->Var()->Name() != "feed" &&
kernel_pair.first.dtype() ==
framework::TransToPhiDataType(
op->inputs[0]->Var()->GetDataType()))
return true;
} else {
return true;
}
}
}
return false;
}
bool MKLDNNPlacementPass::IsSupport(const Node* op) const {
if (FoundOneDNNKernelWithCorrectDataType(op) ||
FoundPhiOneDNNKernelWithCorrectDataType(op))
return true;
return false;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
.RequirePassAttr("mkldnn_enabled_op_types");
......@@ -27,6 +27,9 @@ namespace ir {
* Specifies which operators should use MKLDNN.
*/
class MKLDNNPlacementPass : public PlacementPassBase {
protected:
bool IsSupport(const Node* op) const override;
private:
const std::string GetPlacementName() const override { return "MKLDNN"; }
......
......@@ -14,9 +14,9 @@
#include <gtest/gtest.h>
#include "paddle/utils/tribool.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/utils/tribool.h"
namespace paddle {
namespace framework {
......@@ -80,6 +80,7 @@ class PlacementPassTest {
"l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
var->SetDataType(framework::proto::VarType::FP32);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
......@@ -129,7 +130,7 @@ class PlacementPassTest {
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc();
RegisterOpKernel({"conv2d", "pool2d", "concat", "relu"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
......@@ -162,8 +163,8 @@ class PlacementPassTest {
};
TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
PlacementPassTest().MainTest({"conv2d", "relu"}, 3);
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
PlacementPassTest().MainTest({"conv2d", "relu"}, 4);
}
TEST(MKLDNNPlacementPass, enable_relu_pool) {
......@@ -172,8 +173,9 @@ TEST(MKLDNNPlacementPass, enable_relu_pool) {
}
TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
PlacementPassTest().MainTest({}, 4);
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
// 1 concat
PlacementPassTest().MainTest({}, 6);
}
TEST(MKLDNNPlacementPass, placement_name) {
......
......@@ -933,6 +933,21 @@ static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
return num_nodes;
}
static void RegisterOpKernel(std::vector<std::string>&& op_types) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CPUPlace place = platform::CPUPlace();
OpKernelType mkldnn_kernel_type = OpKernelType(proto::VarType::FP32,
place,
DataLayout::kAnyLayout,
LibraryType::kMKLDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {};
for (auto& op_name : op_types)
all_kernels[op_name][mkldnn_kernel_type] = fake_kernel_func;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -32,65 +32,15 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) &&
IsSupport(op->Type())) {
if (op_types_list.empty() && IsDefaultOpTypes(op->Type())) {
if (IsSupport(n)) {
if (op_types_list.empty() ||
std::find(op_types_list.begin(), op_types_list.end(), n->Name()) !=
op_types_list.end()) {
op->SetAttr(attr_name, true);
} else if (std::find(op_types_list.begin(),
op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr(attr_name, true);
}
}
}
}
}
bool PlacementPassBase::IsSupport(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operators don't have kernel.
return false;
}
for (auto& kernel_pair : it->second) {
if (platform::is_gpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
return true;
}
}
} else if (GetAttrName() == "use_mkldnn") {
// This ops have use_mkldnn attr, but not support for now.
const std::vector<std::string> op_types = {
"trilinear_interp", "bicubic_interp", "linear_interp"};
return std::find(op_types.begin(), op_types.end(), op_type) ==
op_types.end();
}
return false;
}
bool PlacementPassBase::IsDefaultOpTypes(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
return true;
} else if (GetAttrName() == "use_mkldnn") {
// For interpolate ops, there's a little difference between Paddle and
// MKLDNN.
// If run MKLDNN interpolate ops, manual set AnalysisConfig and apply
// the corresponding pass.
const std::vector<std::string> not_default_op_types = {"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"bicubic_interp",
"linear_interp",
"bilinear_interp_v2",
"linear_interp_v2"};
bool is_interpolate_op = std::find(not_default_op_types.begin(),
not_default_op_types.end(),
op_type) != not_default_op_types.end();
return !is_interpolate_op;
}
return false;
}
} // namespace ir
......
......@@ -35,10 +35,7 @@ class PlacementPassBase : public Pass {
virtual const std::string GetPlacementName() const = 0;
virtual const std::string GetAttrName() const = 0;
virtual const std::unordered_set<std::string> GetOpTypesList() const = 0;
private:
bool IsSupport(const std::string& op_type) const;
bool IsDefaultOpTypes(const std::string& op_type) const;
virtual bool IsSupport(const Node* op) const = 0;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
......
......@@ -115,7 +115,6 @@ endif()
if(WITH_TESTING)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
inference_base_test(
test_api_impl
SRCS
......@@ -125,7 +124,6 @@ if(WITH_TESTING)
ARGS
--word2vec_dirname=${WORD2VEC_MODEL_DIR}
--book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
endif()
elseif(WIN32)
inference_base_test(
test_api_impl
......@@ -137,7 +135,6 @@ if(WITH_TESTING)
--word2vec_dirname=${WORD2VEC_MODEL_DIR}
--book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
endif()
endif()
if(NOT APPLE AND NOT WIN32)
......
......@@ -331,6 +331,18 @@ TEST(inference_api_native, image_classification_gpu) {
// }
#endif
#ifdef PADDLE_WITH_MKLDNN
TEST(inference_api_native, image_classification_cpu_onednn) {
FLAGS_use_mkldnn = true;
MainImageClassification(paddle::PaddlePlace::kCPU);
}
TEST(inference_api_native, word2vec_cpu_onednn) {
FLAGS_use_mkldnn = true;
MainWord2Vec(paddle::PaddlePlace::kCPU);
}
#endif
TEST(PassBuilder, Delete) {
AnalysisConfig config;
config.DisableGpu();
......
......@@ -140,4 +140,32 @@ inline std::string FindOutputNameByVarName(framework::OpDesc* op,
if (output_name == searched_name) ret = name;
return ret;
}
inline bool FoundOneDNNKernel(const framework::OpDesc* op) {
auto op_type = op->Type();
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kernel_pair : it->second) {
if (platform::is_cpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ ==
framework::LibraryType::kMKLDNN)) {
return true;
}
}
}
return false;
}
inline bool FoundPhiOneDNNKernel(const framework::OpDesc* op) {
auto op_type = op->Type();
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(op_type));
for (auto& kernel_pair : phi_kernels)
if (kernel_pair.first.backend() == phi::Backend::ONEDNN) return true;
return false;
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册