diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ee25f16fde5d312de36fde04181daaf8d73ebba1..089737bb7c4ea61b80f872ef594d8473c7bac061 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -92,6 +92,7 @@ pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) +pass_library(layer_norm_fuse_pass inference) if(WITH_GPU) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) @@ -129,6 +130,7 @@ cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc D set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_library(pass_builder SRCS pass_builder.cc DEPS pass) +cc_library(pass_test_util SRCS pass_test_util.cc DEPS graph pass) cc_test(node_test SRCS node_test.cc DEPS node) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) @@ -150,6 +152,7 @@ cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.c cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) +cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor) if(WITH_GPU) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) @@ -158,7 +161,6 @@ if(NOT WIN32) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) endif() if (WITH_MKLDNN) - cc_library(pass_test_util SRCS mkldnn/pass_test_util.cc DEPS graph pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 173734cb0da3bf6fb681cd3a2db90071aaed2f0f..43ee501aeee62fb398543717c1cc1f99ed061dbe 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2796,6 +2796,122 @@ PDNode *patterns::MultiGru::operator()() { return h; } +PDNode *patterns::LayerNorm::operator()() { + auto *x = pattern->NewNode(x_repr())->AsInput()->assert_is_ops_input( + {"reduce_mean", "elementwise_sub"}); + auto *x_mean = pattern->NewNode(x_mean_repr())->assert_is_op("reduce_mean"); + auto *x_mean_out = pattern->NewNode(x_mean_out_repr()) + ->assert_is_op_output("reduce_mean", "Out") + ->assert_is_op_input("elementwise_sub", "Y") + ->AsIntermediate(); + auto *x_sub_mean = + pattern->NewNode(x_sub_mean_repr())->assert_is_op("elementwise_sub"); + auto *x_sub_mean_out = + pattern->NewNode(x_sub_mean_out_repr()) + ->assert_is_op_output("elementwise_sub") + ->assert_is_ops_input({"elementwise_pow", "elementwise_div"}, "X") + ->AsIntermediate(); + auto *sqr_pow = pattern->NewNode(sqr_pow_repr()) + ->assert_is_op_input("elementwise_pow", "Y") + ->assert_is_persistable_var() + ->AsInput(); + auto *x_sub_mean_sqr = + pattern->NewNode(x_sub_mean_sqr_repr())->assert_is_op("elementwise_pow"); + auto *x_sub_mean_sqr_out = pattern->NewNode(x_sub_mean_sqr_out_repr()) + ->assert_is_op_output("elementwise_pow") + ->assert_is_op_input("reduce_mean") + ->AsIntermediate(); + auto *std_dev = pattern->NewNode(std_dev_repr())->assert_is_op("reduce_mean"); + auto *std_dev_out = pattern->NewNode(std_dev_out_repr()) + ->assert_is_op_output("reduce_mean") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto *eps = pattern->NewNode(eps_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() + ->AsInput(); + auto *std_dev_eps = + pattern->NewNode(std_dev_eps_repr())->assert_is_op("elementwise_add"); + auto *std_dev_eps_out = pattern->NewNode(std_dev_eps_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("sqrt") + ->AsIntermediate(); + auto *std_dev_eps_sqrt = + pattern->NewNode(std_dev_eps_sqrt_repr())->assert_is_op("sqrt"); + auto *std_dev_eps_sqrt_out = pattern->NewNode(std_dev_eps_sqrt_out_repr()) + ->assert_is_op_output("sqrt") + ->assert_is_op_input("elementwise_div", "Y") + ->AsIntermediate(); + auto *division = + pattern->NewNode(division_repr())->assert_is_op("elementwise_div"); + auto *division_out = pattern->NewNode(division_out_repr()) + ->assert_is_op_output("elementwise_div") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto *gamma = pattern->NewNode(gamma_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_persistable_var() + ->AsInput(); + auto *scale = pattern->NewNode(scale_repr())->assert_is_op("elementwise_mul"); + auto *scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_output("elementwise_mul") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto *beta = pattern->NewNode(beta_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() + ->AsInput(); + auto *shift = pattern->NewNode(shift_repr())->assert_is_op("elementwise_add"); + auto *shift_out = pattern->NewNode(shift_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + /* + * X + * / \ + * / reduce_mean "u(x)" + * \ / + * elementwise_sub "x - u(x)" + * / \ 2 + * | \ / + * | elementwise_pow "(x - u(x))^2" + * | | + * | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}" + * | | eps + * | | / + * | elementwise_add "sigma^2 + epsilon" + * \ | + * \ sqrt "sqrt(sigma^2 + epsilon)" + * \ / + * \ / + * elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}" + * | + * gamma | + * \ | + * elementwise_mul "scale: gamma(C) * lnorm" + * | + * beta | + * \ | + * elementwise_add "shift: gamma(C) * lnorm + beta(C)" + */ + + x_mean->LinksFrom({x}).LinksTo({x_mean_out}); + x_sub_mean->LinksFrom({x, x_mean_out}).LinksTo({x_sub_mean_out}); + x_sub_mean_sqr->LinksFrom({x_sub_mean_out, sqr_pow}) + .LinksTo({x_sub_mean_sqr_out}); + std_dev->LinksFrom({x_sub_mean_sqr_out}).LinksTo({std_dev_out}); + std_dev_eps->LinksFrom({std_dev_out, eps}).LinksTo({std_dev_eps_out}); + + std_dev_eps_sqrt->LinksFrom({std_dev_eps_out}) + .LinksTo({std_dev_eps_sqrt_out}); + division->LinksFrom({x_sub_mean_out, std_dev_eps_sqrt_out}) + .LinksTo({division_out}); + scale->LinksFrom({division_out, gamma}).LinksTo({scale_out}); + shift->LinksFrom({scale_out, beta}).LinksTo({shift_out}); + + return shift_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 79b69a8c180e31d5371d7c6c5a2cc30562f803d2..f9b6e0ef9c9eae34dce631c173a7ec26d7a4a4a8 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1598,6 +1598,41 @@ struct MultiGru : public PatternBase { PATTERN_DECL_NODE(h); }; +// +// \brief Pattern looking for subgraph representing layer normalization +// operation. +// +struct LayerNorm : public PatternBase { + LayerNorm(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "layer_norm") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(x_mean); + PATTERN_DECL_NODE(x_mean_out); + PATTERN_DECL_NODE(x_sub_mean); + PATTERN_DECL_NODE(x_sub_mean_out); + PATTERN_DECL_NODE(sqr_pow); + PATTERN_DECL_NODE(x_sub_mean_sqr); + PATTERN_DECL_NODE(x_sub_mean_sqr_out); + PATTERN_DECL_NODE(std_dev); + PATTERN_DECL_NODE(std_dev_out); + PATTERN_DECL_NODE(eps); + PATTERN_DECL_NODE(std_dev_eps); + PATTERN_DECL_NODE(std_dev_eps_out); + PATTERN_DECL_NODE(std_dev_eps_sqrt); + PATTERN_DECL_NODE(std_dev_eps_sqrt_out); + PATTERN_DECL_NODE(division); + PATTERN_DECL_NODE(division_out); + PATTERN_DECL_NODE(gamma); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(beta); + PATTERN_DECL_NODE(shift); + PATTERN_DECL_NODE(shift_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..6734c74222ff82b2168537c57ad73cbc3a0075f0 --- /dev/null +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +// cpplint complaints (wrong!) for not included header in below line. +using string::PrettyLogDetail; // NOLINT + +namespace { +void validateReduceOpAttrs(const Node* node, const std::string& name) { + const auto* op = node->Op(); + if (op->HasAttr("dim")) { + auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("dim")); + PADDLE_ENFORCE_EQ(dims.size(), 1, platform::errors::PreconditionNotMet( + "The LayerNorm fusion ", name, + " reduction must happen only over " + "single dimension.")); + PADDLE_ENFORCE_EQ(dims.front(), -1, platform::errors::PreconditionNotMet( + "The LayerNorm fusion ", name, + " reduction must happen over last " + "dimension.")); + } + if (op->HasAttr("reduce_all")) { + PADDLE_ENFORCE(!BOOST_GET_CONST(bool, op->GetAttr("reduce_all")), + platform::errors::PreconditionNotMet( + "The LayerNorm fusion ", name, + " reduction must have " + "\'reduce_all\' attribute set to false.")); + } + if (op->HasAttr("keep_dim")) { + PADDLE_ENFORCE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")), + platform::errors::PreconditionNotMet( + "The LayerNorm fusion ", name, + " reduction must have " + "\'keep_dim\' attribute set to true.")); + } +} + +void setIntermediateOut(OpDesc* desc, const std::string& out_name, + const std::string& scope_name) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + desc->SetOutput(out_name, {new_name}); +} + +void addIntermediateOut(Node* op_node, const std::string& out_name, + const std::string& scope_name, Graph* graph) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto* node_var = graph->CreateVarNode(&out_var); + IR_NODE_LINK_TO(op_node, node_var); +} + +} // namespace + +void LayerNormFusePass::ApplyImpl(Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "The input graph of " + "LayerNormFusePass should not be nullptr.")); + FusePassBase::Init(scope_name_, graph); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + GraphPatternDetector gpd; + patterns::LayerNorm layer_norm_pattern(gpd.mutable_pattern(), scope_name_); + layer_norm_pattern(); + + int found_layer_norm_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Fuse LayerNorm from subgraph."; + GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_mean_out, x_mean_out, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean, x_sub_mean, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_out, x_sub_mean_out, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(sqr_pow, sqr_pow, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr, x_sub_mean_sqr, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr_out, x_sub_mean_sqr_out, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev, std_dev, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev_out, std_dev_out, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eps, eps, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps, std_dev_eps, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_out, std_dev_eps_out, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt, std_dev_eps_sqrt, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt_out, std_dev_eps_sqrt_out, + layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(division, division, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(division_out, division_out, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(gamma, gamma, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(beta, beta, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(shift, shift, layer_norm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(shift_out, shift_out, layer_norm_pattern); + + auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable(); + + // ------------------ subgraph node's validation --------------------------- + PADDLE_ENFORCE_EQ( + eps_tensor->numel(), 1, + platform::errors::InvalidArgument( + "The LayerNorm divisor " + "epsilon value must be one-element tensor, but has %s " + "elements.", + eps_tensor->numel())); + PADDLE_ENFORCE_EQ(eps_tensor->type(), proto::VarType::FP32, + platform::errors::InvalidArgument( + "The LayerNorm divisor " + "epsilon value must be of FP32 data type, but is %s.", + eps_tensor->type())); + + const auto& gamma_shape = gamma->Var()->GetShape(); + const auto& beta_shape = beta->Var()->GetShape(); + const auto& x_shape = x->Var()->GetShape(); + int64_t x_last_dim = x_shape.back(); + + PADDLE_ENFORCE_EQ(gamma_shape.size(), 1, + platform::errors::InvalidArgument( + "The LayerNorm gamma " + "(scale) tensor shape must be one-dimensional, " + "but is %s.", + gamma_shape.size())); + PADDLE_ENFORCE_EQ(beta_shape.size(), 1, + platform::errors::InvalidArgument( + "The LayerNorm beta " + "(shift) tensor shape must be one-dimensional, " + "but is %s.", + beta_shape.size())); + PADDLE_ENFORCE_EQ(beta_shape, gamma_shape, + platform::errors::InvalidArgument( + "The LayerNorm beta " + "and gamma tensors shapes' must be equal.")); + PADDLE_ENFORCE_EQ(gamma_shape.front(), x_last_dim, + platform::errors::InvalidArgument( + "The LayerNorm beta " + "and gamma tensors shapes' must be equal to the last " + "input's dimension size.")); + + validateReduceOpAttrs(x_mean, "input mean"); + validateReduceOpAttrs(std_dev, "std_dev mean"); + + // ------------------ op creation and placement --------------------------- + + OpDesc ln_op_desc; + ln_op_desc.SetType("layer_norm"); + ln_op_desc.SetInput("X", {x->Name()}); + ln_op_desc.SetInput("Scale", {gamma->Name()}); + ln_op_desc.SetInput("Bias", {beta->Name()}); + ln_op_desc.SetOutput("Y", {shift_out->Name()}); + setIntermediateOut(&ln_op_desc, "Mean", scope_name_); + setIntermediateOut(&ln_op_desc, "Variance", scope_name_); + ln_op_desc.SetAttr("begin_norm_axis", static_cast(x_shape.size() - 1)); + ln_op_desc.SetAttr("epsilon", *(eps_tensor->data())); + ln_op_desc.SetAttr("is_test", true); + Node* ln_op = g->CreateOpNode(&ln_op_desc); + + addIntermediateOut(ln_op, "Mean", scope_name_, g); + addIntermediateOut(ln_op, "Variance", scope_name_, g); + + IR_NODE_LINK_TO(x, ln_op); + IR_NODE_LINK_TO(gamma, ln_op); + IR_NODE_LINK_TO(beta, ln_op); + IR_OP_VAR_LINK(ln_op, shift_out); + GraphSafeRemoveNodes( + g, + {x_mean, x_mean_out, x_sub_mean, x_sub_mean_out, sqr_pow, + x_sub_mean_sqr, x_sub_mean_sqr_out, std_dev, std_dev_out, eps, + std_dev_eps, std_dev_eps_out, std_dev_eps_sqrt, std_dev_eps_sqrt_out, + division, division_out, scale, scale_out, shift}); + found_layer_norm_count++; + }; + + gpd(graph, handler); + AddStatis(found_layer_norm_count); + PrettyLogDetail("--- Fused %d subgraphs into layer_norm op.", + found_layer_norm_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass); +REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("elementwise_add", 0) + .LE("elementwise_add", 1) + .GE("elementwise_div", 0) + .LE("elementwise_div", 1) + .GE("elementwise_mul", 0) + .LE("elementwise_mul", 1) + .GE("elementwise_pow", 0) + .LE("elementwise_pow", 1) + .GE("elementwise_sub", 0) + .LE("elementwise_sub", 1) + .EQ("reduce_mean", 0) + .EQ("sqrt", 0)); diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.h b/paddle/fluid/framework/ir/layer_norm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..29a6f127065f6c2bfa3f885e44baa0f8df616a69 --- /dev/null +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.h @@ -0,0 +1,84 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * \brief Fuse the subgraph representing layer normalization into + * layer_norm op. + * + * \note The following graph represents this equation: + * + * x - u(x) + * y(c) * ------------------- + b(c) + * sqrt(sigma^2 + eps) + * + * x - input data + * u(x) - mean + * sigma^2 - standard deviation + * eps - epsilon + * y(c) - gamma (scale) channelwise + * b(c) - beta (shift) channelwise + * + * + * X + * / \ + * / reduce_mean "u(x)" + * \ / + * elementwise_sub "x - u(x)" + * / \ 2 + * | \ / + * | elementwise_pow "(x - u(x))^2" + * | | + * | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}" + * | | eps + * | | / + * | elementwise_add "sigma^2 + epsilon" + * \ | + * \ sqrt "sqrt(sigma^2 + epsilon)" + * \ / + * \ / + * elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}" + * | + * gamma | + * \ | + * elementwise_mul "scale: gamma(C) * lnorm" + * | + * beta | + * \ | + * elementwise_add "shift: gamma(C) * lnorm + beta(C)" + */ +class LayerNormFusePass : public FusePassBase { + public: + virtual ~LayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + private: + const std::string scope_name_{"layer_norm_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..c79c9dda8e54f66f3840f0f1f715d04690cd3f5d --- /dev/null +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h" +#include "paddle/fluid/framework/ir/pass_test_util.h" +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +ProgramDesc BuildGraphProgram() { + auto prog = test::BuildProgramDesc( + {"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", "std_dev_out", + "std_dev_eps_out", "std_dev_eps_sqrt_out", "division_out", "scale_out", + "shift_out"}, + {"sqr_pow", "eps", "gamma", "beta"}); + + const auto& block_desc = prog.Block(0); + auto* x_var_desc = block_desc.FindVar("x"); + x_var_desc->SetDataType(proto::VarType::FP32); + x_var_desc->SetShape({3, 32, 48}); + + auto* eps_var_desc = block_desc.FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP32); + eps_var_desc->SetShape({1}); + + auto* gamma_var_desc = block_desc.FindVar("gamma"); + gamma_var_desc->SetDataType(proto::VarType::FP32); + gamma_var_desc->SetShape({48}); + + auto* beta_var_desc = block_desc.FindVar("beta"); + beta_var_desc->SetDataType(proto::VarType::FP32); + beta_var_desc->SetShape({48}); + + auto* x_mean = test::CreateOp(&prog, "reduce_mean", {{"X", "x"}}, + {{"Out", "x_mean_out"}}, false); + x_mean->SetAttr("dim", std::vector{-1}); + x_mean->SetAttr("keep_dim", true); + x_mean->SetAttr("reduce_all", false); + + test::CreateOp(&prog, "elementwise_sub", {{"X", "x"}, {"Y", "x_mean_out"}}, + {{"Out", "x_sub_mean_out"}}, false); + test::CreateOp(&prog, "elementwise_pow", + {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, + {{"Out", "x_sub_mean_sqr_out"}}, false); + auto* std_dev = + test::CreateOp(&prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}}, + {{"Out", "std_dev_out"}}, false); + std_dev->SetAttr("dim", std::vector{-1}); + std_dev->SetAttr("keep_dim", true); + std_dev->SetAttr("reduce_all", false); + + test::CreateOp(&prog, "elementwise_add", {{"X", "std_dev_out"}, {"Y", "eps"}}, + {{"Out", "std_dev_eps_out"}}, false); + test::CreateOp(&prog, "sqrt", {{"X", "std_dev_eps_out"}}, + {{"Out", "std_dev_eps_sqrt_out"}}, false); + test::CreateOp(&prog, "elementwise_div", + {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, + {{"Out", "division_out"}}, false); + test::CreateOp(&prog, "elementwise_mul", + {{"X", "division_out"}, {"Y", "gamma"}}, + {{"Out", "scale_out"}}, false); + test::CreateOp(&prog, "elementwise_add", {{"X", "scale_out"}, {"Y", "beta"}}, + {{"Out", "shift_out"}}, false); + return prog; +} + +bool CheckFusedSubgraphOpsCount(const Graph& graph) { + return test::AssertOpsCount(graph, {{"reduce_mean", 0}, + {"elementwise_sub", 0}, + {"elementwise_pow", 0}, + {"elementwise_add", 0}, + {"sqrt", 0}, + {"elementwise_div", 0}, + {"elementwise_mul", 0}, + {"layer_norm", 1}}); +} + +} // namespace + +// ------------------------------ Test cases ----------------------------------- + +TEST(FuseLayerNormPass, TestFuse) { + ProgramDesc prog = BuildGraphProgram(); + + Graph graph(prog); + constexpr int removed_nodes = 19; + // LayerNorm + outputs: {Mean, Variance} + constexpr int added_nodes = 3; + + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + float eps_value = 1e-5f; + // Init scope, as it is used in pass + exe.CreateVariables(prog, 0, true, &scope); + test::InitLoDTensorHolder(&scope, place, "eps", {1}, &eps_value); + + graph.SetNotOwned(kParamScopeAttr, &scope); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", + "shift_out", removed_nodes, added_nodes)); + EXPECT_TRUE(CheckFusedSubgraphOpsCount(graph)); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "layer_norm") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("is_test")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("is_test"))); + ASSERT_TRUE(op->HasAttr("begin_norm_axis")); + ASSERT_TRUE(op->HasAttr("epsilon")); + } + } +} + +TEST(FuseLayerNormPass, TestInvalidEpsNumel) { + ProgramDesc prog = BuildGraphProgram(); + + auto* eps_var_desc = prog.Block(0).FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP32); + eps_var_desc->SetShape({2}); + + Graph graph(prog); + constexpr int removed_nodes = 19; + constexpr int added_nodes = 3; + + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + auto eps_values = std::vector{1e-5f, 1e-5f}; + // Init scope, as it is used in pass + exe.CreateVariables(prog, 0, true, &scope); + test::InitLoDTensorHolder(&scope, place, "eps", {2}, + eps_values.data()); + + graph.SetNotOwned(kParamScopeAttr, &scope); + EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", + "shift_out", removed_nodes, added_nodes), + paddle::platform::EnforceNotMet); +} + +TEST(FuseLayerNormPass, TestInvalidEpsDataType) { + ProgramDesc prog = BuildGraphProgram(); + + auto* eps_var_desc = prog.Block(0).FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP64); + eps_var_desc->SetShape({1}); + + Graph graph(prog); + constexpr int removed_nodes = 19; + constexpr int added_nodes = 3; + + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + double eps_value = 1e-5; + // Init scope, as it is used in pass + exe.CreateVariables(prog, 0, true, &scope); + test::InitLoDTensorHolder(&scope, place, "eps", {1}, &eps_value); + + graph.SetNotOwned(kParamScopeAttr, &scope); + EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", + "shift_out", removed_nodes, added_nodes), + paddle::platform::EnforceNotMet); +} + +TEST(FuseLayerNormPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("layer_norm_fuse_pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(layer_norm_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc index c8a4d94fe2d5a1ccec2b82eb30f878a3e78b2ef7..38364721f651527da1da8839d574c1bee136fa4f 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -15,7 +15,7 @@ #include #include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h" -#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" +#include "paddle/fluid/framework/ir/pass_test_util.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 35b40ec471568ea1a5c1a2425890b3873d4bfe4f..eafc81cc81d440a976e0176a93ff563972a1d5c9 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -15,7 +15,7 @@ #include #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" +#include "paddle/fluid/framework/ir/pass_test_util.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc index e7d332864c3ead294f97316d0ee0a83c8c7400f5..2cc79856a41a621515ea69c2cd97cd242f1c672f 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -15,7 +15,7 @@ #include #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" +#include "paddle/fluid/framework/ir/pass_test_util.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/framework/ir/mkldnn/pass_test_util.cc b/paddle/fluid/framework/ir/pass_test_util.cc similarity index 67% rename from paddle/fluid/framework/ir/mkldnn/pass_test_util.cc rename to paddle/fluid/framework/ir/pass_test_util.cc index a6c8a6662c92cc81a1acd47c755ad318a79a3f4c..c37331dec05b4e67dd5a0aaea8050fe5b7d11278 100644 --- a/paddle/fluid/framework/ir/mkldnn/pass_test_util.cc +++ b/paddle/fluid/framework/ir/pass_test_util.cc @@ -13,15 +13,19 @@ // limitations under the License. #include +#include #include #include #include #include #include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/ir/graph_traits.h" -#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_test_util.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { namespace framework { @@ -32,7 +36,7 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, const std::vector& inputs, const std::vector& outputs, bool use_mkldnn) { - auto op = prog->MutableBlock(0)->AppendOp(); + auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(op_type_name); op->SetAttr("use_mkldnn", use_mkldnn); @@ -43,6 +47,8 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, op->SetOutput(output.first, {output.second}); } + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); return op; } @@ -168,6 +174,49 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, return expected_nodes_num == current_nodes_num; } +template +void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, + const std::string& var_name, + const std::vector& dims, const T* data) { + auto var = scope->Var(var_name); + auto tensor = var->GetMutable(); + auto* tensor_mem_ptr = tensor->mutable_data(make_ddim(dims), place); + if (data != nullptr) { + std::memcpy(tensor_mem_ptr, data, tensor->memory_size()); + } else { + std::memset(tensor_mem_ptr, 0, tensor->memory_size()); + } +} + +// Instantiate for below data types. +template void InitLoDTensorHolder(Scope*, const paddle::platform::Place&, + const std::string&, + const std::vector&, + const float*); +template void InitLoDTensorHolder(Scope*, const paddle::platform::Place&, + const std::string&, + const std::vector&, const int*); +template void InitLoDTensorHolder(Scope*, + const paddle::platform::Place&, + const std::string&, + const std::vector&, + const double*); + +OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, + const std::string& output_name, + const std::string& output_arg_name) { + auto all_ops = prog.Block(0).AllOps(); + for (auto* op_desc : all_ops) { + if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) { + const auto& arg_names = op_desc->Outputs().at(output_name); + for (const auto& name : arg_names) { + if (name == output_arg_name) return op_desc; + } + } + } + return nullptr; +} + } // namespace test } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/pass_test_util.h b/paddle/fluid/framework/ir/pass_test_util.h similarity index 77% rename from paddle/fluid/framework/ir/mkldnn/pass_test_util.h rename to paddle/fluid/framework/ir/pass_test_util.h index 08ee50e0f177994ebf67ae5f5e0a8b3388c4ebe4..519522a932ceb791f80d3e280fc274b469973054 100644 --- a/paddle/fluid/framework/ir/mkldnn/pass_test_util.h +++ b/paddle/fluid/framework/ir/pass_test_util.h @@ -18,9 +18,13 @@ #include #include +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace framework { @@ -113,6 +117,37 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, const std::string& from, const std::string& to, int removed_nodes_count, int added_nodes_count = 0); +/// +/// @brief Initializes the tensor memory holder. +/// +/// @param[in] scope The scope that manages the variable. +/// @param[in] place The place where memory will be allocated. +/// @param[in] var_name The variable name. +/// @param[in] dims The dimensions of allocated tensor. +/// +/// @tparam T Tensor data type. +/// +template +void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, + const std::string& var_name, + const std::vector& dims, + const T* data = nullptr); + +/// +/// @brief Retrieve operator descriptor from program. +/// +/// @param[in] prog The program descriptor containing the op we +/// search for. +/// @param[in] op_type The wanted operator type name. +/// @param[in] output_name The wanted operator output name. +/// @param[in] output_arg_name The wanted operator output argument name. +/// +/// @return The operator descriptor. +/// +OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, + const std::string& output_name, + const std::string& output_arg_name); + } // namespace test } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index bb4a87af74d4a7bd6f3c8a09e19c1cc25c9a009c..7dc73bb609032e33069319b20b33d37d4b8e525a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -161,7 +161,8 @@ void GpuPassStrategy::EnableMkldnnBfloat16() { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. - passes_.assign({"simplify_with_basic_ops_pass", // + passes_.assign({"simplify_with_basic_ops_pass", // + "layer_norm_fuse_pass", "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // // "seqpool_concat_fuse_pass", // diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_layer_norm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_layer_norm_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..18a84848a0ff340020f9fa7c6d08702681b5d8c9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_layer_norm_fuse_pass.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for fusion of subgraph expressing layer normalization.""" + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from inference_pass_test import InferencePassTest +from paddle import enable_static +from paddle.fluid.core import PassVersionChecker + + +class LayerNormFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 64, 120], dtype="float32") + sqr_pow = fluid.layers.fill_constant( + shape=[1], value=2, dtype="float32") + eps = fluid.layers.fill_constant( + shape=[1], value=1e-5, dtype="float32") + gamma = fluid.layers.create_parameter( + shape=[120], dtype="float32", is_bias=True) + beta = fluid.layers.create_parameter( + shape=[120], dtype="float32", is_bias=True) + + x_mean_out = fluid.layers.reduce_mean(data, dim=-1, keep_dim=True) + x_sub_mean_out = fluid.layers.elementwise_sub(data, x_mean_out) + x_sub_mean_sqr_out = fluid.layers.elementwise_pow(x_sub_mean_out, + sqr_pow) + std_dev_out = fluid.layers.reduce_mean( + x_sub_mean_sqr_out, dim=-1, keep_dim=True) + std_dev_eps_out = fluid.layers.elementwise_add(std_dev_out, eps) + std_dev_eps_sqrt_out = fluid.layers.sqrt(std_dev_eps_out) + division_out = fluid.layers.elementwise_div(x_sub_mean_out, + std_dev_eps_sqrt_out) + scale_out = fluid.layers.elementwise_mul(division_out, gamma) + shift_out = fluid.layers.elementwise_add(scale_out, beta) + + self.feeds = { + "data": np.random.random((3, 64, 120)).astype("float32"), + } + self.fetch_list = [shift_out] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + self.assertTrue(PassVersionChecker.IsCompatible("layer_norm_fuse_pass")) + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 958aad3cfbaa1b086a2ec6e24ee692ffe89d08e0..0c36d0cda3f00e1e21a9b30591d3875edd524383 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -296,6 +296,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_layer_norm_mkldnn_op', 'test_layer_norm_bf16_mkldnn_op', 'test_layer_norm_op_v2', + 'test_layer_norm_fuse_pass', 'test_learning_rate_scheduler', 'test_linear_interp_op', 'test_linear_interp_v2_op',