未验证 提交 eb411613 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] refine activations Inplace support (#24145)

上级 9ec9fc0f
......@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -52,13 +53,9 @@ class MKLDNNInplacePassTest {
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "gelu") {
op->SetInput("X", inputs);
} else if (type == "leaky_relu") {
op->SetInput("X", inputs);
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "tanh") {
} else if (std::unordered_set<std::string>{"gelu", "leaky_relu", "relu",
"tanh"}
.count(type)) {
op->SetInput("X", inputs);
} else if (type == "softmax") {
op->SetAttr("axis", -1);
......@@ -100,11 +97,11 @@ class MKLDNNInplacePassTest {
mkldnn_enabled_op.compare("elementwise_add") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0);
mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}),
mkldnn_enabled_op.compare("tanh") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"l"}),
SetOp(&prog, "relu", "relu3", std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0);
SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}),
......@@ -112,7 +109,7 @@ class MKLDNNInplacePassTest {
mkldnn_enabled_op.compare("leaky_relu") == 0);
SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}),
std::vector<std::string>({"m"}),
mkldnn_enabled_op.compare("relu") == 0);
mkldnn_enabled_op.compare("gelu") == 0);
if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"z"}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册