未验证 提交 a346c4dc 编写于 作者: J jakpiase 提交者: GitHub

Added softplus + activation oneDNN fuse pass (#36657)

* added softplus + activation fuse plass

* minor change

* implemented reviewer suggestion

* minor fix

* minor fix

* added scale_out parameter

* minor fix

* fix for iScan CI

* conditionally disabled logs

* refactored pass builder
上级 6af531b7
...@@ -117,6 +117,7 @@ if(WITH_MKLDNN) ...@@ -117,6 +117,7 @@ if(WITH_MKLDNN)
pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
......
...@@ -1061,6 +1061,27 @@ PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) { ...@@ -1061,6 +1061,27 @@ PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) {
return act_out; return act_out;
} }
PDNode *patterns::SoftplusActivation::operator()(std::string activation_type) {
// Create Operators
auto *softplus_op =
pattern->NewNode(softplus_repr())->assert_is_op("softplus");
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// intermediate variable, will be removed in the IR after fuse.
auto *softplus_out = pattern->NewNode(softplus_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("softplus")
->assert_is_op_input(activation_type);
// output
auto *activation_out = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
softplus_op->LinksTo({softplus_out});
activation_op->LinksFrom({softplus_out}).LinksTo({activation_out});
return activation_out;
}
PDNode *patterns::Embedding::operator()(PDNode *x) { PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids"); x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op = auto *lookup_table_op =
......
...@@ -577,6 +577,24 @@ struct FCActOneDNN : public PatternBase { ...@@ -577,6 +577,24 @@ struct FCActOneDNN : public PatternBase {
PATTERN_DECL_NODE(act_out); PATTERN_DECL_NODE(act_out);
}; };
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct SoftplusActivation : public PatternBase {
SoftplusActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "softplus_activation") {}
PDNode* operator()(std::string activation_type);
// declare operator node's name
PATTERN_DECL_NODE(softplus);
PATTERN_DECL_NODE(activation);
PATTERN_DECL_NODE(softplus_out);
PATTERN_DECL_NODE(activation_out);
};
// Embedding // Embedding
struct Embedding : public PatternBase { struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope) Embedding(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt",
"abs", "clip", "gelu", "relu6", "sigmoid"};
for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
attr_map.emplace("beta", "fuse_activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "fuse_activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "fuse_activation_alpha");
attr_map.emplace("max", "fuse_activation_beta");
} else {
attr_map.emplace("alpha", "fuse_activation_alpha");
attr_map.emplace("beta", "fuse_activation_beta");
}
FuseSoftplusActivation(graph, act_type, attr_map);
}
}
void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph, const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd;
patterns::SoftplusActivation softplus_activation_pattern(
gpd.mutable_pattern(), "softplus_activation");
softplus_activation_pattern(fuse_activation_type);
int found_softplus_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse softplus with activation op.";
GET_IR_NODE_FROM_SUBGRAPH(softplus_out, softplus_out,
softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softplus, softplus, softplus_activation_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation,
softplus_activation_pattern);
auto *softplus_op = softplus->Op();
if (softplus_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE_EQ(
BOOST_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")), true,
platform::errors::PreconditionNotMet("The softplus + activation "
"fusion may happen only when "
"oneDNN library is used."));
}
auto *activation_op = activation->Op();
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
}
}
if (fuse_activation_type == "gelu" &&
activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation_type", std::string("gelu_tanh"));
else
softplus_op->SetAttr("fuse_activation_type", fuse_activation_type);
softplus_op->SetAttr("use_mkldnn", true);
softplus_op->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(softplus, activation_out);
GraphSafeRemoveNodes(g, {activation, softplus_out});
found_softplus_activation_count++;
};
gpd(graph, handler);
AddStatis(found_softplus_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
PrettyLogDetail("--- fused %d softplus with %s activation",
found_softplus_activation_count, fuse_activation_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(softplus_activation_mkldnn_fuse_pass,
paddle::framework::ir::SoftplusActivationOneDNNPass);
REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("softplus", 1)
.EQ("relu", 0)
.EQ("tanh", 0)
.LE("leaky_relu", 1)
.EQ("swish", 0)
.EQ("hard_swish", 0)
.EQ("sqrt", 0)
.EQ("abs", 0)
.LE("relu6", 1)
.LE("clip", 1)
.EQ("gelu", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the softplus and another activation operators into
* softplus with another activation as post-op.
*/
class SoftplusActivationOneDNNPass : public FusePassBase {
public:
virtual ~SoftplusActivationOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void FuseSoftplusActivation(
ir::Graph *graph, const std::string &fuse_activation_type,
const std::unordered_map<std::string, std::string> &attr_map) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void MainTest(const std::string& activation_type) {
auto prog =
test::BuildProgramDesc({"softplus_x", "softplus_out", "activation_out"});
test::CreateOp(&prog, "softplus", {{"X", "softplus_x"}},
{{"Out", "softplus_out"}});
test::CreateOp(&prog, activation_type, {{"X", "softplus_out"}},
{{"Out", "activation_out"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
EXPECT_TRUE(test::RunPassAndAssert(
&graph, "softplus_activation_mkldnn_fuse_pass", "softplus_x",
"activation_out", removed_nodes_count));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"softplus", 1}, {activation_type, 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "softplus") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_activation_type"));
auto activation_type =
BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation_type"));
EXPECT_EQ(activation_type.compare(activation_type), 0);
}
}
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithTanh) {
MainTest("tanh")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithRelu) {
MainTest("relu")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithLeakyRelu) {
MainTest("leaky_relu")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSwish) {
MainTest("swish")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithHardswish) {
MainTest("hardswish")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSqrt) {
MainTest("sqrt")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithAbs) { MainTest("abs") }
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithClip) {
MainTest("clip")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithGelu) {
MainTest("gelu")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithRelu6) {
MainTest("relu6")
}
TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSigmoid) {
MainTest("sigmoid")
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(softplus_activation_mkldnn_fuse_pass);
...@@ -257,7 +257,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -257,7 +257,8 @@ void CpuPassStrategy::EnableMKLDNN() {
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", "batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after // "mkldnn_inplace_pass", // This pass should be activated after
......
...@@ -447,6 +447,26 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -447,6 +447,26 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Only used in cudnn kernel, need install cudnn.") "(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false) .SetDefault(false)
.AsExtra(); .AsExtra();
AddAttr<std::string>(
"fuse_activation_type",
"Fused activation type used in softplus OneDNN kernel.")
.SetDefault("")
.AsExtra();
AddAttr<float>(
"fuse_activation_alpha",
"Fused activation alpha parameter type used in softplus OneDNN kernel.")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>(
"fuse_activation_beta",
"Fused activation beta parameter type used in softplus OneDNN kernel.")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>(
"fuse_activation_scale",
"Fused activation scale parameter type used in softplus OneDNN kernel.")
.SetDefault(1.0f)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
:strong:`Softplus Activation Operator` :strong:`Softplus Activation Operator`
......
...@@ -23,9 +23,10 @@ template <typename T> ...@@ -23,9 +23,10 @@ template <typename T>
class SoftplusMKLDNNHandler class SoftplusMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
SoftplusMKLDNNHandler(const Tensor* x, const float beta, SoftplusMKLDNNHandler(const framework::ExecutionContext& ctx, const Tensor* x,
const mkldnn::engine engine, platform::Place cpu_place) const float beta, const mkldnn::engine engine)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine,
ctx.GetPlace()) {
auto x_tz = framework::vectorize(x->dims()); auto x_tz = framework::vectorize(x->dims());
auto x_md = auto x_md =
dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format()); dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
...@@ -42,6 +43,8 @@ class SoftplusMKLDNNHandler ...@@ -42,6 +43,8 @@ class SoftplusMKLDNNHandler
1.0f / beta, 0.0f); 1.0f / beta, 0.0f);
} }
AppendFusedActivationIfExists(ctx, post_ops);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops); attrs.set_post_ops(post_ops);
...@@ -53,8 +56,41 @@ class SoftplusMKLDNNHandler ...@@ -53,8 +56,41 @@ class SoftplusMKLDNNHandler
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), platform::to_void_cast<float>(beta)); this->fwd_pd_->src1_desc(), platform::to_void_cast<float>(beta));
} }
private:
void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops) {
const auto& fused_activation_type =
algo_map.find(ctx.Attr<std::string>("fuse_activation_type"));
if (fused_activation_type != algo_map.end()) {
auto scale_out =
ctx.Attr<float>("fuse_activation_scale"); // for future int8 support
post_ops.append_eltwise(scale_out, fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta"));
}
}
static const std::unordered_map<std::string, dnnl::algorithm> algo_map;
}; };
template <typename T>
const std::unordered_map<std::string, dnnl::algorithm>
SoftplusMKLDNNHandler<T>::algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};
template <typename T> template <typename T>
void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
const auto& dev_ctx = const auto& dev_ctx =
...@@ -68,7 +104,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { ...@@ -68,7 +104,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
SoftplusMKLDNNHandler<T> handler(x, beta, mkldnn_engine, ctx.GetPlace()); SoftplusMKLDNNHandler<T> handler(ctx, x, beta, mkldnn_engine);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
from paddle.fluid.core import PassVersionChecker
class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
fuse_activation_alpha = None
fuse_activation_beta = None
pass_name = 'softplus_activation_mkldnn_fuse_pass'
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
softplus_out = fluid.layers.softplus(data)
if self.fuse_activation_beta is not None:
activation_out = self.fuse_activation(
softplus_out, self.fuse_activation_alpha,
self.fuse_activation_beta)
elif self.fuse_activation_alpha is not None:
activation_out = self.fuse_activation(
softplus_out, self.fuse_activation_alpha)
else:
activation_out = self.fuse_activation(softplus_out)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32"),
}
self.fetch_list = [activation_out]
self.enable_mkldnn = True
def set_params(self):
self.fuse_activation = fluid.layers.relu
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class SoftplusActivationTanhOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.tanh
class SoftplusActivationLeakyReluOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.leaky_relu
self.fuse_activation_alpha = 0.3
class SoftplusActivationSwishOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.swish
self.fuse_activation_alpha = 3
class SoftplusActivationHardSwishOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.hard_swish
class SoftplusActivationSqrtOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.hard_swish
class SoftplusActivationAbsOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.abs
class SoftplusActivationClipOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.clip
self.fuse_activation_alpha = 1.1
self.fuse_activation_beta = 5.2
class SoftplusActivationGeluErfOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.gelu
class SoftplusActivationGeluTanhOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.gelu
self.fuse_activation_alpha = True # simulated "Approximate" attr
class SoftplusActivationRelu6OneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.relu6
class SoftplusActivationSigmoidOneDNNFusePassTest(
SoftplusActivationReluOneDNNFusePassTest):
def set_params(self):
self.fuse_activation = fluid.layers.sigmoid
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册