未验证 提交 abee05a8 编写于 作者: S Sylwester Fraczek 提交者: GitHub

added mkldnn swish activation (#23041)

上级 31fc3ab7
...@@ -60,6 +60,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -60,6 +60,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
if (activation_type() == "relu6") { if (activation_type() == "relu6") {
desc->SetAttr("fuse_alpha", desc->SetAttr("fuse_alpha",
boost::get<float>(activation->Op()->GetAttr("threshold"))); boost::get<float>(activation->Op()->GetAttr("threshold")));
} else if (activation_type() == "swish") {
// paddle uses beta but mkldnn uses alpha for swish
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("beta"));
} else { } else {
desc->SetAttr("fuse_alpha", desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("alpha")); activation->Op()->GetAttrIfExists<float>("alpha"));
...@@ -95,3 +99,6 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, ...@@ -95,3 +99,6 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass); paddle::framework::ir::Conv2DReLU6FusePass);
REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DSwishFusePass);
...@@ -50,6 +50,13 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass { ...@@ -50,6 +50,13 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass {
public: public:
std::string activation_type() const { return "relu6"; } std::string activation_type() const { return "relu6"; }
}; };
/*
* Fuse Conv and Swish class
*/
class Conv2DSwishFusePass : public ConvActivationFusePass {
public:
std::string activation_type() const { return "swish"; }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -40,6 +40,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -40,6 +40,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("alpha", 0.02f); op->SetAttr("alpha", 0.02f);
} else if (type == "relu6") { } else if (type == "relu6") {
op->SetAttr("threshold", 6.0f); op->SetAttr("threshold", 6.0f);
} else if (type == "swish") {
op->SetAttr("beta", 1.0f);
} }
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
...@@ -133,6 +135,7 @@ TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) { ...@@ -133,6 +135,7 @@ TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) {
MainTest("leaky_relu"); MainTest("leaky_relu");
} }
TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); } TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); }
TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -196,6 +196,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -196,6 +196,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", // "conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass" // "fc_mkldnn_pass"
})) { })) {
......
...@@ -589,6 +589,13 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -589,6 +589,13 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator"); AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Swish Activation Operator. Swish Activation Operator.
......
...@@ -73,8 +73,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -73,8 +73,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0; T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
}
PADDLE_ENFORCE( PADDLE_ENFORCE(
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4, x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
...@@ -112,8 +117,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -112,8 +117,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0; T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
}
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims()); auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
...@@ -162,6 +172,10 @@ template <typename T> ...@@ -162,6 +172,10 @@ template <typename T>
using ReluMKLDNNFunctor = using ReluMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>; MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using SwishMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;
template <typename T> template <typename T>
using TanhMKLDNNFunctor = using TanhMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>; MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
...@@ -178,6 +192,10 @@ template <typename T> ...@@ -178,6 +192,10 @@ template <typename T>
using ReluMKLDNNGradFunctor = using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>; MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using SwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;
template <typename T> template <typename T>
using TanhMKLDNNGradFunctor = using TanhMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>; MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
...@@ -204,6 +222,7 @@ namespace ops = paddle::operators; ...@@ -204,6 +222,7 @@ namespace ops = paddle::operators;
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \ __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
......
...@@ -978,13 +978,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -978,13 +978,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
constexpr float scale = 1.0f; constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
} } else if (fuse_activation == "relu6") {
if (fuse_activation == "relu6") {
constexpr float scale = 1.0f; constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu, mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
} }
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
......
...@@ -16,9 +16,10 @@ from __future__ import print_function ...@@ -16,9 +16,10 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from scipy.special import expit
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
...@@ -111,6 +112,29 @@ class TestMKLDNNAbsDim2(TestAbs): ...@@ -111,6 +112,29 @@ class TestMKLDNNAbsDim2(TestAbs):
['X'], 'Out', max_relative_error=0.007, check_dygraph=False) ['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNSwishDim2(TestSwish):
def setUp(self):
super(TestMKLDNNSwishDim2, self).setUp()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
beta = 2.3
out = x * expit(beta * x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "beta": beta}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(['X'], 'Out')
class TestMKLDNNReluDim4(TestRelu): class TestMKLDNNReluDim4(TestRelu):
def setUp(self): def setUp(self):
super(TestMKLDNNReluDim4, self).setUp() super(TestMKLDNNReluDim4, self).setUp()
...@@ -228,6 +252,29 @@ class TestMKLDNNAbsDim4(TestAbs): ...@@ -228,6 +252,29 @@ class TestMKLDNNAbsDim4(TestAbs):
['X'], 'Out', max_relative_error=0.007, check_dygraph=False) ['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNSwishDim4(TestSwish):
def setUp(self):
super(TestMKLDNNSwishDim4, self).setUp()
x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
beta = 2.3
out = x * expit(beta * x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "beta": beta}
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(['X'], 'Out')
# Check if primitives already exist in backward # Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase): class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册