未验证 提交 8d6dc102 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Ernie GPU Optimize]: Embedding_eltwise_layernorm Fuse (#22494)

* 1. add embedding eltwise layernorm fuse
2. add embedding eltwise layernorm op
3. refine inplace_add_relu
4. refine fc_eltwise_layernorm
test=develop

* 1. refine fc
test=develop

* fix comments
test=develop

* fix comments

test=develop
上级 4ff2915d
...@@ -118,7 +118,7 @@ function(op_library TARGET) ...@@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op") "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -78,6 +78,7 @@ pass_library(fc_elementwise_layernorm_fuse_pass base) ...@@ -78,6 +78,7 @@ pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference) pass_library(multihead_matmul_fuse_pass inference)
if(WITH_GPU) if(WITH_GPU)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
endif() endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static int BuildFusion(Graph* graph, const std::string& name_scope,
const Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
EmbeddingEltwiseLayerNormPattern emb_eltwise_layernorm_pattern(pattern,
name_scope);
emb_eltwise_layernorm_pattern();
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_x, lookup_table3_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_w, lookup_table3_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3, lookup_table3,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_out, lookup_table3_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12, eltwise_add_12,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12_out, eltwise_add_12_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
emb_eltwise_layernorm_pattern);
auto get_persist_tensor_dims = [&](std::string name) -> framework::DDim {
auto* var = scope->FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet(
"Cant not found the %d var in scope.", name));
return var->GetMutable<LoDTensor>()->dims();
};
// Check the weight dims.
auto word_emb_dims = get_persist_tensor_dims(lookup_table1_w->Name());
auto pos_emb_dims = get_persist_tensor_dims(lookup_table2_w->Name());
auto sent_emb_dims = get_persist_tensor_dims(lookup_table3_w->Name());
if (word_emb_dims.size() != 2 || pos_emb_dims.size() != 2 ||
sent_emb_dims.size() != 2 || word_emb_dims[1] != pos_emb_dims[1] ||
word_emb_dims[1] != sent_emb_dims[1]) {
return;
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("WordId", {lookup_table1_x->Name()});
new_op_desc.SetInput("PosId", {lookup_table2_x->Name()});
new_op_desc.SetInput("SentId", {lookup_table3_x->Name()});
new_op_desc.SetInput("WordEmb", {lookup_table1_w->Name()});
new_op_desc.SetInput("PosEmb", {lookup_table2_w->Name()});
new_op_desc.SetInput("SentEmb", {lookup_table3_w->Name()});
new_op_desc.SetInput("Bias", {layer_norm_bias->Name()});
new_op_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_op_desc.SetOutput("Out", {layer_norm_out->Name()});
new_op_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(lookup_table1_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table2_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table3_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table1_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table2_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table3_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(layer_norm_bias, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(layer_norm_scale, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, layer_norm_out);
std::unordered_set<const Node*> marked_nodes(
{lookup_table1, lookup_table2, lookup_table3, lookup_table1_out,
lookup_table2_out, lookup_table3_out, eltwise_add_12,
eltwise_add_12_out, eltwise_add, eltwise_add_out, layer_norm,
layer_norm_mean, layer_norm_variance});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
PDNode* EmbeddingEltwiseLayerNormPattern::operator()() {
// Create shared nodes.
auto create_emb_vars = [&](const std::string& name, const std::string& arg,
bool is_persist = false) -> PDNode* {
PDNode* node = pattern->NewNode(name)
->assert_is_op_input("lookup_table", arg)
->AsInput();
if (is_persist) return node->assert_is_persistable_var();
return node;
};
auto create_emb_out_vars = [&](const std::string& name,
const std::string& arg) -> PDNode* {
PDNode* node = pattern->NewNode(name)
->AsIntermediate()
->assert_is_op_output("lookup_table")
->assert_is_op_input("elementwise_add", arg);
return node;
};
auto* lookup_table1_x = create_emb_vars(lookup_table1_x_repr(), "Ids");
auto* lookup_table2_x = create_emb_vars(lookup_table2_x_repr(), "Ids");
auto* lookup_table3_x = create_emb_vars(lookup_table3_x_repr(), "Ids");
auto* lookup_table1_w = create_emb_vars(lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w = create_emb_vars(lookup_table2_w_repr(), "W", true);
auto* lookup_table3_w = create_emb_vars(lookup_table3_w_repr(), "W", true);
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
auto* lookup_table3 =
pattern->NewNode(lookup_table3_repr())->assert_is_op("lookup_table");
auto* lookup_table1_out = create_emb_out_vars(lookup_table1_out_repr(), "X");
auto* lookup_table2_out = create_emb_out_vars(lookup_table2_out_repr(), "Y");
auto* lookup_table3_out = create_emb_out_vars(lookup_table3_out_repr(), "Y");
auto* eltwise_add_12 =
pattern->NewNode(eltwise_add_12_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_12_out = pattern->NewNode(eltwise_add_12_out_repr())
->AsIntermediate()
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->AsIntermediate()
->assert_is_op_output("elementwise_add");
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Link all nodes together
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
lookup_table3->LinksFrom({lookup_table3_x, lookup_table3_w})
.LinksTo({lookup_table3_out});
eltwise_add_12->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_12_out});
eltwise_add->LinksFrom({lookup_table3_out, eltwise_add_12_out})
.LinksTo({eltwise_add_out});
layer_norm
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
return layer_norm_out;
}
} // namespace patterns
void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::PreconditionNotMet(
"The scope is null, please initialize the scope first."));
int fusion_count = patterns::BuildFusion(graph, name_scope_, scope);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::EmbeddingEltwiseLayerNormFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct EmbeddingEltwiseLayerNormPattern : public PatternBase {
EmbeddingEltwiseLayerNormPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding_eltwise_layernorm") {}
PDNode* operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table3_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table3_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table3);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(lookup_table3_out);
PATTERN_DECL_NODE(eltwise_add_12);
PATTERN_DECL_NODE(eltwise_add_12_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
} // namespace patterns
// The EmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out
class EmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
virtual ~EmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"embedding_eltwise_layernorm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -327,6 +327,16 @@ struct Layers { ...@@ -327,6 +327,16 @@ struct Layers {
return outs; return outs;
} }
VarDesc* embedding(VarDesc* x, VarDesc* weights) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("lookup_table");
op->SetInput("Ids", {x->Name()});
op->SetInput("W", {weights->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
void backward(std::vector<VarDesc*> targets) { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program, // This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program. // but is constructed differently as the actual program.
......
...@@ -107,7 +107,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -107,7 +107,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"multihead_matmul_fuse_pass_v2", "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......
...@@ -6,6 +6,7 @@ register_operators(EXCLUDES ...@@ -6,6 +6,7 @@ register_operators(EXCLUDES
fusion_conv_inception_op fusion_conv_inception_op
fused_fc_elementwise_layernorm_op fused_fc_elementwise_layernorm_op
multihead_matmul_op multihead_matmul_op
fused_embedding_eltwise_layernorm_op
fusion_group_op) fusion_group_op)
if (WITH_GPU) if (WITH_GPU)
...@@ -33,6 +34,8 @@ if (WITH_GPU) ...@@ -33,6 +34,8 @@ if (WITH_GPU)
# multihead_matmul_op # multihead_matmul_op
op_library(multihead_matmul_op) op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
op_library(fused_embedding_eltwise_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n")
# fusion_group # fusion_group
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code) op_library(fusion_group_op DEPS device_code)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace operators {
class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("WordId"), true,
platform::errors::InvalidArgument(
"Input(WordId) of EmbeddingEltWiseLayerNormOp should "
"not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("PosId"), true,
platform::errors::InvalidArgument(
"Input(PosId) of EmbeddingEltWiseLayerNormOp should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("SentId"), true,
platform::errors::InvalidArgument(
"Input(SentId) of EmbeddingEltWiseLayerNormOp should "
"not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("WordEmb"), true,
platform::errors::InvalidArgument(
"Input(WordEmb) of EmbeddingEltWiseLayerNormOp "
"should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("PosEmb"), true,
platform::errors::InvalidArgument(
"Input(PosEmb) of EmbeddingEltWiseLayerNormOp should "
"not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("SentEmb"), true,
platform::errors::InvalidArgument(
"Input(SentEmb) of EmbeddingEltWiseLayerNormOp "
"should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of EmbeddingEltWiseLayerNormOp should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Scale"), true,
platform::errors::InvalidArgument(
"Input(Scale) of EmbeddingEltWiseLayerNormOp should not be null."));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EmbeddingEltWiseLayerNormOp should not be null."));
// batch * seq_len * 1
auto dims_word_id = context->GetInputDim("WordId");
// word_num * hidden
auto dims_word_emb = context->GetInputDim("WordEmb");
auto dims_pos_emb = context->GetInputDim("PosEmb");
auto dims_sent_emb = context->GetInputDim("SentEmb");
// hidden
auto dims_bias = context->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
dims_word_emb[1], dims_bias[0],
platform::errors::InvalidArgument(
"The second dims (%d) of the Word Embedding should be equal "
"to the Bias's size(%d).",
dims_word_emb[1], dims_bias[0]));
PADDLE_ENFORCE_EQ(dims_word_emb.size(), 2,
platform::errors::InvalidArgument(
"The WordEmb dim's size shoule be 2, but found %d.",
dims_word_emb.size()));
PADDLE_ENFORCE_EQ(dims_pos_emb.size(), 2,
platform::errors::InvalidArgument(
"The PosEmb dim's size shoule be 2, but found %d.",
dims_pos_emb.size()));
PADDLE_ENFORCE_EQ(dims_sent_emb.size(), 2,
platform::errors::InvalidArgument(
"The SentEmb dim's size shoule be 2, but found %d.",
dims_sent_emb.size()));
PADDLE_ENFORCE_EQ(
dims_word_emb[1], dims_pos_emb[1],
platform::errors::InvalidArgument(
"The WordEmb first dim size(%d) shoule equal to PosEmb ones(%d).",
dims_word_emb[1], dims_pos_emb[1]));
PADDLE_ENFORCE_EQ(
dims_word_emb[1], dims_sent_emb[1],
platform::errors::InvalidArgument(
"The WordEmb first dim size(%d) shoule equal to SentEmb ones(%d).",
dims_word_emb[1], dims_sent_emb[1]));
int batch = dims_word_id[0];
int seq_len = dims_word_id[1];
int hidden = dims_word_emb[1];
auto dim_output = framework::make_ddim({batch, seq_len, hidden});
context->SetOutputDim("Out", dim_output);
context->ShareLoD("WordId", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "WordEmb");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class EmbeddingEltWiseLayerNormOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("WordId", "The word id input of EmbeddingEltWiseLayerNorm op");
AddInput("PosId", "The position id input of EmbeddingEltWiseLayerNorm op");
AddInput("SentId", "The sentence id input of EmbeddingEltWiseLayerNorm op");
AddInput("WordEmb",
"The Word embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("PosEmb",
"The Position embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("SentEmb",
"The Sent embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op");
AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op");
AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op");
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float& epsilon) {
PADDLE_ENFORCE_GE(
epsilon, 0.0f,
platform::errors::InvalidArgument(
"'epsilon' is %f, but it should be between 0.0 and 0.001",
epsilon));
PADDLE_ENFORCE_LE(
epsilon, 0.001f,
platform::errors::InvalidArgument(
"'epsilon' is %f, but it should be between 0.0 and 0.001.",
epsilon));
});
AddComment(R"DOC(
EmbeddingEltWiseLayerNorm Operator.
This op is used for optimize the following structure in ernie model.
wordid -> lookup_table_op -> word
posid -> lookup_table_op -> pos
sentdid -> lookup_table_op -> sent
word + pos + sent -> Y
Y -> layer_norm -> Out
Not suggest to use in other case except has same structure as ernie.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fused_embedding_eltwise_layernorm,
ops::EmbeddingEltWiseLayerNormOp,
ops::EmbeddingEltWiseLayerNormOpMaker);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
using kvp = cub::KeyValuePair<T, T>;
template <typename T>
using cv2 = cub::CubVector<T, 2>;
template <typename T, int TPB>
__device__ inline void LayerNorm(const cv2<T> &thread_data, const int ld,
const int offset, const float *bias,
const float *scale, T *output, float eps) {
using BlockReduce = cub::BlockReduce<cv2<T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
mu = sum_kv.x;
rsigma = rsqrt(sum_kv.y - mu * mu + eps);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(scale[i]);
const T b(bias[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel(
int hidden, const int64_t *word_id_d, const int64_t *pos_id_d,
const int64_t *sent_id_d, const T *scale, const T *bias, const T *word_emb,
const T *pos_emb, const T *sent_emb, T *output, float eps) {
cub::Sum pair_sum;
// blockIdx.x: position in the sequence
// blockIdx.y: batch
// gridDim.x: Seq
// gridDim.y: Batch
__shared__ int64_t word_id;
__shared__ int64_t pos_id;
__shared__ int64_t sent_id;
const T rhidden = T(1.f) / T(hidden);
const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
if (threadIdx.x == 0) {
word_id = word_id_d[seq_pos];
pos_id = pos_id_d[seq_pos];
sent_id = sent_id_d[seq_pos];
}
__syncthreads();
// load word, pos, sentence embeddings and add them toghether
const int64_t woffset = word_id * hidden;
const int64_t poffset = pos_id * hidden;
const int64_t soffset = sent_id * hidden;
const int64_t out_offset = seq_pos * hidden;
cv2<T> thread_data;
thread_data.x = 0;
thread_data.y = 0;
#pragma unroll
for (int it = threadIdx.x; it < hidden; it += TPB) {
const T w(word_emb[woffset + it]);
const T p(pos_emb[poffset + it]);
const T s(sent_emb[soffset + it]);
const T val = w + s + p;
output[out_offset + it] = val;
const T rhiddenval = rhidden * val;
cv2<T> temp_data;
temp_data.x = rhiddenval;
temp_data.y = rhiddenval * val;
thread_data = pair_sum(thread_data, temp_data);
}
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
}
template <typename DeviceContext, typename T>
class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
using Tensor = framework::Tensor;
auto *word_id = context.Input<framework::Tensor>("WordId");
auto *pos_id = context.Input<framework::Tensor>("PosId");
auto *sent_id = context.Input<framework::Tensor>("SentId");
auto *word_emb = context.Input<framework::Tensor>("WordEmb");
auto *pos_emb = context.Input<framework::Tensor>("PosEmb");
auto *sent_emb = context.Input<framework::Tensor>("SentEmb");
auto *bias = context.Input<framework::Tensor>("Bias");
auto *scale = context.Input<framework::Tensor>("Scale");
auto *out = context.Output<framework::Tensor>("Out");
auto *word_id_d = word_id->data<int64_t>();
auto *pos_id_d = pos_id->data<int64_t>();
auto *sent_id_d = sent_id->data<int64_t>();
auto *word_emb_d = word_emb->data<T>();
auto *pos_emb_d = pos_emb->data<T>();
auto *sent_emb_d = sent_emb->data<T>();
auto *bias_d = bias->data<T>();
auto *scale_d = scale->data<T>();
auto *output_d = out->mutable_data<T>(context.GetPlace());
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
float eps = context.Attr<float>("epsilon");
// should be (B * S * hidden)
auto word_id_dims = word_id->dims();
auto word_emb_dims = word_emb->dims();
int batch = word_id_dims[0];
int seq_len = word_id_dims[1];
int hidden = word_emb_dims[1];
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
EmbEltwiseLayernormKernel<T, tpb><<<grid, block, 0, device_ctx.stream()>>>(
hidden, word_id_d, pos_id_d, sent_id_d, scale_d, bias_d, word_emb_d,
pos_emb_d, sent_emb_d, output_d, eps);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_embedding_eltwise_layernorm,
ops::EmbeddingEltWiseLayerNormKernel<
paddle::platform::CUDADeviceContext, float>);
...@@ -52,7 +52,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, ...@@ -52,7 +52,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0,
const T* scale, T* out, const T* scale, T* out,
T* mean, T* variance, int M, T* mean, T* variance, int M,
int N, float epsilon) { int N, float epsilon) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T shared_mem[BlockDim + 2]; __shared__ T shared_mem[BlockDim + 2];
...@@ -63,8 +63,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, ...@@ -63,8 +63,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0,
int save_index = threadIdx.x; int save_index = threadIdx.x;
T* save_ptr = shared_mem; T* save_ptr = shared_mem;
double sum_i = 0; T sum_i = 0;
double square_sum_i = 0; T square_sum_i = 0;
for (int j = threadIdx.x; j < N; j += blockDim.x) { for (int j = threadIdx.x; j < N; j += blockDim.x) {
T tmp_0 = out[index]; T tmp_0 = out[index];
// Add bias // Add bias
...@@ -87,8 +87,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, ...@@ -87,8 +87,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0,
} }
auto pair = BlockReduce(temp_storage) auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<double>(sum_i, square_sum_i), .Reduce(PairForLayerNorm<T>(sum_i, square_sum_i),
PairForLayerNormAddFunctor<double>()); PairForLayerNormAddFunctor<T>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
T mean_i = static_cast<T>(pair.first_ / N); T mean_i = static_cast<T>(pair.first_ / N);
...@@ -197,5 +197,4 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> { ...@@ -197,5 +197,4 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm, REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm,
ops::FusedFCElementwiseLayerNormOpKernel<float>, ops::FusedFCElementwiseLayerNormOpKernel<float>);
ops::FusedFCElementwiseLayerNormOpKernel<double>);
...@@ -20,18 +20,56 @@ namespace paddle { ...@@ -20,18 +20,56 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T>
struct FcTypeTraits;
template <>
struct FcTypeTraits<float> {
typedef float4 Type;
};
template <>
struct FcTypeTraits<double> {
typedef double4 Type;
};
template <typename T, bool DoRelu> template <typename T, bool DoRelu>
__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) { __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
for (int i = blockIdx.x; i < M; i += gridDim.x) { int tid = blockIdx.x * blockDim.x + threadIdx.x;
int index = i * N + threadIdx.x; if (tid < num) {
for (int j = threadIdx.x; j < N; j += blockDim.x) { int bias_idx = tid % K;
T tmp = data[index] + bias[j]; const T bias_ptr = bias[bias_idx];
const T in_ptr = data[tid];
T packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
packed_val.z = in_ptr.z + bias_ptr.z;
packed_val.w = in_ptr.w + bias_ptr.w;
if (DoRelu) { if (DoRelu) {
data[index] = (tmp > 0) ? tmp : 0; packed_val.x = fmaxf(0.f, packed_val.x);
} else { packed_val.y = fmaxf(0.f, packed_val.y);
data[index] = tmp; packed_val.z = fmaxf(0.f, packed_val.z);
packed_val.w = fmaxf(0.f, packed_val.w);
} }
index += blockDim.x; data[tid] = packed_val;
}
}
template <typename T, bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
int offset = blockIdx.x * N;
for (int i = threadIdx.x; i < N; i += BlockDim) {
T temp;
#if __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
temp = data[offset + i] + bias[i];
#endif
if (DoRelu) {
data[offset + i] = static_cast<int>(temp > 0) * temp;
} else {
data[offset + i] = temp;
} }
} }
} }
...@@ -54,18 +92,35 @@ class FCFunctor<platform::CUDADeviceContext, T> { ...@@ -54,18 +92,35 @@ class FCFunctor<platform::CUDADeviceContext, T> {
return; return;
} }
const int kThreadsPerBlock = 1024; // M * N
int max_threads = context.GetMaxPhysicalThreadCount(); if (N % 4 == 0) {
int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5)); const int threads = 256;
int num_blocks = std::max(max_threads / num_threads, 1); const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) { if (relu) {
InplaceAddReluKernel< bias_relu_v4<trans_type,
T, true><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M, true><<<blocks, threads, 0, context.stream()>>>(
N); num, bias_ptr_v4, data_ptr_v4, N / 4);
} else { } else {
InplaceAddReluKernel< bias_relu_v4<trans_type,
T, false><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M, false><<<blocks, threads, 0, context.stream()>>>(
N); num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;
if (relu) {
InplaceAddReluKernel<T, true,
threads><<<blocks, threads, 0, context.stream()>>>(
N, B, Y);
} else {
InplaceAddReluKernel<T, false,
threads><<<blocks, threads, 0, context.stream()>>>(
N, B, Y);
}
} }
} }
}; };
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from pass_test import PassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
class EmbEltwiseLayerNormFusePassTest(PassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
word_id = fluid.layers.data(
name="word_id",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
pos_id = fluid.layers.data(
name="pos_id",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
sent_id = fluid.layers.data(
name="sent_id",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
word_emb = fluid.layers.embedding(
input=word_id, size=(128, 768), dtype='float32')
pos_emb = fluid.layers.embedding(
input=pos_id, size=(128, 768), dtype='float32')
sent_emb = fluid.layers.embedding(
input=sent_id, size=(128, 768), dtype='float32')
add1 = fluid.layers.elementwise_add(word_emb, pos_emb)
add2 = fluid.layers.elementwise_add(add1, sent_emb)
hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2)
self.feeds = {
"word_id": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
"pos_id": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
"sent_id": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
}
self.fetch_list = [hidden1]
self.pass_names = "embedding_eltwise_layernorm_fuse_pass"
self.fused_op_type = "fused_embedding_eltwise_layernorm"
self.num_fused_ops = 1
def test_check_output(self):
use_gpu_set = [True]
if not core.is_compiled_with_cuda():
return
self.pass_attrs = {
"embedding_eltwise_layernorm_fuse_pass": {
"use_gpu": True
}
}
place = fluid.CUDAPlace(0)
self.check_output_with_place(place, startup_on_cpu=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册