diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 92dc614e09ffdb89eccd1af40e8643158cfbf94b..28e880fb51e7dceccdfe5e8ddeb4bbe92c460fa9 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index cae7e90255696d392c18d6af59f65a0bb1dc501e..9476256b0f0e5ac2290a814e73374fb1552ff5c2 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -73,6 +73,7 @@ pass_library(fillconstant_elementwisemul_fuse inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(simplify_with_basic_ops_pass base) +pass_library(fc_elementwise_layernorm_fuse_pass base) if(WITH_GPU) pass_library(cudnn_placement_pass base DEPS placement_pass_base) endif() @@ -122,6 +123,7 @@ cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_test cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) +cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass) if(WITH_GPU) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) endif() diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2c7606c30836f735844b8c6ef81c265ee295606 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc @@ -0,0 +1,259 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FCElementwiseLayerNorm : public PatternBase { + FCElementwiseLayerNorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "fc_elementwise_layernorm") {} + + PDNode *operator()(PDNode *x); + + // declare operator node's name + PATTERN_DECL_NODE(fused_fc_elementwise_layernorm); + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE(fc_w); + PATTERN_DECL_NODE(fc_bias); + PATTERN_DECL_NODE(fc_out); // (x,fc_w,fc_bias) -> fc_out + PATTERN_DECL_NODE(elementwise_input); + PATTERN_DECL_NODE( + elementwise_out); // (fc_out,elementwise_input) -> elementwise_out + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; + +PDNode *FCElementwiseLayerNorm::operator()(PDNode *x) { + // Create nodes for fc op. + x->assert_is_op_input("fc", "Input"); + auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); + auto *fc_w_var = pattern->NewNode(fc_w_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "W"); + auto *fc_bias_var = pattern->NewNode(fc_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "Bias"); + auto *fc_out_var = pattern->NewNode(fc_out_repr())->assert_is_op_output("fc"); + + // Add links for fc op. + fc->LinksFrom({x, fc_w_var, fc_bias_var}).LinksTo({fc_out_var}); + + // Create nodes for elementwise_add op. + fc_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_input_var = pattern->NewNode(elementwise_input_repr()) + ->assert_is_op_input("elementwise_add"); + + auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + // Add links for elementwise_add op. + elementwise->LinksFrom({fc_out_var, elementwise_input_var}) + .LinksTo({elementwise_out_var}); + + // Create nodes for layer_norm op. + elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm"); + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + return layer_norm_out_var; +} + +} // namespace patterns + +template +static bool IsEqual(const std::vector &x, const std::vector &y) { + if (!(x.size() > 0U && y.size() > 0U) || x.size() != y.size()) { + return false; + } + for (size_t i = 0; i < x.size(); ++i) { + if (x[i] != y[i]) { + return false; + } + } + return true; +} + +void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph); + FusePassBase::Init("fc_elementwise_layernorm_fuse", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("fc_elementwise_layernorm_fuse/x") + ->AsInput() + ->assert_is_op_input("fc", "Input"); + patterns::FCElementwiseLayerNorm fused_pattern( + gpd.mutable_pattern(), "fc_elementwise_layernorm_fuse"); + fused_pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle FCElementwiseLayerNorm fuse"; + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_input, elementwise_input, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + fused_pattern); + + if (!IsEqual(fc_out->Var()->GetShape(), + elementwise_input->Var()->GetShape())) { + return; + } + + int begin_norm_axis = + boost::get(layer_norm->Op()->GetAttr("begin_norm_axis")); + auto layer_norm_x_dims = fc_out->Var()->GetShape(); + auto layer_norm_x_mat_dims = framework::flatten_to_2d( + framework::make_ddim(layer_norm_x_dims), begin_norm_axis); + if (fc_w->Var()->GetShape()[1] != layer_norm_x_mat_dims[1]) { + return; + } + + if (fc_out->outputs.size() > 1U || elementwise_out->outputs.size() > 1U) { + // When fc_out or elementwise_out are used as input of other operators, we + // cannon fuse. + return; + } + + std::unordered_set del_node_set; + + // Create an FusedFCElementwiseLayerNorm op node + OpDesc new_desc; + new_desc.SetType("fused_fc_elementwise_layernorm"); + + // inputs + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("W", {fc_w->Name()}); + new_desc.SetInput("Bias0", {fc_bias->Name()}); + new_desc.SetInput("Y", {elementwise_input->Name()}); + new_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_desc.SetInput("Bias1", {layer_norm_bias->Name()}); + + // outputs + new_desc.SetOutput("Out", {layer_norm_out->Name()}); + if (layer_norm_mean->outputs.size() > 0U) { + new_desc.SetOutput("Mean", {layer_norm_mean->Name()}); + } else { + del_node_set.insert(layer_norm_mean); + } + if (layer_norm_variance->outputs.size() > 0U) { + new_desc.SetOutput("Variance", {layer_norm_variance->Name()}); + } else { + del_node_set.insert(layer_norm_variance); + } + + // attrs + new_desc.SetAttr("x_num_col_dims", fc->Op()->GetAttr("in_num_col_dims")); + new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("begin_norm_axis", + layer_norm->Op()->GetAttr("begin_norm_axis")); + new_desc.SetAttr("activation_type", fc->Op()->GetAttr("activation_type")); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(fc); + del_node_set.insert(elementwise); + del_node_set.insert(layer_norm); + del_node_set.insert(fc_out); + del_node_set.insert(elementwise_out); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(fc_w, fused_node); + IR_NODE_LINK_TO(fc_bias, fused_node); + IR_NODE_LINK_TO(elementwise_input, fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + if (layer_norm_mean->outputs.size() > 0U) { + IR_NODE_LINK_TO(fused_node, layer_norm_mean); + } + if (layer_norm_variance->outputs.size() > 0U) { + IR_NODE_LINK_TO(fused_node, layer_norm_variance); + } + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fc_elementwise_layernorm_fuse_pass, + paddle::framework::ir::FCElementwiseLayerNormFusePass); diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..ac4d0b39ee267c724636954263aa2dce9d9ec47f --- /dev/null +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FCElementwiseLayerNormFusePass : public FusePassBase { + public: + virtual ~FCElementwiseLayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1f822d7ca5cdc0a1bba1dbb5c646c61be244810 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass_tester.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(FCElementwiseLayerNormFusePass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, weights_0, bias_0) fc -> fc_out_0 + // (fc_out_0, weights_1, bias_1) fc -> fc_out_1 + // (fc_out_1, y) elementwise_add -> elementwise_out + // (elementwise_out, scale, bias_2) layer_norm -> + Layers layers; + auto* x = layers.data("x", {128, 768}); + auto* weights_0 = layers.data("weights_0", {768, 3072}, true); + auto* bias_0 = layers.data("bias_0", {3072}, true); + auto* fc_out_0 = layers.fc(x, weights_0, bias_0); // {128, 3072} + auto* weights_1 = layers.data("weights_1", {3072, 768}, true); + auto* bias_1 = layers.data("bias_1", {768}, true); + auto* fc_out_1 = + layers.fc(fc_out_0, weights_1, bias_1, 1, "relu"); // {128, 768} + fc_out_1->SetShape({128, 768}); + auto* y = layers.data("y", {128, 768}); + auto* elementwise_out = layers.elementwise_add(fc_out_1, y); + auto* scale = layers.data("scale", {768}, true); + auto* bias_2 = layers.data("bias_2", {768}, true); + layers.layer_norm(elementwise_out, scale, bias_2); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("fc_elementwise_layernorm_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = + GetNumOpNodes(graph, "fused_fc_elementwise_layernorm"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fc_elementwise_layernorm_fuse_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 38faf85cf08dfd7fed7a54999c84703d352df983..8df292b483b2842628de8aa7e92f9fb0d38373ff 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -137,6 +137,31 @@ struct Layers { return out; } + std::vector layer_norm(VarDesc* x, VarDesc* scale = nullptr, + VarDesc* bias = nullptr) { + VarDesc* y = lod_tensor(unique_name()); + VarDesc* mean = lod_tensor(unique_name()); + VarDesc* variance = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("layer_norm"); + op->SetInput("X", {x->Name()}); + if (scale) { + op->SetInput("Scale", {scale->Name()}); + } + if (bias) { + op->SetInput("Bias", {bias->Name()}); + } + op->SetOutput("Y", {y->Name()}); + op->SetOutput("Mean", {mean->Name()}); + op->SetOutput("Variance", {variance->Name()}); + op->SetAttr("epsilon", static_cast(1E-05)); + op->SetAttr("begin_norm_axis", static_cast(1)); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + std::vector outs = {y, mean, variance}; + return outs; + } + private: VarDesc* lod_tensor(std::string name, std::vector shape = {}, bool is_persistable = false) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 83d10abd5f3ef872d71175369e343373c2d07a07..e81a842814a64890e68bcccacf65a7b975aa7de9 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -107,6 +107,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "is_test_pass", // "simplify_with_basic_ops_pass", // "fc_fuse_pass", // + "fc_elementwise_layernorm_fuse_pass", // "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 42ab8e99662e1ec67b7a4061b274e84103a7d5b1..a31531c599a71e7da0697825a12ab86f5d809a51 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -1,5 +1,5 @@ include(operators) -register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op) +register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op fused_fc_elementwise_layernorm_op) if (WITH_GPU) op_library(fusion_transpose_flatten_concat_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n") @@ -7,4 +7,6 @@ if (WITH_GPU) op_library(fusion_conv_inception_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n") endif() + op_library(fused_fc_elementwise_layernorm_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_fc_elementwise_layernorm);\n") endif() diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c5d0c71226871a3af10c8ddc16269526f0d88b9 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class FusedFCElementwiseLayerNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + "Input(X) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + "Input(W) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + "Input(Y) of fused_fc_elementwise_layernorm should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + "Output(Out) of fused_fc_elementwise_layernorm should not be null."); + + auto w_dims = ctx->GetInputDim("W"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "Fully Connected input should be 2-D tensor."); + + if (ctx->HasInput("Bias0")) { + auto bias0_dims = ctx->GetInputDim("Bias0"); + if (bias0_dims.size() == 2) { + PADDLE_ENFORCE_EQ(bias0_dims[0], 1, + "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias0_dims[1], w_dims[1], + "The shape of Bias must be [1, dim]."); + } else if (bias0_dims.size() == 1) { + PADDLE_ENFORCE_EQ(bias0_dims[0], w_dims[1], + "The shape of Bias must be [1, dim]."); + } + } + + auto x_dims = ctx->GetInputDim("X"); + int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + "The input tensor Input's rank of FCOp should be larger than " + "in_num_col_dims."); + + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + PADDLE_ENFORCE_EQ( + x_mat_dims[1], w_dims[0], + "Fully Connected input and weigth size do not match. %s, %s"); + + std::vector fc_out_dims; + for (int i = 0; i < x_num_col_dims; ++i) { + fc_out_dims.push_back(x_dims[i]); + } + fc_out_dims.push_back(w_dims[1]); + + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::make_ddim(fc_out_dims), y_dims); + + auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); + PADDLE_ENFORCE_LT( + begin_norm_axis, y_dims.size(), + "'begin_norm_axis' must be less than the rank of Input(Y)."); + + auto y_mat_dim = framework::flatten_to_2d(y_dims, begin_norm_axis); + int64_t dim_0 = y_mat_dim[0]; + int64_t dim_1 = y_mat_dim[1]; + if (ctx->HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1); + + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], dim_1, + "scale should with right"); + } + } + if (ctx->HasInput("Bias1")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1").size(), 1); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias1")[0], dim_1, + "bias should with right"); + } + } + + ctx->SetOutputDim("Out", y_dims); + if (ctx->HasOutput("Mean")) { + ctx->SetOutputDim("Mean", {dim_0}); + } + if (ctx->HasOutput("Variance")) { + ctx->SetOutputDim("Variance", {dim_0}); + } + ctx->ShareLoD("X", "Out"); + } +}; + +class FusedFCElementwiseLayerNormOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of fully connected operation"); + AddInput("W", + "(Tensor), The weight tensor of fully connected operation. It is " + "a 2-D Tensor with shape (I, O)"); + AddInput("Bias0", + "(Tensor, optional), The bias tensor of fully connecred " + "operation. It is a 1-D Tensor with shape (O), or a 2-D Tensor " + "with shape (1, O).") + .AsDispensable(); + AddInput("Y", + "(Tensor), The second input tensor of elementwise_add operation. " + "Note that the shape should be the same as fully connect's result " + "tensor."); + AddInput( + "Scale", + "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") + .AsDispensable(); + AddInput( + "Bias1", + "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") + .AsDispensable(); + AddOutput("Out", + "(Tensor), Output after normalization. The shape is the shame as " + "layer_norm's input."); + AddOutput("Mean", "(Tensor, optional), Mean of the current minibatch") + .AsDispensable(); + AddOutput("Variance", + "(Tensor, optional), Variance of the current minibatch") + .AsDispensable(); + AddAttr("x_num_col_dims", + "(int, default 1), This op can take tensors with more than " + "two dimensions as its inputs.") + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr("activation_type", + "Activation type used in fully connected operator.") + .SetDefault(""); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_GE(epsilon, 0.0f, + "'epsilon' should be between 0.0 and 0.001."); + PADDLE_ENFORCE_LE(epsilon, 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("begin_norm_axis", + "the axis of `begin_norm_axis ... Rank(Y) - 1` will be " + "normalized. `begin_norm_axis` splits the tensor(`X`) to a " + "matrix [N,H]. [default 1].") + .SetDefault(1) + .AddCustomChecker([](const int &begin_norm_axis) { + PADDLE_ENFORCE_GT(begin_norm_axis, 0, + "'begin_norm_axis' should be greater than zero."); + }); + AddComment(R"DOC( +fc_out <= fc(X, W, Bias0) +add_out <= elementwise_add(fc_out, Y) +(out, mean, variance) <= layer_norm(add_out, Scale, Bias1) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_fc_elementwise_layernorm, + ops::FusedFCElementwiseLayerNormOp, + ops::FusedFCElementwiseLayerNormOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..74d345257a4f3aeb8eb9db9a7b4e0060e4ba1621 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu @@ -0,0 +1,201 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cuda_device_function.h" + +namespace paddle { +namespace operators { + +template +static __device__ __forceinline__ T Relu(T x) { + return (x > 0) ? x : 0; +} + +static __device__ __forceinline__ float RealSqrt(float x) { return sqrtf(x); } +static __device__ __forceinline__ double RealSqrt(double x) { return sqrt(x); } + +template +struct PairForLayerNorm { + __device__ __forceinline__ PairForLayerNorm() {} + __device__ __forceinline__ PairForLayerNorm(const T& first, const T& second) + : first_(first), second_(second) {} + + T first_; + T second_; +}; + +template +struct PairForLayerNormAddFunctor { + __device__ __forceinline__ PairForLayerNorm operator()( + const PairForLayerNorm& p1, const PairForLayerNorm& p2) { + return PairForLayerNorm(p1.first_ + p2.first_, p1.second_ + p2.second_); + } +}; + +template +__global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, + const T* bias_1, + const T* scale, T* out, + T* mean, T* variance, int M, + int N, float epsilon) { + using BlockReduce = cub::BlockReduce, BlockDim>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T shared_mem[BlockDim + 2]; + + for (int i = blockIdx.x; i < M; i += gridDim.x) { + int index = i * N + threadIdx.x; + + // The fisrt BlockDim elements will be saved to shared memory. + int save_index = threadIdx.x; + T* save_ptr = shared_mem; + + double sum_i = 0; + double square_sum_i = 0; + for (int j = threadIdx.x; j < N; j += blockDim.x) { + T tmp_0 = out[index]; + // Add bias + T tmp_1 = bias_0 ? tmp_0 + bias_0[j] : tmp_0; + // Relu + T tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1; + // elementwise_add + T tmp_3 = tmp_2 + y[index]; + + // Save + save_ptr[save_index] = tmp_3; + save_ptr = out; + + index += blockDim.x; + save_index = index; + + // For layer_norm, reduce to calculate mean and std + sum_i += tmp_3; + square_sum_i += (tmp_3 * tmp_3); + } + + auto pair = BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(sum_i, square_sum_i), + PairForLayerNormAddFunctor()); + + if (threadIdx.x == 0) { + T mean_i = static_cast(pair.first_ / N); + T variance_i = static_cast(pair.second_ / N - mean_i * mean_i); + shared_mem[BlockDim] = mean_i; + shared_mem[BlockDim + 1] = variance_i; + if (mean) { + mean[blockIdx.x] = mean_i; + } + if (variance) { + variance[blockIdx.x] = variance_i; + } + } + __syncthreads(); + T mean_i = shared_mem[BlockDim]; + T std_i = static_cast(RealSqrt(shared_mem[BlockDim + 1] + epsilon)); + + index = i * N + threadIdx.x; + // First BlockDim elements loading from shared memory. + save_index = threadIdx.x; + save_ptr = shared_mem; + + // For layer_norm, calculate out + for (int j = threadIdx.x; j < N; j += blockDim.x) { + T tmp_0 = (save_ptr[save_index] - mean_i) / std_i; + T tmp_1 = scale ? scale[j] * tmp_0 : tmp_0; + out[index] = bias_1 ? tmp_1 + bias_1[j] : tmp_1; + + save_ptr = out; + index += blockDim.x; + save_index = index; + } + } +} + +template +class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* out = ctx.Output("Out"); + + auto w_dims = w->dims(); + int N = w_dims[1]; + int K = w_dims[0]; + int M = framework::product(x->dims()) / K; + + const T* x_data = x->data(); + const T* w_data = w->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + blas.GEMM(false, false, M, N, K, static_cast(1.0), x_data, K, w_data, N, + static_cast(0.0), out_data, N); + + auto* y = ctx.Input("Y"); + auto* bias_0 = ctx.Input("Bias0"); + auto* bias_1 = ctx.Input("Bias1"); + auto* scale = ctx.Input("Scale"); + + const T* y_data = y->data(); + const T* bias_0_data = bias_0 ? bias_0->data() : nullptr; + const T* bias_1_data = bias_1 ? bias_1->data() : nullptr; + const T* scale_data = scale ? scale->data() : nullptr; + + auto* mean = ctx.Output("Mean"); + auto* variance = ctx.Output("Variance"); + + T* mean_data = mean ? mean->mutable_data(ctx.GetPlace()) : nullptr; + T* variance_data = + variance ? variance->mutable_data(ctx.GetPlace()) : nullptr; + + bool with_relu = + (ctx.Attr("activation_type") == "relu") ? true : false; + float epsilon = ctx.Attr("epsilon"); + + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + if (with_relu) { + switch (platform::RoundToPowerOfTwo(N)) { + CUDA_LAUNCH_KERNEL_HELPER( + InplaceAddReluAddLayerNormKernel< + T, true, + kPowerOfTwoDim><<>>( + y_data, bias_0_data, bias_1_data, scale_data, out_data, + mean_data, variance_data, M, N, epsilon)); + } + } else { + switch (platform::RoundToPowerOfTwo(N)) { + CUDA_LAUNCH_KERNEL_HELPER( + InplaceAddReluAddLayerNormKernel< + T, false, + kPowerOfTwoDim><<>>( + y_data, bias_0_data, bias_1_data, scale_data, out_data, + mean_data, variance_data, M, N, epsilon)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm, + ops::FusedFCElementwiseLayerNormOpKernel, + ops::FusedFCElementwiseLayerNormOpKernel); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_layer_norm_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_layer_norm_ngraph_op.py index a59eaade1bbb8f14765aea5d3c9b00b95b7078b1..ffdc64a23018521086219e3690a81e9b77aca3a7 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_layer_norm_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_layer_norm_ngraph_op.py @@ -15,16 +15,16 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_layer_norm_op import TestLayerNormdOp +from test_layer_norm_op import TestLayerNormOp -class TestLayerNormNGRAPHOp(TestLayerNormdOp): +class TestLayerNormNGRAPHOp(TestLayerNormOp): def setUp(self): super(TestLayerNormNGRAPHOp, self).setUp() self.use_cudnn = False -del TestLayerNormdOp +del TestLayerNormOp if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_fc_elementwise_layernorm_op.py b/python/paddle/fluid/tests/unittests/test_fused_fc_elementwise_layernorm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9604201e04e1dc0e176fdb899275a9cadc325ad1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_fc_elementwise_layernorm_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from test_fc_op import fc_refer, MatrixGenerate +from test_layer_norm_op import _reference_layer_norm_naive + +np.random.random(123) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedFCElementwiseLayerNormOp(OpTest): + def config(self): + self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2) + self.y_shape = [1, 15] + self.begin_norm_axis = 1 + + def setUp(self): + self.op_type = "fused_fc_elementwise_layernorm" + self.config() + + # Attr of layer_norm + epsilon = 0.00001 + + # fc + fc_out = fc_refer(self.matrix, True, True) + # elementwise_add + y = np.random.random_sample(self.y_shape).astype(np.float32) + add_out = fc_out + y + # layer_norm + scale_shape = [np.prod(self.y_shape[self.begin_norm_axis:])] + scale = np.random.random_sample(scale_shape).astype(np.float32) + bias_1 = np.random.random_sample(scale_shape).astype(np.float32) + out, mean, variance = _reference_layer_norm_naive( + add_out, scale, bias_1, epsilon, self.begin_norm_axis) + + self.inputs = { + "X": self.matrix.input, + "W": self.matrix.weights, + "Bias0": self.matrix.bias, + "Y": y, + "Scale": scale, + "Bias1": bias_1 + } + self.attrs = { + "activation_type": "relu", + "epsilon": epsilon, + "begin_norm_axis": self.begin_norm_axis + } + self.outputs = {"Out": out, "Mean": mean, "Variance": variance} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=2e-3) + + +class TestFusedFCElementwiseLayerNormOp2(TestFusedFCElementwiseLayerNormOp): + def config(self): + self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1) + self.y_shape = [4, 6] + self.begin_norm_axis = 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index fdc5d3679e71036cf1e1d813e654815eb03dd45c..ff68599dce6bdb7ba7a6f35cc05f69ec8f543ab4 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -71,7 +71,7 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): return grad_x, d_scale, d_bias -class TestLayerNormdOp(unittest.TestCase): +class TestLayerNormOp(unittest.TestCase): def setUp(self): self.use_cudnn = True