fc_elementwise_layernorm_fuse_pass.cc 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h"
W
wanghuancoder 已提交
16

17
#include <string>
W
wanghuancoder 已提交
18

19
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
21

22 23 24 25 26 27 28 29
namespace paddle {
namespace framework {
namespace ir {
class Node;
}  // namespace ir
}  // namespace framework
}  // namespace paddle

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct FCElementwiseLayerNorm : public PatternBase {
  FCElementwiseLayerNorm(PDPattern *pattern, const std::string &name_scope)
      : PatternBase(pattern, name_scope, "fc_elementwise_layernorm") {}

  PDNode *operator()(PDNode *x);

  // declare operator node's name
  PATTERN_DECL_NODE(fused_fc_elementwise_layernorm);
  PATTERN_DECL_NODE(fc);
  PATTERN_DECL_NODE(elementwise);
  PATTERN_DECL_NODE(layer_norm);
  // declare variable node's name
  PATTERN_DECL_NODE(fc_w);
  PATTERN_DECL_NODE(fc_bias);
  PATTERN_DECL_NODE(fc_out);  // (x,fc_w,fc_bias) -> fc_out
  PATTERN_DECL_NODE(elementwise_input);
  PATTERN_DECL_NODE(
      elementwise_out);  // (fc_out,elementwise_input) -> elementwise_out
  PATTERN_DECL_NODE(layer_norm_bias);
  PATTERN_DECL_NODE(layer_norm_scale);
  PATTERN_DECL_NODE(layer_norm_out);
  PATTERN_DECL_NODE(layer_norm_mean);
  PATTERN_DECL_NODE(layer_norm_variance);
};

PDNode *FCElementwiseLayerNorm::operator()(PDNode *x) {
  // Create nodes for fc op.
  x->assert_is_op_input("fc", "Input");
  auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc");
  auto *fc_w_var = pattern->NewNode(fc_w_repr())
                       ->AsInput()
                       ->assert_is_persistable_var()
                       ->assert_is_op_input("fc", "W");
  auto *fc_bias_var = pattern->NewNode(fc_bias_repr())
                          ->AsInput()
                          ->assert_is_persistable_var()
                          ->assert_is_op_input("fc", "Bias");
  auto *fc_out_var = pattern->NewNode(fc_out_repr())->assert_is_op_output("fc");

  // Add links for fc op.
  fc->LinksFrom({x, fc_w_var, fc_bias_var}).LinksTo({fc_out_var});

  // Create nodes for elementwise_add op.
  fc_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
  auto *elementwise =
      pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
  auto *elementwise_input_var = pattern->NewNode(elementwise_input_repr())
                                    ->assert_is_op_input("elementwise_add");

  auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
                                  ->AsOutput()
                                  ->assert_is_op_output("elementwise_add");

  // Add links for elementwise_add op.
  elementwise->LinksFrom({fc_out_var, elementwise_input_var})
      .LinksTo({elementwise_out_var});

  // Create nodes for layer_norm op.
  elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm");
  auto *layer_norm =
      pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
  auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
                                  ->AsInput()
                                  ->assert_is_persistable_var()
                                  ->assert_is_op_input("layer_norm", "Bias");
  auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
                                   ->AsInput()
                                   ->assert_is_persistable_var()
                                   ->assert_is_op_input("layer_norm", "Scale");

  auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
                                 ->AsOutput()
                                 ->assert_is_op_output("layer_norm", "Y");
  auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
                                  ->AsOutput()
                                  ->assert_is_op_output("layer_norm", "Mean");
  auto *layer_norm_variance_var =
      pattern->NewNode(layer_norm_variance_repr())
          ->AsOutput()
          ->assert_is_op_output("layer_norm", "Variance");

  // Add links for layer_norm op.
  layer_norm
      ->LinksFrom(
          {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var})
      .LinksTo(
          {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
  return layer_norm_out_var;
}

}  // namespace patterns

template <typename T>
static bool IsEqual(const std::vector<T> &x, const std::vector<T> &y) {
  if (!(x.size() > 0U && y.size() > 0U) || x.size() != y.size()) {
    return false;
  }
  for (size_t i = 0; i < x.size(); ++i) {
    if (x[i] != y[i]) {
      return false;
    }
  }
  return true;
}

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() {
  AddOpCompat(OpCompat("fc"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("W")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("in_num_col_dims")
      .IsNumGE(1)
      .End()
      .AddAttr("activation_type")
      .IsStringIn({"relu", ""})
      .End();

  AddOpCompat(OpCompat("layer_norm"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Scale")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddOutput("Y")
      .IsTensor()
      .End()
      .AddOutput("Mean")
      .IsOptional()
      .End()
      .AddOutput("Variance")
      .IsOptional()
      .End()

      .AddAttr("epsilon")
      .IsNumGE(0.0f)
      .IsNumLE(0.001f)
      .End()
      .AddAttr("begin_norm_axis")
      .IsNumGT(0)
      .End();

  AddOpCompat(OpCompat("elementwise_add"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Y")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("axis")
200
      .IsIntIn({-1, 0})
201 202 203
      .End();
}

204
void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
205 206 207
  PADDLE_ENFORCE_NOT_NULL(graph,
                          platform::errors::InvalidArgument(
                              "Pointer to graph argument should not be NULL."));
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
  FusePassBase::Init("fc_elementwise_layernorm_fuse", graph);
  int found_subgraph_count = 0;

  GraphPatternDetector gpd;
  auto *x = gpd.mutable_pattern()
                ->NewNode("fc_elementwise_layernorm_fuse/x")
                ->AsInput()
                ->assert_is_op_input("fc", "Input");
  patterns::FCElementwiseLayerNorm fused_pattern(
      gpd.mutable_pattern(), "fc_elementwise_layernorm_fuse");
  fused_pattern(x);

  auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
                     Graph *graph) {
    if (subgraph.count(x) <= 0) {
      LOG(WARNING) << "The subgraph is empty.";
      return;
    }

227 228 229 230 231
    if (!IsCompat(subgraph, graph)) {
      LOG(WARNING) << "Pass in op compat failed.";
      return;
    }

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    VLOG(4) << "handle FCElementwiseLayerNorm fuse";
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(elementwise_input, elementwise_input,
                              fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
                              fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
                              fused_pattern);

    if (!IsEqual(fc_out->Var()->GetShape(),
                 elementwise_input->Var()->GetShape())) {
      return;
    }

    int begin_norm_axis =
256
        BOOST_GET_CONST(int, layer_norm->Op()->GetAttr("begin_norm_axis"));
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    auto layer_norm_x_dims = fc_out->Var()->GetShape();
    auto layer_norm_x_mat_dims = framework::flatten_to_2d(
        framework::make_ddim(layer_norm_x_dims), begin_norm_axis);
    if (fc_w->Var()->GetShape()[1] != layer_norm_x_mat_dims[1]) {
      return;
    }

    if (fc_out->outputs.size() > 1U || elementwise_out->outputs.size() > 1U) {
      // When fc_out or elementwise_out are used as input of other operators, we
      // cannon fuse.
      return;
    }

    std::unordered_set<const Node *> del_node_set;

    // Create an FusedFCElementwiseLayerNorm op node
    OpDesc new_desc;
    new_desc.SetType("fused_fc_elementwise_layernorm");

    // inputs
    new_desc.SetInput("X", {subgraph.at(x)->Name()});
    new_desc.SetInput("W", {fc_w->Name()});
    new_desc.SetInput("Bias0", {fc_bias->Name()});
    new_desc.SetInput("Y", {elementwise_input->Name()});
    new_desc.SetInput("Scale", {layer_norm_scale->Name()});
    new_desc.SetInput("Bias1", {layer_norm_bias->Name()});

    // outputs
    new_desc.SetOutput("Out", {layer_norm_out->Name()});
286 287
    bool lnm_has_output = layer_norm_mean->outputs.size() > 0U;
    if (lnm_has_output) {
288 289 290 291
      new_desc.SetOutput("Mean", {layer_norm_mean->Name()});
    } else {
      del_node_set.insert(layer_norm_mean);
    }
292 293
    bool lnv_has_output = layer_norm_variance->outputs.size() > 0U;
    if (lnv_has_output) {
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
      new_desc.SetOutput("Variance", {layer_norm_variance->Name()});
    } else {
      del_node_set.insert(layer_norm_variance);
    }

    // attrs
    new_desc.SetAttr("x_num_col_dims", fc->Op()->GetAttr("in_num_col_dims"));
    new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
    new_desc.SetAttr("begin_norm_axis",
                     layer_norm->Op()->GetAttr("begin_norm_axis"));
    new_desc.SetAttr("activation_type", fc->Op()->GetAttr("activation_type"));

    auto fused_node = graph->CreateOpNode(&new_desc);  // OpDesc will be copied.

    del_node_set.insert(fc);
    del_node_set.insert(elementwise);
    del_node_set.insert(layer_norm);
    del_node_set.insert(fc_out);
    del_node_set.insert(elementwise_out);
    GraphSafeRemoveNodes(graph, del_node_set);

    IR_NODE_LINK_TO(subgraph.at(x), fused_node);
    IR_NODE_LINK_TO(fc_w, fused_node);
    IR_NODE_LINK_TO(fc_bias, fused_node);
    IR_NODE_LINK_TO(elementwise_input, fused_node);
    IR_NODE_LINK_TO(layer_norm_scale, fused_node);
    IR_NODE_LINK_TO(layer_norm_bias, fused_node);
    IR_NODE_LINK_TO(fused_node, layer_norm_out);
322
    if (lnm_has_output) {
323 324
      IR_NODE_LINK_TO(fused_node, layer_norm_mean);
    }
325
    if (lnv_has_output) {
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
      IR_NODE_LINK_TO(fused_node, layer_norm_variance);
    }

    found_subgraph_count++;
  };

  gpd(graph, handler);
  AddStatis(found_subgraph_count);
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(fc_elementwise_layernorm_fuse_pass,
              paddle::framework::ir::FCElementwiseLayerNormFusePass);
342 343 344 345 346 347
REGISTER_PASS_CAPABILITY(fc_elementwise_layernorm_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
            .EQ("fc", 0)
            .LE("elementwise_add", 1)
            .EQ("layer_norm", 0));