未验证 提交 461e6a01 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] activations Inplace support (#24123)

上级 c2bc92de
...@@ -86,7 +86,7 @@ endif() ...@@ -86,7 +86,7 @@ endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn) pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op activation_op softmax_op softmax DIR mkldnn) pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op gelu_op activation_op softmax_op softmax DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn) pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
......
...@@ -1892,17 +1892,18 @@ PDNode *patterns::MultipleQuantize::operator()() { ...@@ -1892,17 +1892,18 @@ PDNode *patterns::MultipleQuantize::operator()() {
} }
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
auto possible_inplace_op = const std::unordered_set<std::string> &supported_op_types = {
pattern->NewNode(inplace_to_be_op_repr()) "abs", "elementwise_add", "gelu", "leaky_relu", "relu", "softmax",
->assert_is_ops({"elementwise_add", "softmax"}); "sqrt", "swish", "tanh"};
auto possible_inplace_op = pattern->NewNode(inplace_to_be_op_repr())
->assert_is_ops(supported_op_types);
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
auto input = pattern->NewNode(inplace_to_be_op_in_repr()) auto input = pattern->NewNode(inplace_to_be_op_in_repr())
->assert_is_ops_input({"elementwise_add", "softmax"}) ->assert_is_ops_input(supported_op_types)
->AsInput(); ->AsInput();
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
auto output = pattern->NewNode(inplace_to_be_op_out_repr()) auto output = pattern->NewNode(inplace_to_be_op_out_repr())
->assert_is_ops_output({"elementwise_add", "softmax"}) ->assert_is_ops_output(supported_op_types)
->AsOutput(); ->AsOutput();
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op(); auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
......
...@@ -109,7 +109,6 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -109,7 +109,6 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
// It may be that next op is reusing some of vars, we need to // It may be that next op is reusing some of vars, we need to
// make sure that unwanted inplace is not created // make sure that unwanted inplace is not created
// TODO(jczaja): Make UT for that one
for (auto& n : current_op_out->outputs) { for (auto& n : current_op_out->outputs) {
auto& n_op_infer_inplace = auto& n_op_infer_inplace =
OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_; OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_;
......
...@@ -23,7 +23,12 @@ USE_OP(softmax); ...@@ -23,7 +23,12 @@ USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
USE_OP(gelu);
USE_OP(relu); USE_OP(relu);
USE_OP(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -47,8 +52,14 @@ class MKLDNNInplacePassTest { ...@@ -47,8 +52,14 @@ class MKLDNNInplacePassTest {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
} else if (type == "gelu") {
op->SetInput("X", inputs);
} else if (type == "leaky_relu") {
op->SetInput("X", inputs);
} else if (type == "relu") { } else if (type == "relu") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
} else if (type == "tanh") {
op->SetInput("X", inputs);
} else if (type == "softmax") { } else if (type == "softmax") {
op->SetAttr("axis", -1); op->SetAttr("axis", -1);
op->SetInput("X", inputs); op->SetInput("X", inputs);
...@@ -67,7 +78,7 @@ class MKLDNNInplacePassTest { ...@@ -67,7 +78,7 @@ class MKLDNNInplacePassTest {
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "weights", "bias", "f", "g", "h", "i", std::vector<std::string>({"a", "weights", "bias", "f", "g", "h", "i",
"j", "k", "l", "m", "z"})) { "j", "k", "l", "m", "n", "z"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
...@@ -90,6 +101,18 @@ class MKLDNNInplacePassTest { ...@@ -90,6 +101,18 @@ class MKLDNNInplacePassTest {
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}), SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
std::vector<std::string>({"k"}), std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0); mkldnn_enabled_op.compare("softmax") == 0);
SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}),
mkldnn_enabled_op.compare("tanh") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}),
std::vector<std::string>({"n"}),
mkldnn_enabled_op.compare("leaky_relu") == 0);
SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}),
std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0);
if (branched == true) { if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}), SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"z"}), std::vector<std::string>({"z"}),
...@@ -113,11 +136,6 @@ class MKLDNNInplacePassTest { ...@@ -113,11 +136,6 @@ class MKLDNNInplacePassTest {
std::unordered_map<std::string, std::string> input_names; std::unordered_map<std::string, std::string> input_names;
std::unordered_map<std::string, std::string> output_names; std::unordered_map<std::string, std::string> output_names;
input_names["softmax"] = "X";
output_names["softmax"] = "Out";
input_names["elementwise_add"] = "X";
output_names["elementwise_add"] = "Out";
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
...@@ -127,8 +145,9 @@ class MKLDNNInplacePassTest { ...@@ -127,8 +145,9 @@ class MKLDNNInplacePassTest {
auto ins = op->Inputs(); auto ins = op->Inputs();
auto outs = op->Outputs(); auto outs = op->Outputs();
// Input and output are the same var // Input and output are the same var
if (ins[input_names[mkldnn_enabled_op]] == // All inplace ops are inplacing input named: X
outs[output_names[mkldnn_enabled_op]]) { // and output : Out
if (ins["X"] == outs["Out"]) {
++use_mkldnn_true_count; ++use_mkldnn_true_count;
} }
} }
...@@ -153,6 +172,15 @@ TEST(MKLDNNInplacePass, inplace_elementwise_add) { ...@@ -153,6 +172,15 @@ TEST(MKLDNNInplacePass, inplace_elementwise_add) {
// Two elementwise_add mkl-dnn enabled op instances to be made inplace // Two elementwise_add mkl-dnn enabled op instances to be made inplace
MKLDNNInplacePassTest().MainTest("elementwise_add", false, 1); MKLDNNInplacePassTest().MainTest("elementwise_add", false, 1);
} }
TEST(MKLDNNInplacePass, inplace_tanh) {
MKLDNNInplacePassTest().MainTest("tanh", false, 1);
}
TEST(MKLDNNInplacePass, inplace_leaky_relu) {
// Input of leaky_relu is used as output of subsequent gelu, so no inplace
// cannot be done
MKLDNNInplacePassTest().MainTest("leaky_relu", false, 0);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -90,7 +90,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -90,7 +90,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
ctx.InputName("X")); ctx.InputName("X"));
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y); auto dst_memory_p =
x->IsSharedBufferWith(*y) ? src_memory_p : handler.AcquireDstMemory(y);
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine()); mkldnn::stream astream(dev_ctx.GetEngine());
......
cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry elementwise_add_op softmax_op softmax scope device_context enforce executor) cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry elementwise_add_op activation_op softmax_op softmax scope device_context enforce executor)
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -132,5 +134,11 @@ TEST(test_elementwise_add_inplace, cpu_place) { ...@@ -132,5 +134,11 @@ TEST(test_elementwise_add_inplace, cpu_place) {
ASSERT_TRUE(TestMain<float>(p, "elementwise_add", dims, 2)); ASSERT_TRUE(TestMain<float>(p, "elementwise_add", dims, 2));
} }
TEST(test_relu_inplace, cpu_place) {
framework::DDim dims({1, 12, 20, 20});
platform::CPUPlace p;
ASSERT_TRUE(TestMain<float>(p, "relu", dims, 1));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册