未验证 提交 1d18bc2c 编写于 作者: S Sławomir Siwek 提交者: GitHub

Mish FP32/BF16 kernel, conv and fc fuse passes (#38623)

* Mish

* Change exp() library

* mish fuse pass

* mish attrs

* fixes

* mishop maker

* remove attrs

* mish kernal for bf16

* fc+mish fuse

* fix code format error

* Resolve merge conflicts

* Update mish operator version

* update mish variable to new naming convention
上级 4c46eed0
...@@ -230,7 +230,15 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { ...@@ -230,7 +230,15 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.IsType<float>() .IsType<float>()
.End(); .End();
} }
Conv2DMishFusePass::Conv2DMishFusePass() {
AddOpCompat(OpCompat("mish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid")) AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X") .AddInput("X")
...@@ -311,6 +319,14 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass) ...@@ -311,6 +319,14 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("hard_swish", 0)); .EQ("hard_swish", 0));
REGISTER_PASS(conv_mish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DMishFusePass);
REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("mish", 1));
REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass, REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass); paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
......
...@@ -72,6 +72,14 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass { ...@@ -72,6 +72,14 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass {
Conv2DHardSwishFusePass(); Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; } std::string activation_type() const { return "hard_swish"; }
}; };
/*
* Fuse Conv and Mish class
*/
class Conv2DMishFusePass : public ConvActivationFusePass {
public:
Conv2DMishFusePass();
std::string activation_type() const { return "mish"; }
};
/* /*
* Fuse Conv and HardSigmoid class * Fuse Conv and HardSigmoid class
*/ */
......
...@@ -148,6 +148,7 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); } ...@@ -148,6 +148,7 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) { TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
MainTest("hard_swish"); MainTest("hard_swish");
} }
TEST(ConvActivationFusePass, conv_mish_fuse_pass) { MainTest("mish"); }
TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) { TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) {
MainTest("hard_sigmoid"); MainTest("hard_sigmoid");
} }
......
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid", std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid", "mish",
"hard_swish"}; "hard_swish"};
for (std::string act_type : act_types) FuseFCAct(graph, act_type); for (std::string act_type : act_types) FuseFCAct(graph, act_type);
...@@ -99,5 +99,6 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) ...@@ -99,5 +99,6 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.LE("fc", 0) .LE("fc", 0)
.LE("gelu", 0) .LE("gelu", 0)
.LE("sigmoid", 0) .LE("sigmoid", 0)
.LE("mish", 1)
.LE("hard_swish", 0) .LE("hard_swish", 0)
.LE("tanh", 0)); .LE("tanh", 0));
...@@ -27,8 +27,8 @@ namespace ir { ...@@ -27,8 +27,8 @@ namespace ir {
* \brief Fuse the FC and activation operators into single OneDNN's * \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op. * FC with post-op.
* *
* \note Currently only GeLU, hardswish, sigmoid and tanh are supported as an * \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* activation function. * as an activation function.
*/ */
class FuseFCActOneDNNPass : public FusePassBase { class FuseFCActOneDNNPass : public FusePassBase {
public: public:
......
...@@ -201,6 +201,36 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ...@@ -201,6 +201,36 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
} }
} }
TEST(FuseFCActOneDNNPass, FuseWithMish) {
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"mish", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_EQ(act_type.compare("mish"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
auto prog = auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
......
...@@ -252,6 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -252,6 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", //
"conv_mish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", // "conv_hard_sigmoid_mkldnn_fuse_pass", //
// TODO(baoachun) fix int8 accuracy // TODO(baoachun) fix int8 accuracy
"conv_gelu_mkldnn_fuse_pass", "conv_gelu_mkldnn_fuse_pass",
......
...@@ -237,6 +237,10 @@ template <typename T> ...@@ -237,6 +237,10 @@ template <typename T>
using HardSwishMKLDNNFunctor = using HardSwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>; MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T> template <typename T>
using SigmoidMKLDNNFunctor = using SigmoidMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>; MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
...@@ -274,6 +278,10 @@ template <typename T> ...@@ -274,6 +278,10 @@ template <typename T>
using HardSwishMKLDNNGradFunctor = using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>; MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T> template <typename T>
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
...@@ -341,6 +349,8 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor, ...@@ -341,6 +349,8 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradUseOutFunctor); SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradUseOutFunctor); SqrtMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
MishMKLDNNGradFunctor);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
......
...@@ -516,6 +516,10 @@ class ConvMKLDNNHandlerT ...@@ -516,6 +516,10 @@ class ConvMKLDNNHandlerT
post_operations.append_eltwise(activation_scale, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish, dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
} else if (fuse_activation == "mish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_mish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_sigmoid") { } else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(activation_scale, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear, dnnl::algorithm::eltwise_linear,
......
...@@ -496,6 +496,11 @@ class FCPrimitiveFactory { ...@@ -496,6 +496,11 @@ class FCPrimitiveFactory {
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "mish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_mish,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") { } else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
......
...@@ -102,6 +102,15 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): ...@@ -102,6 +102,15 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass' self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "mish"
self.pass_name = 'conv_mish_mkldnn_fuse_pass'
class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
def set_params(self): def set_params(self):
self.conv_num_filters = 5 self.conv_num_filters = 5
......
...@@ -134,5 +134,27 @@ class FCHardSwishOneDnnFusePassTest(InferencePassTest): ...@@ -134,5 +134,27 @@ class FCHardSwishOneDnnFusePassTest(InferencePassTest):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class FCMishOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
mish_out = fluid.layers.mish(fc_out)
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.fetch_list = [mish_out]
self.enable_mkldnn = True
def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -148,5 +148,19 @@ class TestMKLDNNReluBF16Op(MKLDNNBF16ActivationOp, TestActivation): ...@@ -148,5 +148,19 @@ class TestMKLDNNReluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
return dout return dout
class TestMKLDNNMishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "mish"
def op_forward(self, x):
return x * np.tanh(np.log(1 + np.exp(x)))
def op_grad(self, dout, x):
omega = np.exp(3 * x) + 4 * np.exp(2 * x) + np.exp(x) * (4 * x + 6
) + 4 * (x + 1)
delta = np.exp(2 * x) + 2 * np.exp(x) + 2
return dout * ((np.exp(x) * omega) / delta**2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -315,6 +315,19 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): ...@@ -315,6 +315,19 @@ class TestMKLDNNHardSwishDim4(TestHardSwish):
self.dtype = np.float32 self.dtype = np.float32
class TestMKLDNNMish(TestActivation):
def setUp(self):
self.op_type = "mish"
self.dtype = np.float32
x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype(self.dtype)
out = x * np.tanh(np.log(1 + np.exp(x)))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNSigmoidDim4(TestSigmoid): class TestMKLDNNSigmoidDim4(TestSigmoid):
def setUp(self): def setUp(self):
super(TestMKLDNNSigmoidDim4, self).setUp() super(TestMKLDNNSigmoidDim4, self).setUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册