From ff3ddbb502872bca411a26a6428c1dc3b8e998c7 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 11 Mar 2020 17:17:28 +0800 Subject: [PATCH] add skip_layernorm pass. test=develop (#22895) * add skip_layernorm pass. test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/skip_layernorm_fuse_pass.cc | 182 ++++++++++++++++++ .../framework/ir/skip_layernorm_fuse_pass.h | 42 ++++ .../ir/skip_layernorm_fuse_pass_tester.cc | 61 ++++++ .../ir/test_ir_skip_layernorm_pass.py | 49 +++++ 5 files changed, 336 insertions(+) create mode 100644 paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6d9dda19e0e..8b5f5aa888a 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -75,6 +75,7 @@ pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) +pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) if(WITH_GPU) pass_library(cudnn_placement_pass base DEPS placement_pass_base) @@ -125,6 +126,7 @@ cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.c cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass) +cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DEPS skip_layernorm_fuse_pass) cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) if(WITH_GPU) diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc new file mode 100644 index 00000000000..9dddc9154f8 --- /dev/null +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct SkipLayerNorm : public PatternBase { + SkipLayerNorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + + PDNode *operator()(PDNode *x, PDNode *y); + + // declare operator node's name + PATTERN_DECL_NODE(fused_skipe_layernorm); + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE( + elementwise_out); // (elementwise_input_x,elementwise_input_y) -> + // elementwise_out + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; + +PDNode *SkipLayerNorm::operator()(PDNode *x, PDNode *y) { + // Create nodes for elementwise add op. + x->assert_is_op_input("elementwise_add", "X"); + y->assert_is_op_input("elementwise_add", "Y"); + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + // Add links for elementwise_add op. + elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); + + // Create nodes for layer_norm op. + elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm"); + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + return layer_norm_out_var; +} + +} // namespace patterns + +void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("skip_layernorm_fuse", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("skip_layernorm_fuse/x") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X") + ->assert_var_not_persistable(); + auto *y = gpd.mutable_pattern() + ->NewNode("skip_layernorm_fuse/y") + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y") + ->assert_var_not_persistable(); + patterns::SkipLayerNorm fused_pattern(gpd.mutable_pattern(), + "skip_layernorm_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle SkipLayerNorm fuse"; + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + fused_pattern); + + std::unordered_set del_node_set; + + // Create an SkipLayerNorm op node + OpDesc new_desc; + new_desc.SetType("skip_layernorm"); + + // inputs + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_desc.SetInput("Bias", {layer_norm_bias->Name()}); + + // outputs + new_desc.SetOutput("Out", {layer_norm_out->Name()}); + + // attrs + new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("begin_norm_axis", + layer_norm->Op()->GetAttr("begin_norm_axis")); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise); + del_node_set.insert(layer_norm); + del_node_set.insert(elementwise_out); + del_node_set.insert(layer_norm_mean); + del_node_set.insert(layer_norm_variance); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(skip_layernorm_fuse_pass, + paddle::framework::ir::SkipLayerNormFusePass); diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h new file mode 100644 index 00000000000..2de8d376221 --- /dev/null +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> skip_layernorm +// | | +// layer_norm other_op3 +// | | +// other_op3 +// | +class SkipLayerNormFusePass : public FusePassBase { + public: + virtual ~SkipLayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc new file mode 100644 index 00000000000..d2d74698728 --- /dev/null +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(SkipLayerNormFusePass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, y) elementwise_add -> elementwise_out + // (elementwise_out, scale, bias) layer_norm -> layer_norm_out... + Layers layers; + auto* x = layers.data("x", {128, 768}); + auto* y = layers.data("y", {128, 768}); + auto* elementwise_out = layers.elementwise_add(x, y); + auto* scale = layers.data("scale", {768}, true); + auto* bias = layers.data("bias", {768}, true); + layers.layer_norm(elementwise_out, scale, bias); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("skip_layernorm_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = GetNumOpNodes(graph, "skip_layernorm"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 4, + platform::errors::PreconditionNotMet( + "The number of nodes before and after the fuse does " + "not meet expectations")); + PADDLE_ENFORCE_EQ( + num_fused_nodes_after, 1, + platform::errors::PreconditionNotMet( + "The number of fusion nodes does not meet expectations after fuse")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(skip_layernorm_fuse_pass); diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py new file mode 100644 index 00000000000..888857e5a72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py @@ -0,0 +1,49 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest +import paddle.fluid as fluid +import paddle.fluid.core as core + + +class SkipLayerNormFusePassTest(PassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + x = fluid.data( + name="x", shape=[128, 768], dtype="float32", lod_level=0) + y = fluid.data( + name="y", shape=[128, 768], dtype="float32", lod_level=0) + elementwise_out = fluid.layers.elementwise_add(x=x, y=y) + out = fluid.layers.layer_norm(input=elementwise_out) + + self.fetch_list = [out] + self.pass_names = "skip_layernorm_fuse_pass" + self.fused_op_type = "skip_layernorm" + self.num_fused_ops = 1 + + def test_check_program(self): + use_gpu_set = [False] + if core.is_compiled_with_cuda(): + use_gpu_set.append(True) + for use_gpu in use_gpu_set: + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + opt_program = self._apply_ir_passes() + self.check_program(opt_program) + + +if __name__ == "__main__": + unittest.main() -- GitLab