diff --git a/cmake/operators.cmake b/cmake/operators.cmake index f60a6dc3f0c89dd345b04ea3a1e213de770e5760..9ae1f0134ef61c8e7279b3ce244772eae10f08d3 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -126,7 +126,7 @@ function(op_library TARGET) foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" -"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" +"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "fused_fc_reshape_elementwise_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8787aa8a94a44c2c36868fea4b88ede5f91b19f4..15972d696e0a0d7e4cb10556ad400802db75751f 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,7 @@ pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) +pass_library(fc_reshape_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) if(WITH_GPU) diff --git a/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..02ef2e2595953f3e3cd740aa31307e17323a9131 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.cc @@ -0,0 +1,302 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FCReshapeElementwiseLayerNorm : public PatternBase { + FCReshapeElementwiseLayerNorm(PDPattern *pattern, + const std::string &name_scope) + : PatternBase(pattern, name_scope, "fc_reshape_elementwise_layernorm") {} + + PDNode *operator()(PDNode *x); + + // declare operator node's name + PATTERN_DECL_NODE(fused_fc_reshape_elementwise_layernorm); + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(reshape2); + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE(fc_w); + PATTERN_DECL_NODE(fc_bias); + PATTERN_DECL_NODE(fc_out); // (x,fc_w,fc_bias) -> fc_out + + PATTERN_DECL_NODE(reshape_input); + PATTERN_DECL_NODE(reshape_out); + + PATTERN_DECL_NODE(elementwise_input); + PATTERN_DECL_NODE( + elementwise_out); // (fc_out,elementwise_input) -> elementwise_out + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; + +PDNode *FCReshapeElementwiseLayerNorm::operator()(PDNode *x) { + // Create nodes for fc op. + x->assert_is_op_input("fc", "Input"); + auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); + auto *fc_w_var = pattern->NewNode(fc_w_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "W"); + auto *fc_bias_var = pattern->NewNode(fc_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "Bias"); + auto *fc_out_var = pattern->NewNode(fc_out_repr())->assert_is_op_output("fc"); + + // Add links for fc op. + fc->LinksFrom({x, fc_w_var, fc_bias_var}).LinksTo({fc_out_var}); + + // Create nodes for elementwise_add op. + fc_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto *reshape2 = pattern->NewNode(reshape2_repr())->assert_is_op("reshape2"); + auto *reshape_out_var = pattern->NewNode(reshape_out_repr()) + ->AsOutput() + ->assert_is_op_output("reshape2"); + // Add links for reshape op. + reshape2->LinksFrom({fc_out_var}).LinksTo({reshape_out_var}); + reshape_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_input_var = pattern->NewNode(elementwise_input_repr()) + ->assert_is_op_input("elementwise_add"); + + auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + // Add links for elementwise_add op. + elementwise->LinksFrom({reshape_out_var, elementwise_input_var}) + .LinksTo({elementwise_out_var}); + + // Create nodes for layer_norm op. + elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm"); + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + return layer_norm_out_var; +} + +} // namespace patterns + +template +static bool IsEqual(const std::vector &x, const std::vector &y) { + if (!(x.size() > 0U && y.size() > 0U) || x.size() != y.size()) { + return false; + } + for (size_t i = 0; i < x.size(); ++i) { + if (x[i] != y[i]) { + return false; + } + } + return true; +} + +void FCReshapeElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init("fc_reshape_elementwise_layernorm_fuse", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("fc_reshape_elementwise_layernorm_fuse/x") + ->AsInput() + ->assert_is_op_input("fc", "Input"); + patterns::FCReshapeElementwiseLayerNorm fused_pattern( + gpd.mutable_pattern(), "fc_reshape_elementwise_layernorm_fuse"); + fused_pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle FCReshapeElementwiseLayerNorm fuse"; + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fused_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2, reshape2, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, fused_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_input, elementwise_input, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + fused_pattern); + + // if (!IsEqual(reshape_out->Var()->GetShape(), + // elementwise_input->Var()->GetShape())) { + // VLOG(4) << "shape check failed!!!!!"; + // + // VLOG(4) << "reshape_out shape: "; + // for (auto dim : reshape_out->Var()->GetShape()) { + // VLOG(4) << "dim: " << dim; + // } + // VLOG(4) << "elementwise_input shape: "; + // for (auto dim : elementwise_input->Var()->GetShape()) { + // VLOG(4) << "dim: " << dim; + // } + // return; + // } + // + // int begin_norm_axis = + // BOOST_GET_CONST(int, + // layer_norm->Op()->GetAttr("begin_norm_axis")); + // auto layer_norm_x_dims = fc_out->Var()->GetShape(); + // auto layer_norm_x_mat_dims = framework::flatten_to_2d( + // framework::make_ddim(layer_norm_x_dims), begin_norm_axis); + // if (fc_w->Var()->GetShape()[1] != layer_norm_x_mat_dims[1]) { + // return; + // } + + if (reshape_out->outputs.size() > 1U || + elementwise_out->outputs.size() > 1U) { + VLOG(4) << "output check failed!!!!!"; + VLOG(4) << "reshape_out->outputs.size(): " << reshape_out->outputs.size(); + VLOG(4) << "elementwise_out->outputs.size(): " + << elementwise_out->outputs.size(); + // When reshape_out or elementwise_out are used as input of other + // operators, we + // cannon fuse. + return; + } + + std::unordered_set del_node_set; + + // Create an FusedFCReshapeElementwiseLayerNorm op node + OpDesc new_desc; + new_desc.SetType("fused_fc_reshape_elementwise_layernorm"); + + // inputs + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("W", {fc_w->Name()}); + new_desc.SetInput("Bias0", {fc_bias->Name()}); + new_desc.SetInput("Y", {elementwise_input->Name()}); + new_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_desc.SetInput("Bias1", {layer_norm_bias->Name()}); + + // outputs + new_desc.SetOutput("Out", {layer_norm_out->Name()}); + bool lnm_has_output = layer_norm_mean->outputs.size() > 0U; + if (lnm_has_output) { + new_desc.SetOutput("Mean", {layer_norm_mean->Name()}); + } else { + del_node_set.insert(layer_norm_mean); + } + bool lnv_has_output = layer_norm_variance->outputs.size() > 0U; + if (lnv_has_output) { + new_desc.SetOutput("Variance", {layer_norm_variance->Name()}); + } else { + del_node_set.insert(layer_norm_variance); + } + + // attrs + new_desc.SetAttr("x_num_col_dims", fc->Op()->GetAttr("in_num_col_dims")); + new_desc.SetAttr("shape", reshape2->Op()->GetAttr("shape")); + new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("begin_norm_axis", + layer_norm->Op()->GetAttr("begin_norm_axis")); + new_desc.SetAttr("activation_type", fc->Op()->GetAttr("activation_type")); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(fc); + del_node_set.insert(reshape2); + del_node_set.insert(elementwise); + del_node_set.insert(layer_norm); + del_node_set.insert(fc_out); + del_node_set.insert(reshape_out); + del_node_set.insert(elementwise_out); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(fc_w, fused_node); + IR_NODE_LINK_TO(fc_bias, fused_node); + IR_NODE_LINK_TO(elementwise_input, fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + if (lnm_has_output) { + IR_NODE_LINK_TO(fused_node, layer_norm_mean); + } + if (lnv_has_output) { + IR_NODE_LINK_TO(fused_node, layer_norm_variance); + } + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fc_reshape_elementwise_layernorm_fuse_pass, + paddle::framework::ir::FCReshapeElementwiseLayerNormFusePass); diff --git a/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c89f6ab6c539fb128b9578f9a5e9ba6ddd50fba8 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_reshape_elementwise_layernorm_fuse_pass.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FCReshapeElementwiseLayerNormFusePass : public FusePassBase { + public: + virtual ~FCReshapeElementwiseLayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 98a36a3308dc539ee5aecad9e71f50be310e584c..38ea594988224851e9f96dbe7d4cbcd8e0f8aff2 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -103,16 +103,17 @@ const std::vector kLiteSubgraphPasses({ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // - "is_test_pass", // - "simplify_with_basic_ops_pass", // - "conv_affine_channel_fuse_pass", // - "conv_eltwiseadd_affine_channel_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_eltwiseadd_bn_fuse_pass", // - "embedding_eltwise_layernorm_fuse_pass", // - "multihead_matmul_fuse_pass_v2", // - "fc_fuse_pass", // - "fc_elementwise_layernorm_fuse_pass", // + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "conv_affine_channel_fuse_pass", // + "conv_eltwiseadd_affine_channel_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // + "embedding_eltwise_layernorm_fuse_pass", // + "multihead_matmul_fuse_pass_v2", // + "fc_fuse_pass", // + "fc_elementwise_layernorm_fuse_pass", // + "fc_reshape_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 3fc5f3bfc6b1633ffe835606bbac6118e6b32ca6..4fd0253bd77451e54f9bac464f9eb75c3d1a3128 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -5,6 +5,7 @@ register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op fused_fc_elementwise_layernorm_op + fused_fc_reshape_elementwise_layernorm_op multihead_matmul_op fused_embedding_eltwise_layernorm_op fusion_group_op @@ -36,6 +37,9 @@ if (WITH_GPU) # fused_fc_elementwise_layernorm_op op_library(fused_fc_elementwise_layernorm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_fc_elementwise_layernorm);\n") + # fused_fc_reshape_elementwise_layernorm_op + op_library(fused_fc_reshape_elementwise_layernorm_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_fc_reshape_elementwise_layernorm);\n") # multihead_matmul_op op_library(multihead_matmul_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") diff --git a/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf150409b1db4db6459aa6a3d57f38f02792905d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class FusedFCReshapeElementwiseLayerNormOp + : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + "Input(X) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + "Input(W) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + "Input(Y) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + "Output(Out) of fused_fc_elementwise_layernorm should not be null."); + + auto w_dims = ctx->GetInputDim("W"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "Fully Connected input should be 2-D tensor."); + + if (ctx->HasInput("Bias0")) { + auto bias0_dims = ctx->GetInputDim("Bias0"); + if (bias0_dims.size() == 2) { + PADDLE_ENFORCE_EQ(bias0_dims[0], 1, + "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias0_dims[1], w_dims[1], + "The shape of Bias must be [1, dim]."); + } else if (bias0_dims.size() == 1) { + PADDLE_ENFORCE_EQ(bias0_dims[0], w_dims[1], + "The shape of Bias must be [1, dim]."); + } + } + + auto x_dims = ctx->GetInputDim("X"); + int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + "The input tensor Input's rank of FCOp should be larger than " + "in_num_col_dims."); + + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + PADDLE_ENFORCE_EQ( + x_mat_dims[1], w_dims[0], + "Fully Connected input and weigth size do not match. %s, %s"); + + // std::vector fc_out_dims; + // for (int i = 0; i < x_num_col_dims; ++i) { + // fc_out_dims.push_back(x_dims[i]); + // } + // fc_out_dims.push_back(w_dims[1]); + + auto y_dims = ctx->GetInputDim("Y"); + // PADDLE_ENFORCE_EQ(framework::make_ddim(fc_out_dims), y_dims); + + auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); + PADDLE_ENFORCE_LT( + begin_norm_axis, y_dims.size(), + "'begin_norm_axis' must be less than the rank of Input(Y)."); + + auto y_mat_dim = framework::flatten_to_2d(y_dims, begin_norm_axis); + int64_t dim_0 = y_mat_dim[0]; + int64_t dim_1 = y_mat_dim[1]; + if (ctx->HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1); + + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], dim_1, + "scale should with right"); + } + } + if (ctx->HasInput("Bias1")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1").size(), 1); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1")[0], dim_1, + "bias should with right"); + } + } + + ctx->SetOutputDim("Out", y_dims); + if (ctx->HasOutput("Mean")) { + ctx->SetOutputDim("Mean", {dim_0}); + } + if (ctx->HasOutput("Variance")) { + ctx->SetOutputDim("Variance", {dim_0}); + } + ctx->ShareLoD("X", "Out"); + } +}; + +class FusedFCReshapeElementwiseLayerNormOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of fully connected operation"); + AddInput("W", + "(Tensor), The weight tensor of fully connected operation. It is " + "a 2-D Tensor with shape (I, O)"); + AddInput("Bias0", + "(Tensor, optional), The bias tensor of fully connecred " + "operation. It is a 1-D Tensor with shape (O), or a 2-D Tensor " + "with shape (1, O).") + .AsDispensable(); + AddInput("Y", + "(Tensor), The second input tensor of elementwise_add operation. " + "Note that the shape should be the same as fully connect's result " + "tensor."); + AddInput( + "Scale", + "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") + .AsDispensable(); + AddInput( + "Bias1", + "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") + .AsDispensable(); + AddOutput("Out", + "(Tensor), Output after normalization. The shape is the shame as " + "layer_norm's input."); + AddOutput("Mean", "(Tensor, optional), Mean of the current minibatch") + .AsDispensable(); + AddOutput("Variance", + "(Tensor, optional), Variance of the current minibatch") + .AsDispensable(); + AddAttr("x_num_col_dims", + "(int, default 1), This op can take tensors with more than " + "two dimensions as its inputs.") + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr("activation_type", + "Activation type used in fully connected operator.") + .SetDefault(""); + AddAttr>( + "shape", + "(std::vector) Target shape of reshape operator." + "It has the lowest priority compare with Input(Shape) and " + " Input(ShapeTensor).") + .SetDefault({}); + + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_GE(epsilon, 0.0f, + "'epsilon' should be between 0.0 and 0.001."); + PADDLE_ENFORCE_LE(epsilon, 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("begin_norm_axis", + "the axis of `begin_norm_axis ... Rank(Y) - 1` will be " + "normalized. `begin_norm_axis` splits the tensor(`X`) to a " + "matrix [N,H]. [default 1].") + .SetDefault(1) + .AddCustomChecker([](const int &begin_norm_axis) { + PADDLE_ENFORCE_GT(begin_norm_axis, 0, + "'begin_norm_axis' should be greater than zero."); + }); + AddComment(R"DOC( +fc_out <= fc(X, W, Bias0) +add_out <= elementwise_add(fc_out, Y) +(out, mean, variance) <= layer_norm(add_out, Scale, Bias1) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_fc_reshape_elementwise_layernorm, + ops::FusedFCReshapeElementwiseLayerNormOp, + ops::FusedFCReshapeElementwiseLayerNormOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2015fa4363b23c39ac4f05ae172e0afe7f48523d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_fc_reshape_elementwise_layernorm_op.cu @@ -0,0 +1,201 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cuda_device_function.h" + +namespace paddle { +namespace operators { + +template +static __device__ __forceinline__ T Relu(T x) { + return (x > 0) ? x : 0; +} + +static __device__ __forceinline__ float RealSqrt(float x) { return sqrtf(x); } +static __device__ __forceinline__ double RealSqrt(double x) { return sqrt(x); } + +template +struct PairForLayerNorm { + __device__ __forceinline__ PairForLayerNorm() {} + __device__ __forceinline__ PairForLayerNorm(const T& first, const T& second) + : first_(first), second_(second) {} + + T first_; + T second_; +}; + +template +struct PairForLayerNormAddFunctor { + __device__ __forceinline__ PairForLayerNorm operator()( + const PairForLayerNorm& p1, const PairForLayerNorm& p2) { + return PairForLayerNorm(p1.first_ + p2.first_, p1.second_ + p2.second_); + } +}; + +template +__global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, + const T* bias_1, + const T* scale, T* out, + T* mean, T* variance, int M, + int N, float epsilon) { + using BlockReduce = cub::BlockReduce, BlockDim>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T shared_mem[BlockDim + 2]; + + for (int i = blockIdx.x; i < M; i += gridDim.x) { + int index = i * N + threadIdx.x; + + // The fisrt BlockDim elements will be saved to shared memory. + int save_index = threadIdx.x; + T* save_ptr = shared_mem; + + T sum_i = 0; + T square_sum_i = 0; + for (int j = threadIdx.x; j < N; j += blockDim.x) { + T tmp_0 = out[index]; + // Add bias + T tmp_1 = bias_0 ? tmp_0 + bias_0[j] : tmp_0; + // Relu + T tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1; + // elementwise_add + T tmp_3 = tmp_2 + y[index]; + + // Save + save_ptr[save_index] = tmp_3; + save_ptr = out; + + index += blockDim.x; + save_index = index; + + // For layer_norm, reduce to calculate mean and std + sum_i += tmp_3; + square_sum_i += (tmp_3 * tmp_3); + } + + auto pair = BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(sum_i, square_sum_i), + PairForLayerNormAddFunctor()); + + if (threadIdx.x == 0) { + T mean_i = static_cast(pair.first_ / N); + T variance_i = static_cast(pair.second_ / N - mean_i * mean_i); + shared_mem[BlockDim] = mean_i; + shared_mem[BlockDim + 1] = variance_i; + if (mean) { + mean[blockIdx.x] = mean_i; + } + if (variance) { + variance[blockIdx.x] = variance_i; + } + } + __syncthreads(); + T mean_i = shared_mem[BlockDim]; + T std_i = static_cast(RealSqrt(shared_mem[BlockDim + 1] + epsilon)); + + index = i * N + threadIdx.x; + // First BlockDim elements loading from shared memory. + save_index = threadIdx.x; + save_ptr = shared_mem; + + // For layer_norm, calculate out + for (int j = threadIdx.x; j < N; j += blockDim.x) { + T tmp_0 = (save_ptr[save_index] - mean_i) / std_i; + T tmp_1 = scale ? scale[j] * tmp_0 : tmp_0; + out[index] = bias_1 ? tmp_1 + bias_1[j] : tmp_1; + + save_ptr = out; + index += blockDim.x; + save_index = index; + } + } +} + +template +class FusedFCReshapeElementwiseLayerNormOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* out = ctx.Output("Out"); + + auto w_dims = w->dims(); + int N = w_dims[1]; + int K = w_dims[0]; + int M = framework::product(x->dims()) / K; + + const T* x_data = x->data(); + const T* w_data = w->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + blas.GEMM(false, false, M, N, K, static_cast(1.0), x_data, K, w_data, N, + static_cast(0.0), out_data, N); + + auto* y = ctx.Input("Y"); + auto* bias_0 = ctx.Input("Bias0"); + auto* bias_1 = ctx.Input("Bias1"); + auto* scale = ctx.Input("Scale"); + + const T* y_data = y->data(); + const T* bias_0_data = bias_0 ? bias_0->data() : nullptr; + const T* bias_1_data = bias_1 ? bias_1->data() : nullptr; + const T* scale_data = scale ? scale->data() : nullptr; + + auto* mean = ctx.Output("Mean"); + auto* variance = ctx.Output("Variance"); + + T* mean_data = mean ? mean->mutable_data(ctx.GetPlace()) : nullptr; + T* variance_data = + variance ? variance->mutable_data(ctx.GetPlace()) : nullptr; + + bool with_relu = + (ctx.Attr("activation_type") == "relu") ? true : false; + float epsilon = ctx.Attr("epsilon"); + + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + if (with_relu) { + switch (platform::RoundToPowerOfTwo(N)) { + CUDA_LAUNCH_KERNEL_HELPER( + InplaceAddReluAddLayerNormKernel< + T, true, + kPowerOfTwoDim><<>>( + y_data, bias_0_data, bias_1_data, scale_data, out_data, + mean_data, variance_data, M, N, epsilon)); + } + } else { + switch (platform::RoundToPowerOfTwo(N)) { + CUDA_LAUNCH_KERNEL_HELPER( + InplaceAddReluAddLayerNormKernel< + T, false, + kPowerOfTwoDim><<>>( + y_data, bias_0_data, bias_1_data, scale_data, out_data, + mean_data, variance_data, M, N, epsilon)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_fc_reshape_elementwise_layernorm, + ops::FusedFCReshapeElementwiseLayerNormOpKernel); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 01a33a46521cd81d084f8971c47741b28a105d41..3e87e24cbb67f4b7b8638c066ea88db03ba74c3c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -253,6 +253,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { "It has the lowest priority compare with Input(Shape) and " " Input(ShapeTensor).") .SetDefault({}); + AddAttr("inplace", "").SetDefault(true); + AddComment(R"DOC( Reshape Operator. @@ -327,6 +329,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { class ReshapeKernel { public: void operator()(const framework::ExecutionContext &ctx) const { + auto inplace = ctx.Attr("inplace"); auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); @@ -360,6 +363,10 @@ class ReshapeKernel { out->Resize(out_dims); out->mutable_data(ctx.GetPlace(), in->type()); + if (inplace) { + return; + } + framework::TensorCopy( *in, ctx.GetPlace(), ctx.template device_context(), out);