diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7aa2766763ce9441b0e4de969930af50fb7a55e0..715d324c357fb32d214b47c740101efb2eb52276 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -127,7 +127,7 @@ function(op_library TARGET) "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" -"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" +"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fused_bn_add_activation_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc index 51861b402d58aa1224fbbfbc1476ed848716d5f7..19662a04f541d778db4c06be0e8402db295b4c0a 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc @@ -326,6 +326,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); int fusion_count = patterns::BuildFusion(graph, name_scope_); + if (fusion_count > 0) { + graph->Set(kEmbEltwiseLayernormPass, new bool(true)); + } AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index d1fbc8396ba55523f3769a26ceaf9ef4e7fcf65e..cd6d1d57034d7ca5e849c98884c6435d6394eebd 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -696,7 +696,11 @@ void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { platform::errors::Fatal( "During the multiheadMatmul pass, The scope should not be null.")); - patterns::BuildFusionV2(graph, name_scope_, scope); + int fusion_count = patterns::BuildFusionV2(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kMultiheadMatmulPass, new bool(true)); + } + AddStatis(fusion_count); } } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 668dc74eab20a17d3697ebe778a1a5bb63cdab48..a3b1b33d2685b06e16702943eb4a51d0ae3648da 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -36,6 +36,9 @@ struct PassRegistrar; typedef std::unordered_set PassRecorder; constexpr char kPassRecorder[] = "pass_recorder"; +constexpr char kEmbEltwiseLayernormPass[] = + "embedding_eltwise_layernorm_fuse_pass_flag"; +constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; class Pass { public: diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc index e5f348dfeb13e97632aa4901b6109576a21f67af..b708f2eff10e7506a08a7bfefc4bc84cd1b937cf 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -134,6 +134,14 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, fused_pattern); + // check if is in ernie or not + if (!graph->Has(kEmbEltwiseLayernormPass) || + !graph->Has(kMultiheadMatmulPass)) { + LOG(INFO) << "The skip_layernorm_fuse_pass is only supported in " + << "Ernie/Bert model. Just skip this pass."; + return; + } + std::unordered_set del_node_set; // Create an SkipLayerNorm op node diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc index eff5dcddf54ee49be5b14a7bdfa609079f925036..29be2c3cb09a7f659efaad0dfd197514d13d96a6 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc @@ -36,6 +36,8 @@ TEST(SkipLayerNormFusePass, basic) { layers.layer_norm(elementwise_out, scale, bias); std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set(kEmbEltwiseLayernormPass, new bool(true)); + graph->Set(kMultiheadMatmulPass, new bool(true)); auto pass = PassRegistry::Instance().Get("skip_layernorm_fuse_pass"); int num_nodes_before = graph->Nodes().size(); VLOG(3) << DebugString(graph); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 29c5ce7e59b4122ca0980db3b23125729cf2b868..08f3d609fa3e6ad32c7751fe9178bc8a83463f43 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -117,20 +117,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( block_desc.Proto()->set_idx(0); LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; - bool has_fused_embedding_eltwise_layernorm = false; - bool has_multihead_matmul = false; for (auto *node : subgraph) { auto *new_block_op = new_block->AppendOp(); auto *op = block_desc.AppendOp(); *new_block_op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto(); - if (!has_fused_embedding_eltwise_layernorm && - op->Type() == "fused_embedding_eltwise_layernorm") { - has_fused_embedding_eltwise_layernorm = true; - } - if (!has_multihead_matmul && op->Type() == "multihead_matmul") { - has_multihead_matmul = true; - } } // Then, we will use the input_names_with_id and output_names_with_id to @@ -318,8 +309,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); trt_engine->SetUseOSS(Get("use_oss")); - trt_engine->SetWithErnie(has_multihead_matmul && - has_fused_embedding_eltwise_layernorm); + trt_engine->SetWithErnie( + graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)); bool need_serialize = (use_static_engine && !load_from_memory); if (need_serialize) { diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 37000594b57d4fb80d4b2cf2571d028fe502638d..e16774e5422b27c78a43d3e2c2a23685ac2f4670 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -174,7 +174,13 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #undef CP_MEMBER - Update(); + // Update(); + // Update() will reset all the passes, when some tensorRT pass is deleted in + // other.pass_builder(), it will set again, so just copy the passes. + pass_builder_->ClearPasses(); + for (const std::string &pass : other.pass_builder()->AllPasses()) { + pass_builder_->AppendPass(pass); + } } void AnalysisConfig::EnableCUDNN() { diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index ee4716bb56bc299ce2d57a06c71d3f41191bce25..f516d605cc1e2e01e2d5b2827744788a34881f92 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -78,6 +78,7 @@ class SliceOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) if (engine_->use_oss() && engine_->with_ernie()) { std::vector plugin_inputs; // plugin_inputs.emplace_back(trans_layer->getOutput(0)); @@ -92,17 +93,16 @@ class SliceOpConverter : public OpConverter { layer = engine_->AddPluginV2(plugin_inputs.data(), plugin_inputs.size(), plugin); } else { -#if IS_TRT_VERSION_GE(6000) bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); layer = engine_->AddPluginV2(&input, 1, plugin); + } #else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); #endif - } } else { bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::SlicePlugin* plugin = diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 1907bb93ccbfbcdb127e8b28de26fb499ab170b4..6d4dc69de07a84579f9588fd1fde190e94e945a2 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -525,6 +525,15 @@ if(WITH_GPU AND TENSORRT_FOUND) EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4) + set(TEST_TRT_TRANSFORMER_PRUNE_MODEL "${TRT_MODEL_INSTALL_DIR}/transformer_prune") + if (NOT EXISTS ${TEST_TRT_TRANSFORMER_PRUNE_MODEL}/transformer_prune.tar.gz) + inference_download_and_uncompress(${TEST_TRT_TRANSFORMER_PRUNE_MODEL} ${INFERENCE_URL}/tensorrt_test "transformer_prune.tar.gz") + endif() + + inference_analysis_test(test_trt_dynamic_shape_transformer_prune SRCS trt_dynamic_shape_transformer_prune_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TEST_TRT_TRANSFORMER_PRUNE_MODEL}/transformer_prune) + set(TEST_TRT_ERNIE_UNSER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_unserialized/") if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL}/ernie_model_4_unserialized.tgz) inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3916cf361c4b87602d9abc996788566da0488bbf --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +void run(const AnalysisConfig& config, std::vector* out_data) { + auto predictor = CreatePaddlePredictor(config); + auto input_names = predictor->GetInputNames(); + + int run_batch = 1; + const int run_seq_len = 128; + + std::vector tmp_input; + std::vector tmp_four_input; + tmp_input.reserve(run_batch * run_seq_len); + tmp_four_input.reserve(run_batch * run_seq_len); + + int64_t i0[run_seq_len] = { + 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, + 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, + 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; + int64_t i1[run_seq_len] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; + float i3[run_seq_len] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + // first input + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({run_batch, run_seq_len, 1}); + input_t->copy_from_cpu(i0); + + // second input + auto input_t2 = predictor->GetInputTensor(input_names[1]); + input_t2->Reshape({run_batch, run_seq_len, 1}); + input_t2->copy_from_cpu(i1); + + // third input. + auto input_t3 = predictor->GetInputTensor(input_names[2]); + input_t3->Reshape({run_batch, run_seq_len, 1}); + input_t3->copy_from_cpu(i2); + + auto input_t4 = predictor->GetInputTensor(input_names[3]); + input_t4->Reshape({run_batch, run_seq_len, 1}); + input_t4->copy_from_cpu(i3); + + ASSERT_TRUE(predictor->ZeroCopyRun()); + + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data->resize(out_num); + output_t->copy_to_cpu(out_data->data()); +} + +void trt_ernie(bool with_fp16, std::vector result) { + AnalysisConfig config; + std::string model_dir = FLAGS_infer_model; + SetConfig(&config, model_dir, true); + + config.SwitchUseFeedFetchOps(false); + + int batch = 32; + int min_seq_len = 1; + int max_seq_len = 128; + int opt_seq_len = 128; + + std::vector min_shape = {1, min_seq_len, 1}; + std::vector max_shape = {batch, max_seq_len, 1}; + std::vector opt_shape = {batch, opt_seq_len, 1}; + // Set the input's min, max, opt shape + std::map> min_input_shape = { + {"read_file_0.tmp_0", min_shape}, + {"read_file_0.tmp_1", min_shape}, + {"read_file_0.tmp_2", min_shape}, + {"read_file_0.tmp_3", min_shape}}; + std::map> max_input_shape = { + {"read_file_0.tmp_0", max_shape}, + {"read_file_0.tmp_1", max_shape}, + {"read_file_0.tmp_2", max_shape}, + {"read_file_0.tmp_3", max_shape}}; + std::map> opt_input_shape = { + {"read_file_0.tmp_0", opt_shape}, + {"read_file_0.tmp_1", opt_shape}, + {"read_file_0.tmp_2", opt_shape}, + {"read_file_0.tmp_3", opt_shape}}; + + auto precision = AnalysisConfig::Precision::kFloat32; + if (with_fp16) { + precision = AnalysisConfig::Precision::kHalf; + } + config.EnableTensorRtEngine(1 << 30, 1, 12, precision, false, false); + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); + std::vector out_data; + run(config, &out_data); + + for (size_t i = 0; i < out_data.size(); i++) { + EXPECT_NEAR(result[i], out_data[i], 1e-4); + } +} + +TEST(AnalysisPredictor, no_fp16) { + std::vector result = {0.498667, 0.501333}; + trt_ernie(false, result); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 477a9162fe3f779d4006deb2e20b3a16f70cdf47..97d6e696b137dc7b4110efa9b9a25a34c6e6fdbb 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -6,6 +6,7 @@ register_operators(EXCLUDES fusion_conv_inception_op fused_fc_elementwise_layernorm_op multihead_matmul_op + skip_layernorm_op fused_embedding_eltwise_layernorm_op fusion_group_op fusion_gru_op @@ -40,6 +41,8 @@ if (WITH_GPU) # multihead_matmul_op op_library(multihead_matmul_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") + op_library(skip_layernorm_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n") op_library(fused_embedding_eltwise_layernorm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n") # fusion_group diff --git a/paddle/fluid/operators/fused/skip_layernorm_op.cc b/paddle/fluid/operators/fused/skip_layernorm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..442f359c0dac59a5d6ee6d071d1d1b63838b4963 --- /dev/null +++ b/paddle/fluid/operators/fused/skip_layernorm_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace operators { + +class SkipLayerNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ(context->HasInput("Y"), true, + platform::errors::InvalidArgument( + "Input(Y) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ( + context->HasInput("Scale"), true, + platform::errors::InvalidArgument( + "Input(Scale) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ( + context->HasInput("Bias"), true, + platform::errors::InvalidArgument( + "Input(Bias) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ( + context->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of MultiHeadMatMul should not be null.")); + + auto dim_input = context->GetInputDim("X"); + context->SetOutputDim("Out", dim_input); + context->ShareLoD("X", "Out"); + } +}; + +class SkipLayerNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The X input of SkipLayerNorm op"); + AddInput("Y", "The Y input of SkipLayerNorm op"); + AddInput("Scale", "The scale input of SkipLayerNorm op"); + AddInput("Bias", "The bias input of SkipLayerNorm op"); + AddOutput("Out", "The output of SkipLayerNorm op"); + AddAttr("epsilon", + "param epsilon of layer_norm op in " + "skip_layernorm_fuse_pass"); + AddAttr("begin_norm_axis", + "param begin_norm_axis of " + "layer_norm op in skip_layernorm_fuse_pass"); + AddComment(R"DOC( +SkipLayerNorm Operator. + +This op is used for skip_layernorm_fuse_pass, which fuse op pattern as followed. + + | | | | + other_op1 other_op2 other_op1 other_op2 + | | fuse \ / + |------elementwise_add -> skip_layernorm + | | + layer_norm other_op3 + | | + other_op3 + | + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(skip_layernorm, ops::SkipLayerNormOp, + ops::SkipLayerNormOpMaker); diff --git a/paddle/fluid/operators/fused/skip_layernorm_op.cu b/paddle/fluid/operators/fused/skip_layernorm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..856d5e694bdf13d333dbd5c701c8936482cde2a0 --- /dev/null +++ b/paddle/fluid/operators/fused/skip_layernorm_op.cu @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/bert_encoder_functor.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +template +class SkipLayerNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + using Tensor = framework::Tensor; + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *scale = context.Input("Scale"); + auto *bias = context.Input("Bias"); + + auto *X_d = X->data(); + auto *Y_d = Y->data(); + auto *scale_d = scale->data(); + auto *bias_d = bias->data(); + float epsilon = context.Attr("epsilon"); + int begin_norm_axis = context.Attr("begin_norm_axis"); + + auto *out = context.Output("Out"); + out->Resize(X->dims()); + auto *output_d = out->mutable_data(context.GetPlace()); + + size_t num = 1; + for (size_t i = 0; i < X->dims().size(); i++) { + num *= X->dims()[i]; + } + int hidden = X->dims()[2]; + auto &device_ctx = context.template device_context(); + operators::math::SkipLayerNormFunctor skip_layer_norm_func; + + skip_layer_norm_func(num, hidden, X_d, Y_d, scale_d, bias_d, output_d, + epsilon, device_ctx.stream()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + skip_layernorm, + ops::SkipLayerNormKernel); diff --git a/python/paddle/fluid/tests/unittests/ir/pass_test.py b/python/paddle/fluid/tests/unittests/ir/pass_test.py index c1c05c43359758b4c5fc226a08a2b844e2e721a7..aae1cc65c9220c712655c632de56d2b13244cf86 100644 --- a/python/paddle/fluid/tests/unittests/ir/pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/pass_test.py @@ -36,6 +36,7 @@ class PassTest(unittest.TestCase): self.fetch_list = None self.pass_names = None self.pass_attrs = {} + self.graph_attrs = {} self.fused_op_type = None self.num_fused_ops = -1 @@ -85,6 +86,8 @@ class PassTest(unittest.TestCase): def _apply_ir_passes(self): graph = core.Graph(self.main_program.desc) graph.set_not_owned("__param_scope__", fluid.global_scope()) + for attr_name, attr_value in self.graph_attrs.items(): + graph.set(attr_name, attr_value) if not isinstance(self.pass_names, list): self.pass_names = [self.pass_names] diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py index 888857e5a7246fb58622e05325177e64a3dc99e5..0aac6650f52dde7927cb06d3c621df754200c4ad 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_skip_layernorm_pass.py @@ -16,12 +16,14 @@ import unittest import numpy as np from pass_test import PassTest +import paddle import paddle.fluid as fluid import paddle.fluid.core as core class SkipLayerNormFusePassTest(PassTest): def setUp(self): + paddle.enable_static() with fluid.program_guard(self.main_program, self.startup_program): x = fluid.data( name="x", shape=[128, 768], dtype="float32", lod_level=0) @@ -34,6 +36,10 @@ class SkipLayerNormFusePassTest(PassTest): self.pass_names = "skip_layernorm_fuse_pass" self.fused_op_type = "skip_layernorm" self.num_fused_ops = 1 + self.graph_attrs = { + "embedding_eltwise_layernorm_fuse_pass_flag": True, + "multihead_matmul_fuse_pass_flag": True + } def test_check_program(self): use_gpu_set = [False]