未验证 提交 eb411613 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] refine activations Inplace support (#24145)

上级 9ec9fc0f
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <boost/logic/tribool.hpp> #include <boost/logic/tribool.hpp>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -52,13 +53,9 @@ class MKLDNNInplacePassTest { ...@@ -52,13 +53,9 @@ class MKLDNNInplacePassTest {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
} else if (type == "gelu") { } else if (std::unordered_set<std::string>{"gelu", "leaky_relu", "relu",
op->SetInput("X", inputs); "tanh"}
} else if (type == "leaky_relu") { .count(type)) {
op->SetInput("X", inputs);
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "tanh") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
} else if (type == "softmax") { } else if (type == "softmax") {
op->SetAttr("axis", -1); op->SetAttr("axis", -1);
...@@ -100,11 +97,11 @@ class MKLDNNInplacePassTest { ...@@ -100,11 +97,11 @@ class MKLDNNInplacePassTest {
mkldnn_enabled_op.compare("elementwise_add") == 0); mkldnn_enabled_op.compare("elementwise_add") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}), SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
std::vector<std::string>({"k"}), std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0); mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}), SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), std::vector<std::string>({"l"}),
mkldnn_enabled_op.compare("tanh") == 0); mkldnn_enabled_op.compare("tanh") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"l"}), SetOp(&prog, "relu", "relu3", std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}), std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0); mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}), SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}),
...@@ -112,7 +109,7 @@ class MKLDNNInplacePassTest { ...@@ -112,7 +109,7 @@ class MKLDNNInplacePassTest {
mkldnn_enabled_op.compare("leaky_relu") == 0); mkldnn_enabled_op.compare("leaky_relu") == 0);
SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}), SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}),
std::vector<std::string>({"m"}), std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0); mkldnn_enabled_op.compare("gelu") == 0);
if (branched == true) { if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}), SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"z"}), std::vector<std::string>({"z"}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册