From eb411613e9afb210bc7de1d1fe76d3e00f1f0858 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Sat, 25 Apr 2020 06:54:31 +0200 Subject: [PATCH] [DNNL] refine activations Inplace support (#24145) --- .../ir/mkldnn/mkldnn_inplace_pass_tester.cc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index 88c4db8198f..01abe5a8d28 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -16,6 +16,7 @@ #include #include +#include #include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/op_registry.h" @@ -52,13 +53,9 @@ class MKLDNNInplacePassTest { op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[2]}); - } else if (type == "gelu") { - op->SetInput("X", inputs); - } else if (type == "leaky_relu") { - op->SetInput("X", inputs); - } else if (type == "relu") { - op->SetInput("X", inputs); - } else if (type == "tanh") { + } else if (std::unordered_set{"gelu", "leaky_relu", "relu", + "tanh"} + .count(type)) { op->SetInput("X", inputs); } else if (type == "softmax") { op->SetAttr("axis", -1); @@ -100,11 +97,11 @@ class MKLDNNInplacePassTest { mkldnn_enabled_op.compare("elementwise_add") == 0); SetOp(&prog, "relu", "relu2", std::vector({"j"}), std::vector({"k"}), - mkldnn_enabled_op.compare("softmax") == 0); + mkldnn_enabled_op.compare("relu") == 0); SetOp(&prog, "tanh", "tanh1", std::vector({"k"}), std::vector({"l"}), mkldnn_enabled_op.compare("tanh") == 0); - SetOp(&prog, "relu", "relu2", std::vector({"l"}), + SetOp(&prog, "relu", "relu3", std::vector({"l"}), std::vector({"m"}), mkldnn_enabled_op.compare("relu") == 0); SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector({"m"}), @@ -112,7 +109,7 @@ class MKLDNNInplacePassTest { mkldnn_enabled_op.compare("leaky_relu") == 0); SetOp(&prog, "gelu", "gelu1", std::vector({"n"}), std::vector({"m"}), - mkldnn_enabled_op.compare("relu") == 0); + mkldnn_enabled_op.compare("gelu") == 0); if (branched == true) { SetOp(&prog, "softmax", "softmax2", std::vector({"g"}), std::vector({"z"}), -- GitLab