未验证 提交 4f066e31 编写于 作者: A Adam Osewski 提交者: GitHub

Layer normalization fuse pass. (#30721)

上级 b1026f64
...@@ -92,6 +92,7 @@ pass_library(skip_layernorm_fuse_pass base) ...@@ -92,6 +92,7 @@ pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference) pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
if(WITH_GPU) if(WITH_GPU)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference) pass_library(embedding_eltwise_layernorm_fuse_pass inference)
...@@ -129,6 +130,7 @@ cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc D ...@@ -129,6 +130,7 @@ cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc D
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_library(pass_test_util SRCS pass_test_util.cc DEPS graph pass)
cc_test(node_test SRCS node_test.cc DEPS node) cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
...@@ -150,6 +152,7 @@ cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.c ...@@ -150,6 +152,7 @@ cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.c
cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass)
cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
...@@ -158,7 +161,6 @@ if(NOT WIN32) ...@@ -158,7 +161,6 @@ if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif() endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_library(pass_test_util SRCS mkldnn/pass_test_util.cc DEPS graph pass)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
......
...@@ -2796,6 +2796,122 @@ PDNode *patterns::MultiGru::operator()() { ...@@ -2796,6 +2796,122 @@ PDNode *patterns::MultiGru::operator()() {
return h; return h;
} }
PDNode *patterns::LayerNorm::operator()() {
auto *x = pattern->NewNode(x_repr())->AsInput()->assert_is_ops_input(
{"reduce_mean", "elementwise_sub"});
auto *x_mean = pattern->NewNode(x_mean_repr())->assert_is_op("reduce_mean");
auto *x_mean_out = pattern->NewNode(x_mean_out_repr())
->assert_is_op_output("reduce_mean", "Out")
->assert_is_op_input("elementwise_sub", "Y")
->AsIntermediate();
auto *x_sub_mean =
pattern->NewNode(x_sub_mean_repr())->assert_is_op("elementwise_sub");
auto *x_sub_mean_out =
pattern->NewNode(x_sub_mean_out_repr())
->assert_is_op_output("elementwise_sub")
->assert_is_ops_input({"elementwise_pow", "elementwise_div"}, "X")
->AsIntermediate();
auto *sqr_pow = pattern->NewNode(sqr_pow_repr())
->assert_is_op_input("elementwise_pow", "Y")
->assert_is_persistable_var()
->AsInput();
auto *x_sub_mean_sqr =
pattern->NewNode(x_sub_mean_sqr_repr())->assert_is_op("elementwise_pow");
auto *x_sub_mean_sqr_out = pattern->NewNode(x_sub_mean_sqr_out_repr())
->assert_is_op_output("elementwise_pow")
->assert_is_op_input("reduce_mean")
->AsIntermediate();
auto *std_dev = pattern->NewNode(std_dev_repr())->assert_is_op("reduce_mean");
auto *std_dev_out = pattern->NewNode(std_dev_out_repr())
->assert_is_op_output("reduce_mean")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto *eps = pattern->NewNode(eps_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var()
->AsInput();
auto *std_dev_eps =
pattern->NewNode(std_dev_eps_repr())->assert_is_op("elementwise_add");
auto *std_dev_eps_out = pattern->NewNode(std_dev_eps_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("sqrt")
->AsIntermediate();
auto *std_dev_eps_sqrt =
pattern->NewNode(std_dev_eps_sqrt_repr())->assert_is_op("sqrt");
auto *std_dev_eps_sqrt_out = pattern->NewNode(std_dev_eps_sqrt_out_repr())
->assert_is_op_output("sqrt")
->assert_is_op_input("elementwise_div", "Y")
->AsIntermediate();
auto *division =
pattern->NewNode(division_repr())->assert_is_op("elementwise_div");
auto *division_out = pattern->NewNode(division_out_repr())
->assert_is_op_output("elementwise_div")
->assert_is_op_input("elementwise_mul")
->AsIntermediate();
auto *gamma = pattern->NewNode(gamma_repr())
->assert_is_op_input("elementwise_mul", "Y")
->assert_is_persistable_var()
->AsInput();
auto *scale = pattern->NewNode(scale_repr())->assert_is_op("elementwise_mul");
auto *scale_out = pattern->NewNode(scale_out_repr())
->assert_is_op_output("elementwise_mul")
->assert_is_op_input("elementwise_add")
->AsIntermediate();
auto *beta = pattern->NewNode(beta_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var()
->AsInput();
auto *shift = pattern->NewNode(shift_repr())->assert_is_op("elementwise_add");
auto *shift_out = pattern->NewNode(shift_out_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
/*
* X
* / \
* / reduce_mean "u(x)"
* \ /
* elementwise_sub "x - u(x)"
* / \ 2
* | \ /
* | elementwise_pow "(x - u(x))^2"
* | |
* | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}"
* | | eps
* | | /
* | elementwise_add "sigma^2 + epsilon"
* \ |
* \ sqrt "sqrt(sigma^2 + epsilon)"
* \ /
* \ /
* elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}"
* |
* gamma |
* \ |
* elementwise_mul "scale: gamma(C) * lnorm"
* |
* beta |
* \ |
* elementwise_add "shift: gamma(C) * lnorm + beta(C)"
*/
x_mean->LinksFrom({x}).LinksTo({x_mean_out});
x_sub_mean->LinksFrom({x, x_mean_out}).LinksTo({x_sub_mean_out});
x_sub_mean_sqr->LinksFrom({x_sub_mean_out, sqr_pow})
.LinksTo({x_sub_mean_sqr_out});
std_dev->LinksFrom({x_sub_mean_sqr_out}).LinksTo({std_dev_out});
std_dev_eps->LinksFrom({std_dev_out, eps}).LinksTo({std_dev_eps_out});
std_dev_eps_sqrt->LinksFrom({std_dev_eps_out})
.LinksTo({std_dev_eps_sqrt_out});
division->LinksFrom({x_sub_mean_out, std_dev_eps_sqrt_out})
.LinksTo({division_out});
scale->LinksFrom({division_out, gamma}).LinksTo({scale_out});
shift->LinksFrom({scale_out, beta}).LinksTo({shift_out});
return shift_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1598,6 +1598,41 @@ struct MultiGru : public PatternBase { ...@@ -1598,6 +1598,41 @@ struct MultiGru : public PatternBase {
PATTERN_DECL_NODE(h); PATTERN_DECL_NODE(h);
}; };
//
// \brief Pattern looking for subgraph representing layer normalization
// operation.
//
struct LayerNorm : public PatternBase {
LayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "layer_norm") {}
PDNode* operator()();
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(x_mean);
PATTERN_DECL_NODE(x_mean_out);
PATTERN_DECL_NODE(x_sub_mean);
PATTERN_DECL_NODE(x_sub_mean_out);
PATTERN_DECL_NODE(sqr_pow);
PATTERN_DECL_NODE(x_sub_mean_sqr);
PATTERN_DECL_NODE(x_sub_mean_sqr_out);
PATTERN_DECL_NODE(std_dev);
PATTERN_DECL_NODE(std_dev_out);
PATTERN_DECL_NODE(eps);
PATTERN_DECL_NODE(std_dev_eps);
PATTERN_DECL_NODE(std_dev_eps_out);
PATTERN_DECL_NODE(std_dev_eps_sqrt);
PATTERN_DECL_NODE(std_dev_eps_sqrt_out);
PATTERN_DECL_NODE(division);
PATTERN_DECL_NODE(division_out);
PATTERN_DECL_NODE(gamma);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(beta);
PATTERN_DECL_NODE(shift);
PATTERN_DECL_NODE(shift_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
// cpplint complaints (wrong!) for not included <string> header in below line.
using string::PrettyLogDetail; // NOLINT
namespace {
void validateReduceOpAttrs(const Node* node, const std::string& name) {
const auto* op = node->Op();
if (op->HasAttr("dim")) {
auto dims = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim"));
PADDLE_ENFORCE_EQ(dims.size(), 1, platform::errors::PreconditionNotMet(
"The LayerNorm fusion ", name,
" reduction must happen only over "
"single dimension."));
PADDLE_ENFORCE_EQ(dims.front(), -1, platform::errors::PreconditionNotMet(
"The LayerNorm fusion ", name,
" reduction must happen over last "
"dimension."));
}
if (op->HasAttr("reduce_all")) {
PADDLE_ENFORCE(!BOOST_GET_CONST(bool, op->GetAttr("reduce_all")),
platform::errors::PreconditionNotMet(
"The LayerNorm fusion ", name,
" reduction must have "
"\'reduce_all\' attribute set to false."));
}
if (op->HasAttr("keep_dim")) {
PADDLE_ENFORCE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")),
platform::errors::PreconditionNotMet(
"The LayerNorm fusion ", name,
" reduction must have "
"\'keep_dim\' attribute set to true."));
}
}
void setIntermediateOut(OpDesc* desc, const std::string& out_name,
const std::string& scope_name) {
std::string new_name = scope_name + "/at." + out_name + ".new";
desc->SetOutput(out_name, {new_name});
}
void addIntermediateOut(Node* op_node, const std::string& out_name,
const std::string& scope_name, Graph* graph) {
std::string new_name = scope_name + "/at." + out_name + ".new";
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);
IR_NODE_LINK_TO(op_node, node_var);
}
} // namespace
void LayerNormFusePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"The input graph of "
"LayerNormFusePass should not be nullptr."));
FusePassBase::Init(scope_name_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd;
patterns::LayerNorm layer_norm_pattern(gpd.mutable_pattern(), scope_name_);
layer_norm_pattern();
int found_layer_norm_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Fuse LayerNorm from subgraph.";
GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_mean_out, x_mean_out, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean, x_sub_mean, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_out, x_sub_mean_out,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(sqr_pow, sqr_pow, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr, x_sub_mean_sqr,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr_out, x_sub_mean_sqr_out,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev, std_dev, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev_out, std_dev_out, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eps, eps, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps, std_dev_eps, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_out, std_dev_eps_out,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt, std_dev_eps_sqrt,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt_out, std_dev_eps_sqrt_out,
layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(division, division, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(division_out, division_out, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(gamma, gamma, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(beta, beta, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(shift, shift, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(shift_out, shift_out, layer_norm_pattern);
auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable<LoDTensor>();
// ------------------ subgraph node's validation ---------------------------
PADDLE_ENFORCE_EQ(
eps_tensor->numel(), 1,
platform::errors::InvalidArgument(
"The LayerNorm divisor "
"epsilon value must be one-element tensor, but has %s "
"elements.",
eps_tensor->numel()));
PADDLE_ENFORCE_EQ(eps_tensor->type(), proto::VarType::FP32,
platform::errors::InvalidArgument(
"The LayerNorm divisor "
"epsilon value must be of FP32 data type, but is %s.",
eps_tensor->type()));
const auto& gamma_shape = gamma->Var()->GetShape();
const auto& beta_shape = beta->Var()->GetShape();
const auto& x_shape = x->Var()->GetShape();
int64_t x_last_dim = x_shape.back();
PADDLE_ENFORCE_EQ(gamma_shape.size(), 1,
platform::errors::InvalidArgument(
"The LayerNorm gamma "
"(scale) tensor shape must be one-dimensional, "
"but is %s.",
gamma_shape.size()));
PADDLE_ENFORCE_EQ(beta_shape.size(), 1,
platform::errors::InvalidArgument(
"The LayerNorm beta "
"(shift) tensor shape must be one-dimensional, "
"but is %s.",
beta_shape.size()));
PADDLE_ENFORCE_EQ(beta_shape, gamma_shape,
platform::errors::InvalidArgument(
"The LayerNorm beta "
"and gamma tensors shapes' must be equal."));
PADDLE_ENFORCE_EQ(gamma_shape.front(), x_last_dim,
platform::errors::InvalidArgument(
"The LayerNorm beta "
"and gamma tensors shapes' must be equal to the last "
"input's dimension size."));
validateReduceOpAttrs(x_mean, "input mean");
validateReduceOpAttrs(std_dev, "std_dev mean");
// ------------------ op creation and placement ---------------------------
OpDesc ln_op_desc;
ln_op_desc.SetType("layer_norm");
ln_op_desc.SetInput("X", {x->Name()});
ln_op_desc.SetInput("Scale", {gamma->Name()});
ln_op_desc.SetInput("Bias", {beta->Name()});
ln_op_desc.SetOutput("Y", {shift_out->Name()});
setIntermediateOut(&ln_op_desc, "Mean", scope_name_);
setIntermediateOut(&ln_op_desc, "Variance", scope_name_);
ln_op_desc.SetAttr("begin_norm_axis", static_cast<int>(x_shape.size() - 1));
ln_op_desc.SetAttr("epsilon", *(eps_tensor->data<float>()));
ln_op_desc.SetAttr("is_test", true);
Node* ln_op = g->CreateOpNode(&ln_op_desc);
addIntermediateOut(ln_op, "Mean", scope_name_, g);
addIntermediateOut(ln_op, "Variance", scope_name_, g);
IR_NODE_LINK_TO(x, ln_op);
IR_NODE_LINK_TO(gamma, ln_op);
IR_NODE_LINK_TO(beta, ln_op);
IR_OP_VAR_LINK(ln_op, shift_out);
GraphSafeRemoveNodes(
g,
{x_mean, x_mean_out, x_sub_mean, x_sub_mean_out, sqr_pow,
x_sub_mean_sqr, x_sub_mean_sqr_out, std_dev, std_dev_out, eps,
std_dev_eps, std_dev_eps_out, std_dev_eps_sqrt, std_dev_eps_sqrt_out,
division, division_out, scale, scale_out, shift});
found_layer_norm_count++;
};
gpd(graph, handler);
AddStatis(found_layer_norm_count);
PrettyLogDetail("--- Fused %d subgraphs into layer_norm op.",
found_layer_norm_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass);
REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("elementwise_add", 0)
.LE("elementwise_add", 1)
.GE("elementwise_div", 0)
.LE("elementwise_div", 1)
.GE("elementwise_mul", 0)
.LE("elementwise_mul", 1)
.GE("elementwise_pow", 0)
.LE("elementwise_pow", 1)
.GE("elementwise_sub", 0)
.LE("elementwise_sub", 1)
.EQ("reduce_mean", 0)
.EQ("sqrt", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the subgraph representing layer normalization into
* layer_norm op.
*
* \note The following graph represents this equation:
*
* x - u(x)
* y(c) * ------------------- + b(c)
* sqrt(sigma^2 + eps)
*
* x - input data
* u(x) - mean
* sigma^2 - standard deviation
* eps - epsilon
* y(c) - gamma (scale) channelwise
* b(c) - beta (shift) channelwise
*
*
* X
* / \
* / reduce_mean "u(x)"
* \ /
* elementwise_sub "x - u(x)"
* / \ 2
* | \ /
* | elementwise_pow "(x - u(x))^2"
* | |
* | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}"
* | | eps
* | | /
* | elementwise_add "sigma^2 + epsilon"
* \ |
* \ sqrt "sqrt(sigma^2 + epsilon)"
* \ /
* \ /
* elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}"
* |
* gamma |
* \ |
* elementwise_mul "scale: gamma(C) * lnorm"
* |
* beta |
* \ |
* elementwise_add "shift: gamma(C) * lnorm + beta(C)"
*/
class LayerNormFusePass : public FusePassBase {
public:
virtual ~LayerNormFusePass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
private:
const std::string scope_name_{"layer_norm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
ProgramDesc BuildGraphProgram() {
auto prog = test::BuildProgramDesc(
{"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", "std_dev_out",
"std_dev_eps_out", "std_dev_eps_sqrt_out", "division_out", "scale_out",
"shift_out"},
{"sqr_pow", "eps", "gamma", "beta"});
const auto& block_desc = prog.Block(0);
auto* x_var_desc = block_desc.FindVar("x");
x_var_desc->SetDataType(proto::VarType::FP32);
x_var_desc->SetShape({3, 32, 48});
auto* eps_var_desc = block_desc.FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({1});
auto* gamma_var_desc = block_desc.FindVar("gamma");
gamma_var_desc->SetDataType(proto::VarType::FP32);
gamma_var_desc->SetShape({48});
auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({48});
auto* x_mean = test::CreateOp(&prog, "reduce_mean", {{"X", "x"}},
{{"Out", "x_mean_out"}}, false);
x_mean->SetAttr("dim", std::vector<int>{-1});
x_mean->SetAttr("keep_dim", true);
x_mean->SetAttr("reduce_all", false);
test::CreateOp(&prog, "elementwise_sub", {{"X", "x"}, {"Y", "x_mean_out"}},
{{"Out", "x_sub_mean_out"}}, false);
test::CreateOp(&prog, "elementwise_pow",
{{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
{{"Out", "x_sub_mean_sqr_out"}}, false);
auto* std_dev =
test::CreateOp(&prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}},
{{"Out", "std_dev_out"}}, false);
std_dev->SetAttr("dim", std::vector<int>{-1});
std_dev->SetAttr("keep_dim", true);
std_dev->SetAttr("reduce_all", false);
test::CreateOp(&prog, "elementwise_add", {{"X", "std_dev_out"}, {"Y", "eps"}},
{{"Out", "std_dev_eps_out"}}, false);
test::CreateOp(&prog, "sqrt", {{"X", "std_dev_eps_out"}},
{{"Out", "std_dev_eps_sqrt_out"}}, false);
test::CreateOp(&prog, "elementwise_div",
{{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
{{"Out", "division_out"}}, false);
test::CreateOp(&prog, "elementwise_mul",
{{"X", "division_out"}, {"Y", "gamma"}},
{{"Out", "scale_out"}}, false);
test::CreateOp(&prog, "elementwise_add", {{"X", "scale_out"}, {"Y", "beta"}},
{{"Out", "shift_out"}}, false);
return prog;
}
bool CheckFusedSubgraphOpsCount(const Graph& graph) {
return test::AssertOpsCount(graph, {{"reduce_mean", 0},
{"elementwise_sub", 0},
{"elementwise_pow", 0},
{"elementwise_add", 0},
{"sqrt", 0},
{"elementwise_div", 0},
{"elementwise_mul", 0},
{"layer_norm", 1}});
}
} // namespace
// ------------------------------ Test cases -----------------------------------
TEST(FuseLayerNormPass, TestFuse) {
ProgramDesc prog = BuildGraphProgram();
Graph graph(prog);
constexpr int removed_nodes = 19;
// LayerNorm + outputs: {Mean, Variance}
constexpr int added_nodes = 3;
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
float eps_value = 1e-5f;
// Init scope, as it is used in pass
exe.CreateVariables(prog, 0, true, &scope);
test::InitLoDTensorHolder<float>(&scope, place, "eps", {1}, &eps_value);
graph.SetNotOwned(kParamScopeAttr, &scope);
EXPECT_TRUE(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
"shift_out", removed_nodes, added_nodes));
EXPECT_TRUE(CheckFusedSubgraphOpsCount(graph));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "layer_norm") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("is_test"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("is_test")));
ASSERT_TRUE(op->HasAttr("begin_norm_axis"));
ASSERT_TRUE(op->HasAttr("epsilon"));
}
}
}
TEST(FuseLayerNormPass, TestInvalidEpsNumel) {
ProgramDesc prog = BuildGraphProgram();
auto* eps_var_desc = prog.Block(0).FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({2});
Graph graph(prog);
constexpr int removed_nodes = 19;
constexpr int added_nodes = 3;
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
auto eps_values = std::vector<float>{1e-5f, 1e-5f};
// Init scope, as it is used in pass
exe.CreateVariables(prog, 0, true, &scope);
test::InitLoDTensorHolder<float>(&scope, place, "eps", {2},
eps_values.data());
graph.SetNotOwned(kParamScopeAttr, &scope);
EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
"shift_out", removed_nodes, added_nodes),
paddle::platform::EnforceNotMet);
}
TEST(FuseLayerNormPass, TestInvalidEpsDataType) {
ProgramDesc prog = BuildGraphProgram();
auto* eps_var_desc = prog.Block(0).FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP64);
eps_var_desc->SetShape({1});
Graph graph(prog);
constexpr int removed_nodes = 19;
constexpr int added_nodes = 3;
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
double eps_value = 1e-5;
// Init scope, as it is used in pass
exe.CreateVariables(prog, 0, true, &scope);
test::InitLoDTensorHolder<double>(&scope, place, "eps", {1}, &eps_value);
graph.SetNotOwned(kParamScopeAttr, &scope);
EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
"shift_out", removed_nodes, added_nodes),
paddle::platform::EnforceNotMet);
}
TEST(FuseLayerNormPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("layer_norm_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(layer_norm_fuse_pass);
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
......
...@@ -13,15 +13,19 @@ ...@@ -13,15 +13,19 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <cstring>
#include <exception> #include <exception>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <list> #include <list>
#include <map> #include <map>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,7 +36,7 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, ...@@ -32,7 +36,7 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs, const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs, const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn) { bool use_mkldnn) {
auto op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name); op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
...@@ -43,6 +47,8 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, ...@@ -43,6 +47,8 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
op->SetOutput(output.first, {output.second}); op->SetOutput(output.first, {output.second});
} }
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return op; return op;
} }
...@@ -168,6 +174,49 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, ...@@ -168,6 +174,49 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
return expected_nodes_num == current_nodes_num; return expected_nodes_num == current_nodes_num;
} }
template <typename T>
void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place,
const std::string& var_name,
const std::vector<int64_t>& dims, const T* data) {
auto var = scope->Var(var_name);
auto tensor = var->GetMutable<LoDTensor>();
auto* tensor_mem_ptr = tensor->mutable_data<T>(make_ddim(dims), place);
if (data != nullptr) {
std::memcpy(tensor_mem_ptr, data, tensor->memory_size());
} else {
std::memset(tensor_mem_ptr, 0, tensor->memory_size());
}
}
// Instantiate for below data types.
template void InitLoDTensorHolder<float>(Scope*, const paddle::platform::Place&,
const std::string&,
const std::vector<int64_t>&,
const float*);
template void InitLoDTensorHolder<int>(Scope*, const paddle::platform::Place&,
const std::string&,
const std::vector<int64_t>&, const int*);
template void InitLoDTensorHolder<double>(Scope*,
const paddle::platform::Place&,
const std::string&,
const std::vector<int64_t>&,
const double*);
OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type,
const std::string& output_name,
const std::string& output_arg_name) {
auto all_ops = prog.Block(0).AllOps();
for (auto* op_desc : all_ops) {
if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) {
const auto& arg_names = op_desc->Outputs().at(output_name);
for (const auto& name : arg_names) {
if (name == output_arg_name) return op_desc;
}
}
}
return nullptr;
}
} // namespace test } // namespace test
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -18,9 +18,13 @@ ...@@ -18,9 +18,13 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -113,6 +117,37 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, ...@@ -113,6 +117,37 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
const std::string& from, const std::string& to, const std::string& from, const std::string& to,
int removed_nodes_count, int added_nodes_count = 0); int removed_nodes_count, int added_nodes_count = 0);
///
/// @brief Initializes the tensor memory holder.
///
/// @param[in] scope The scope that manages the variable.
/// @param[in] place The place where memory will be allocated.
/// @param[in] var_name The variable name.
/// @param[in] dims The dimensions of allocated tensor.
///
/// @tparam T Tensor data type.
///
template <typename T>
void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place,
const std::string& var_name,
const std::vector<int64_t>& dims,
const T* data = nullptr);
///
/// @brief Retrieve operator descriptor from program.
///
/// @param[in] prog The program descriptor containing the op we
/// search for.
/// @param[in] op_type The wanted operator type name.
/// @param[in] output_name The wanted operator output name.
/// @param[in] output_arg_name The wanted operator output argument name.
///
/// @return The operator descriptor.
///
OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type,
const std::string& output_name,
const std::string& output_arg_name);
} // namespace test } // namespace test
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -161,7 +161,8 @@ void GpuPassStrategy::EnableMkldnnBfloat16() { ...@@ -161,7 +161,8 @@ void GpuPassStrategy::EnableMkldnnBfloat16() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will // NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones. // not be damaged by smaller ones.
passes_.assign({"simplify_with_basic_ops_pass", // passes_.assign({"simplify_with_basic_ops_pass", //
"layer_norm_fuse_pass",
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", // // "seqpool_concat_fuse_pass", //
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of subgraph expressing layer normalization."""
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from inference_pass_test import InferencePassTest
from paddle import enable_static
from paddle.fluid.core import PassVersionChecker
class LayerNormFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 64, 120], dtype="float32")
sqr_pow = fluid.layers.fill_constant(
shape=[1], value=2, dtype="float32")
eps = fluid.layers.fill_constant(
shape=[1], value=1e-5, dtype="float32")
gamma = fluid.layers.create_parameter(
shape=[120], dtype="float32", is_bias=True)
beta = fluid.layers.create_parameter(
shape=[120], dtype="float32", is_bias=True)
x_mean_out = fluid.layers.reduce_mean(data, dim=-1, keep_dim=True)
x_sub_mean_out = fluid.layers.elementwise_sub(data, x_mean_out)
x_sub_mean_sqr_out = fluid.layers.elementwise_pow(x_sub_mean_out,
sqr_pow)
std_dev_out = fluid.layers.reduce_mean(
x_sub_mean_sqr_out, dim=-1, keep_dim=True)
std_dev_eps_out = fluid.layers.elementwise_add(std_dev_out, eps)
std_dev_eps_sqrt_out = fluid.layers.sqrt(std_dev_eps_out)
division_out = fluid.layers.elementwise_div(x_sub_mean_out,
std_dev_eps_sqrt_out)
scale_out = fluid.layers.elementwise_mul(division_out, gamma)
shift_out = fluid.layers.elementwise_add(scale_out, beta)
self.feeds = {
"data": np.random.random((3, 64, 120)).astype("float32"),
}
self.fetch_list = [shift_out]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(PassVersionChecker.IsCompatible("layer_norm_fuse_pass"))
if __name__ == "__main__":
enable_static()
unittest.main()
...@@ -296,6 +296,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -296,6 +296,7 @@ STATIC_MODE_TESTING_LIST = [
'test_layer_norm_mkldnn_op', 'test_layer_norm_mkldnn_op',
'test_layer_norm_bf16_mkldnn_op', 'test_layer_norm_bf16_mkldnn_op',
'test_layer_norm_op_v2', 'test_layer_norm_op_v2',
'test_layer_norm_fuse_pass',
'test_learning_rate_scheduler', 'test_learning_rate_scheduler',
'test_linear_interp_op', 'test_linear_interp_op',
'test_linear_interp_v2_op', 'test_linear_interp_v2_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册