未验证 提交 2bb1b0e8 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] Added MKL-DNN inplace pass for C-API inference (#23315)

上级 487f43bb
...@@ -86,6 +86,7 @@ endif() ...@@ -86,6 +86,7 @@ endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn) pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry softmax_op softmax DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn) pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
...@@ -145,6 +146,7 @@ if (WITH_MKLDNN) ...@@ -145,6 +146,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_mkldnn_inplace_pass SRCS mkldnn/mkldnn_inplace_pass_tester.cc DEPS mkldnn_inplace_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
......
...@@ -1834,6 +1834,35 @@ PDNode *patterns::MultipleQuantize::operator()() { ...@@ -1834,6 +1834,35 @@ PDNode *patterns::MultipleQuantize::operator()() {
return prev_out; return prev_out;
} }
PDNode *patterns::MKLDNNInPlace::operator()() {
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto possible_inplace_op =
pattern->NewNode(inplace_to_be_op_repr())->assert_is_ops({"softmax"});
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto input = pattern->NewNode(inplace_to_be_op_in_repr())
->assert_is_ops_input({"softmax"})
->AsInput();
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto output = pattern->NewNode(inplace_to_be_op_out_repr())
->assert_is_ops_output({"softmax"})
->AsIntermediate();
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
// Check if op is MKL-DNN enabled
possible_inplace_op->assert_op_attr("use_mkldnn", true);
possible_inplace_op->LinksTo({output});
possible_inplace_op->LinksFrom({input});
next_op->LinksFrom({output});
return possible_inplace_op;
}
// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a // a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a
// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b // b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b
// ... // ...
......
...@@ -1092,6 +1092,20 @@ struct MultipleQuantize : public PatternBase { ...@@ -1092,6 +1092,20 @@ struct MultipleQuantize : public PatternBase {
PATTERN_DECL_NODE(prev_out); PATTERN_DECL_NODE(prev_out);
}; };
// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
MKLDNNInPlace(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "mkldnn_inplace") {}
PDNode* operator()();
// MKL-DNN's in-place ops: BatchNorm, Softmax, Layer Norm
PATTERN_DECL_NODE(inplace_to_be_op);
PATTERN_DECL_NODE(inplace_to_be_op_in);
PATTERN_DECL_NODE(inplace_to_be_op_out);
PATTERN_DECL_NODE(next_op);
};
struct TransposeFlattenConcat : public PatternBase { struct TransposeFlattenConcat : public PatternBase {
TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope) TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose_flatten_concat") {} : PatternBase(pattern, name_scope, "transpose_flatten_concat") {}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"};
mkldnn_inplace();
int found_inplace_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "Start to handle MKL-DNN In-Place pass";
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op, inplace_to_be_op,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_in, inplace_to_be_op_in,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_out, inplace_to_be_op_out,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace);
if ((inplace_to_be_op->Op()->HasAttr("use_mkldnn") == false) ||
(boost::get<bool>(inplace_to_be_op->Op()->GetAttr("use_mkldnn")) ==
false)) {
VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to "
"false";
return;
}
auto& infer_inplace = OpInfoMap::Instance()
.Get(inplace_to_be_op->Op()->Type())
.infer_inplace_;
if (!infer_inplace) {
VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer";
return;
}
// TODO(jczaja): Enable more ops
if (inplace_to_be_op->Op()->Type() != "softmax") {
VLOG(3)
<< "Curently works for softmax only. TODO(jczaja): support other ops";
return;
}
// Iterate over all nodes that are ops
// and check if in-place to be var is part of inputs
// if positive then do not perform inplace
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
// Avoid searchin in op that is to be inplace
if ((n->id() != inplace_to_be_op->id())) {
auto* op = n->Op();
auto inputs = op->Inputs();
auto in_place_input = inplace_to_be_op_in->Name();
for (auto& it : inputs) {
for (auto& var_name : it.second) {
if (var_name == in_place_input) {
VLOG(3) << "MKL-DNN in-place pass: in-place var cannot be an "
"input to more than one operator";
return;
}
}
}
}
}
}
auto original_name = inplace_to_be_op_out->Name();
inplace_to_be_op_out->RenameVar(inplace_to_be_op_in->Name());
// Get mapping of input to output
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
// TODO(jczaja): Support more complex situations
auto out_name = in_to_outs.begin()->second;
inplace_to_be_op->Op()->SetOutput(
out_name, std::vector<std::string>({inplace_to_be_op_out->Name()}));
next_op->Op()->RenameInput(original_name, inplace_to_be_op_out->Name());
found_inplace_count++;
VLOG(3) << "MKL-DNN InPlace applied!";
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_inplace_pass, paddle::framework::ir::MKLDNNInPlacePass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Transpose weights of FC to comply with MKL-DNN interface
*/
class MKLDNNInPlacePass : public Pass {
public:
virtual ~MKLDNNInPlacePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const;
private:
#if PADDLE_WITH_TESTING
friend class MKLDNNInPlacePassTest;
#endif
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
namespace paddle {
namespace framework {
namespace ir {
class MKLDNNInplacePassTest {
private:
void SetOp(ProgramDesc* prog, const std::string& type,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!boost::indeterminate(use_mkldnn))
op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "softmax") {
op->SetAttr("axis", -1);
op->SetInput("X", inputs);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
ProgramDesc BuildProgramDesc(const std::string& mkldnn_enabled_op,
bool branched) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "weights", "bias", "f", "g", "h", "i", "j", "k"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"a", "weights", "bias"}),
std::vector<std::string>({"f"}), boost::indeterminate);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}),
mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "softmax", "softmax1", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}),
mkldnn_enabled_op.compare("softmax") == 0);
SetOp(&prog, "elementwise_add", "elementwise_add1",
std::vector<std::string>({"h", "i"}), std::vector<std::string>({"j"}),
mkldnn_enabled_op.compare("elementwise_add") == 0);
if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0);
}
return prog;
}
public:
void MainTest(const std::string& mkldnn_enabled_op, bool branched,
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc(mkldnn_enabled_op, branched);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_inplace_pass");
graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0;
std::unordered_map<std::string, std::string> input_names;
std::unordered_map<std::string, std::string> output_names;
input_names["softmax"] = "X";
output_names["softmax"] = "Out";
input_names["batch_norm"] = "X";
output_names["batch_norm"] = "Y";
input_names["layer_norm"] = "X";
output_names["layer_norm"] = "Y";
VLOG(3) << DebugString(graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == mkldnn_enabled_op) {
auto ins = op->Inputs();
auto outs = op->Outputs();
// Input and output are the same var
if (ins[input_names[mkldnn_enabled_op]] ==
outs[output_names[mkldnn_enabled_op]]) {
++use_mkldnn_true_count;
}
}
}
}
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
}
};
TEST(MKLDNNInplacePass, inplace_softmax) {
// softmax to be mkl-dnn enabled and made in-place
MKLDNNInplacePassTest().MainTest("softmax", false, 1);
}
TEST(MKLDNNInplacePass, inplace_softmax_branched) {
// softmax to be mkl-dnn enabled and made in-place
MKLDNNInplacePassTest().MainTest("softmax", true, 0);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(mkldnn_inplace_pass);
...@@ -200,7 +200,9 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -200,7 +200,9 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass" // "fc_mkldnn_pass",
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
})) { })) {
passes_.push_back(pass); passes_.push_back(pass);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册