未验证 提交 ec672e88 编写于 作者: S Shang Zhizhou 提交者: GitHub

Skip layernorm to 1.8 (#28583)

* 裁剪transformer模型trt支持;修复tensorRT不支持DeletePass的bug (#28517)

* skip_layernorm_op done

* add unittest

* slice op convertor support trt < 6

* skip_layernorm only work in ernie

* fix unittest

* fix unittest
上级 0a42986c
...@@ -118,7 +118,7 @@ function(op_library TARGET) ...@@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op") "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -323,6 +323,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -323,6 +323,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_); int fusion_count = patterns::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
graph->Set(kEmbEltwiseLayernormPass, new bool(true));
}
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -694,7 +694,11 @@ void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { ...@@ -694,7 +694,11 @@ void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal( platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null.")); "During the multiheadMatmul pass, The scope should not be null."));
patterns::BuildFusionV2(graph, name_scope_, scope); int fusion_count = patterns::BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
} }
} // namespace ir } // namespace ir
......
...@@ -34,6 +34,9 @@ struct PassRegistrar; ...@@ -34,6 +34,9 @@ struct PassRegistrar;
typedef std::unordered_set<std::string> PassRecorder; typedef std::unordered_set<std::string> PassRecorder;
constexpr char kPassRecorder[] = "pass_recorder"; constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
class Pass { class Pass {
public: public:
......
...@@ -132,6 +132,14 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -132,6 +132,14 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern); fused_pattern);
// check if is in ernie or not
if (!graph->Has(kEmbEltwiseLayernormPass) ||
!graph->Has(kMultiheadMatmulPass)) {
LOG(INFO) << "The skip_layernorm_fuse_pass is only supported in "
<< "Ernie/Bert model. Just skip this pass.";
return;
}
std::unordered_set<const Node *> del_node_set; std::unordered_set<const Node *> del_node_set;
// Create an SkipLayerNorm op node // Create an SkipLayerNorm op node
......
...@@ -35,6 +35,8 @@ TEST(SkipLayerNormFusePass, basic) { ...@@ -35,6 +35,8 @@ TEST(SkipLayerNormFusePass, basic) {
layers.layer_norm(elementwise_out, scale, bias); layers.layer_norm(elementwise_out, scale, bias);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set(kEmbEltwiseLayernormPass, new bool(true));
graph->Set(kMultiheadMatmulPass, new bool(true));
auto pass = PassRegistry::Instance().Get("skip_layernorm_fuse_pass"); auto pass = PassRegistry::Instance().Get("skip_layernorm_fuse_pass");
int num_nodes_before = graph->Nodes().size(); int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
......
...@@ -114,20 +114,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -114,20 +114,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes";
bool has_fused_embedding_eltwise_layernorm = false;
bool has_multihead_matmul = false;
for (auto *node : subgraph) { for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp(); auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp(); auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto(); *new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto();
if (!has_fused_embedding_eltwise_layernorm &&
op->Type() == "fused_embedding_eltwise_layernorm") {
has_fused_embedding_eltwise_layernorm = true;
}
if (!has_multihead_matmul && op->Type() == "multihead_matmul") {
has_multihead_matmul = true;
}
} }
// Then, we will use the input_names_with_id and output_names_with_id to // Then, we will use the input_names_with_id and output_names_with_id to
...@@ -310,8 +301,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -310,8 +301,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
min_input_shape, max_input_shape, opt_input_shape, min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16); disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss")); trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithErnie(has_multihead_matmul && trt_engine->SetWithErnie(
has_fused_embedding_eltwise_layernorm); graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
bool need_serialize = (use_static_engine && !load_from_memory); bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) { if (need_serialize) {
......
...@@ -170,7 +170,13 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -170,7 +170,13 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER #undef CP_MEMBER
Update(); // Update();
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so just copy the passes.
pass_builder_->ClearPasses();
for (const std::string &pass : other.pass_builder()->AllPasses()) {
pass_builder_->AppendPass(pass);
}
} }
void AnalysisConfig::EnableCUDNN() { void AnalysisConfig::EnableCUDNN() {
......
...@@ -78,6 +78,7 @@ class SliceOpConverter : public OpConverter { ...@@ -78,6 +78,7 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
if (engine_->use_oss() && engine_->with_ernie()) { if (engine_->use_oss() && engine_->with_ernie()) {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0)); // plugin_inputs.emplace_back(trans_layer->getOutput(0));
...@@ -92,17 +93,16 @@ class SliceOpConverter : public OpConverter { ...@@ -92,17 +93,16 @@ class SliceOpConverter : public OpConverter {
layer = engine_->AddPluginV2(plugin_inputs.data(), plugin_inputs.size(), layer = engine_->AddPluginV2(plugin_inputs.data(), plugin_inputs.size(),
plugin); plugin);
} else { } else {
#if IS_TRT_VERSION_GE(6000)
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin = plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16);
layer = engine_->AddPluginV2(&input, 1, plugin); layer = engine_->AddPluginV2(&input, 1, plugin);
}
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0")); "your TRT version is no less than 6.0"));
#endif #endif
}
} else { } else {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePlugin* plugin = plugin::SlicePlugin* plugin =
......
...@@ -425,6 +425,15 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -425,6 +425,15 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4) ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4)
set(TEST_TRT_TRANSFORMER_PRUNE_MODEL "${TRT_MODEL_INSTALL_DIR}/transformer_prune")
if (NOT EXISTS ${TEST_TRT_TRANSFORMER_PRUNE_MODEL}/transformer_prune.tar.gz)
inference_download_and_uncompress(${TEST_TRT_TRANSFORMER_PRUNE_MODEL} ${INFERENCE_URL}/tensorrt_test "transformer_prune.tar.gz")
endif()
inference_analysis_test(test_trt_dynamic_shape_transformer_prune SRCS trt_dynamic_shape_transformer_prune_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_TRANSFORMER_PRUNE_MODEL}/transformer_prune)
set(TEST_TRT_ERNIE_UNSER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_unserialized/") set(TEST_TRT_ERNIE_UNSER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_unserialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL}/ernie_model_4_unserialized.tgz) if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL}/ernie_model_4_unserialized.tgz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int run_batch = 1;
const int run_seq_len = 128;
std::vector<int64_t> tmp_input;
std::vector<float> tmp_four_input;
tmp_input.reserve(run_batch * run_seq_len);
tmp_four_input.reserve(run_batch * run_seq_len);
int64_t i0[run_seq_len] = {
1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321,
4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2,
75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2};
int64_t i1[run_seq_len] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39};
float i3[run_seq_len] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
// first input
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({run_batch, run_seq_len, 1});
input_t->copy_from_cpu(i0);
// second input
auto input_t2 = predictor->GetInputTensor(input_names[1]);
input_t2->Reshape({run_batch, run_seq_len, 1});
input_t2->copy_from_cpu(i1);
// third input.
auto input_t3 = predictor->GetInputTensor(input_names[2]);
input_t3->Reshape({run_batch, run_seq_len, 1});
input_t3->copy_from_cpu(i2);
auto input_t4 = predictor->GetInputTensor(input_names[3]);
input_t4->Reshape({run_batch, run_seq_len, 1});
input_t4->copy_from_cpu(i3);
ASSERT_TRUE(predictor->ZeroCopyRun());
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data->resize(out_num);
output_t->copy_to_cpu(out_data->data());
}
void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config;
std::string model_dir = FLAGS_infer_model;
SetConfig(&config, model_dir, true);
config.SwitchUseFeedFetchOps(false);
int batch = 32;
int min_seq_len = 1;
int max_seq_len = 128;
int opt_seq_len = 128;
std::vector<int> min_shape = {1, min_seq_len, 1};
std::vector<int> max_shape = {batch, max_seq_len, 1};
std::vector<int> opt_shape = {batch, opt_seq_len, 1};
// Set the input's min, max, opt shape
std::map<std::string, std::vector<int>> min_input_shape = {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"read_file_0.tmp_3", min_shape}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"read_file_0.tmp_3", max_shape}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"read_file_0.tmp_3", opt_shape}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf;
}
config.EnableTensorRtEngine(1 << 30, 1, 12, precision, false, false);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
std::vector<float> out_data;
run(config, &out_data);
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-4);
}
}
TEST(AnalysisPredictor, no_fp16) {
std::vector<float> result = {0.498667, 0.501333};
trt_ernie(false, result);
}
} // namespace inference
} // namespace paddle
...@@ -6,6 +6,7 @@ register_operators(EXCLUDES ...@@ -6,6 +6,7 @@ register_operators(EXCLUDES
fusion_conv_inception_op fusion_conv_inception_op
fused_fc_elementwise_layernorm_op fused_fc_elementwise_layernorm_op
multihead_matmul_op multihead_matmul_op
skip_layernorm_op
fused_embedding_eltwise_layernorm_op fused_embedding_eltwise_layernorm_op
fusion_group_op) fusion_group_op)
...@@ -34,6 +35,8 @@ if (WITH_GPU) ...@@ -34,6 +35,8 @@ if (WITH_GPU)
# multihead_matmul_op # multihead_matmul_op
op_library(multihead_matmul_op) op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
op_library(skip_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n")
op_library(fused_embedding_eltwise_layernorm_op) op_library(fused_embedding_eltwise_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n")
# fusion_group # fusion_group
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace operators {
class SkipLayerNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
platform::errors::InvalidArgument(
"Input(Y) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Scale"), true,
platform::errors::InvalidArgument(
"Input(Scale) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of MultiHeadMatMul should not be null."));
auto dim_input = context->GetInputDim("X");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("X", "Out");
}
};
class SkipLayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The X input of SkipLayerNorm op");
AddInput("Y", "The Y input of SkipLayerNorm op");
AddInput("Scale", "The scale input of SkipLayerNorm op");
AddInput("Bias", "The bias input of SkipLayerNorm op");
AddOutput("Out", "The output of SkipLayerNorm op");
AddAttr<float>("epsilon",
"param epsilon of layer_norm op in "
"skip_layernorm_fuse_pass");
AddAttr<int>("begin_norm_axis",
"param begin_norm_axis of "
"layer_norm op in skip_layernorm_fuse_pass");
AddComment(R"DOC(
SkipLayerNorm Operator.
This op is used for skip_layernorm_fuse_pass, which fuse op pattern as followed.
| | | |
other_op1 other_op2 other_op1 other_op2
| | fuse \ /
|------elementwise_add -> skip_layernorm
| |
layer_norm other_op3
| |
other_op3
|
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(skip_layernorm, ops::SkipLayerNormOp,
ops::SkipLayerNormOpMaker);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SkipLayerNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
using Tensor = framework::Tensor;
auto *X = context.Input<framework::Tensor>("X");
auto *Y = context.Input<framework::Tensor>("Y");
auto *scale = context.Input<framework::Tensor>("Scale");
auto *bias = context.Input<framework::Tensor>("Bias");
auto *X_d = X->data<T>();
auto *Y_d = Y->data<T>();
auto *scale_d = scale->data<T>();
auto *bias_d = bias->data<T>();
float epsilon = context.Attr<float>("epsilon");
int begin_norm_axis = context.Attr<int>("begin_norm_axis");
auto *out = context.Output<framework::Tensor>("Out");
out->Resize(X->dims());
auto *output_d = out->mutable_data<T>(context.GetPlace());
size_t num = 1;
for (size_t i = 0; i < X->dims().size(); i++) {
num *= X->dims()[i];
}
int hidden = X->dims()[2];
auto &device_ctx = context.template device_context<DeviceContext>();
operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func;
skip_layer_norm_func(num, hidden, X_d, Y_d, scale_d, bias_d, output_d,
epsilon, device_ctx.stream());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
skip_layernorm,
ops::SkipLayerNormKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -36,6 +36,7 @@ class PassTest(unittest.TestCase): ...@@ -36,6 +36,7 @@ class PassTest(unittest.TestCase):
self.fetch_list = None self.fetch_list = None
self.pass_names = None self.pass_names = None
self.pass_attrs = {} self.pass_attrs = {}
self.graph_attrs = {}
self.fused_op_type = None self.fused_op_type = None
self.num_fused_ops = -1 self.num_fused_ops = -1
...@@ -85,6 +86,8 @@ class PassTest(unittest.TestCase): ...@@ -85,6 +86,8 @@ class PassTest(unittest.TestCase):
def _apply_ir_passes(self): def _apply_ir_passes(self):
graph = core.Graph(self.main_program.desc) graph = core.Graph(self.main_program.desc)
graph.set_not_owned("__param_scope__", fluid.global_scope()) graph.set_not_owned("__param_scope__", fluid.global_scope())
for attr_name, attr_value in self.graph_attrs.items():
graph.set(attr_name, attr_value)
if not isinstance(self.pass_names, list): if not isinstance(self.pass_names, list):
self.pass_names = [self.pass_names] self.pass_names = [self.pass_names]
......
...@@ -16,6 +16,7 @@ import unittest ...@@ -16,6 +16,7 @@ import unittest
import numpy as np import numpy as np
from pass_test import PassTest from pass_test import PassTest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -34,6 +35,10 @@ class SkipLayerNormFusePassTest(PassTest): ...@@ -34,6 +35,10 @@ class SkipLayerNormFusePassTest(PassTest):
self.pass_names = "skip_layernorm_fuse_pass" self.pass_names = "skip_layernorm_fuse_pass"
self.fused_op_type = "skip_layernorm" self.fused_op_type = "skip_layernorm"
self.num_fused_ops = 1 self.num_fused_ops = 1
self.graph_attrs = {
"embedding_eltwise_layernorm_fuse_pass_flag": True,
"multihead_matmul_fuse_pass_flag": True
}
def test_check_program(self): def test_check_program(self):
use_gpu_set = [False] use_gpu_set = [False]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册