dlnne_subgraph_pass.cc 23.5 KB
Newer Older
D
denglin-github 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14
#include <algorithm>
D
denglin-github 已提交
15 16
#include <fstream>
#include <iostream>
17 18
#include <map>
#include <set>
D
denglin-github 已提交
19

D
denglin-github 已提交
20
#include "paddle/fluid/framework/attribute.h"
D
denglin-github 已提交
21 22 23
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
D
denglin-github 已提交
24
#include "paddle/fluid/framework/type_defs.h"
D
denglin-github 已提交
25
#include "paddle/fluid/inference/analysis/helper.h"
D
denglin-github 已提交
26
#include "paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h"
D
denglin-github 已提交
27 28 29 30
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace inference {
D
denglin-github 已提交
31
namespace analysis {
D
denglin-github 已提交
32

D
denglin-github 已提交
33
using framework::ir::Node;
D
denglin-github 已提交
34

D
denglin-github 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
void analysis::DlnneSubgraphPass::InferShapeForDlnneMainGraph() const {
  // copy from paddle2onnx
  static std::unordered_set<std::string> OP_WITHOUT_KERNEL_SET = {
      "feed",
      "fetch",
      "recurrent",
      "go",
      "rnn_memory_helper_grad",
      "conditional_block",
      "while",
      "send",
      "recv",
      "listen_and_serv",
      "fl_listen_and_serv",
      "ncclInit",
      "select",
      "checkpoint_notify",
      "gen_bkcl_id",
      "c_gen_bkcl_id",
      "gen_nccl_id",
      "c_gen_nccl_id",
      "c_comm_init",
      "c_sync_calc_stream",
      "c_sync_comm_stream",
      "queue_generator",
      "dequeue",
      "enqueue",
      "heter_listen_and_serv",
      "c_wait_comm",
      "c_wait_compute"};

  std::string bilinear_interp_v2_type = "bilinear_interp_v2";
  auto input_dict =
      Get<std::map<std::string, std::vector<int64_t>>>("input_shape_dict");

  framework::ProgramDesc *global_program =
      Get<framework::ProgramDesc *>("program");
  auto block = global_program->MutableBlock(framework::kRootBlockIndex);
  for (auto kv : input_dict) {
    auto var = block->FindVar(kv.first);
    if (var != nullptr) {
      var->SetShape(kv.second);
    } else {
      VLOG(4) << "input_name:" << kv.first << " not find in all input vars";
    }
D
denglin-github 已提交
80 81
  }

D
denglin-github 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  std::vector<framework::OpDesc *> all_ops = block->AllOps();

  for (size_t i = 0; i < block->OpSize(); i++) {
    // the output_shape of bilinear_interp_v2 cannot be inferd by input shape,
    // it also need the value of input tensor, so when call OpDesc->InferShape,
    // the output_shape of bilinear_interp_v2 is still dynamic, here we try to
    // infer the output_shape of bilinear_interp_v2 infer shape for
    // bilinear_interp_v2
    if (block->Op(i)->Type() == bilinear_interp_v2_type) {
      framework::VariableNameMap input_name_map = block->Op(i)->Inputs();
      std::vector<std::string> input_name_vec = input_name_map["OutSize"];
      PADDLE_ENFORCE_EQ(
          input_name_vec.size(),
          1,
          platform::errors::PreconditionNotMet(
              "The 'bilinear_interp_v2 op' input 'OutSize' size must be 1 "));

      // find shape->slice->bilinear_interp_v2 pattern
      int start_id = 0;
      int end_id = 0;
      std::vector<std::string> slice_input_name_vec;
      for (auto *i_op : all_ops) {
        if (i_op->HasOutput("Out")) {
          auto it = find(i_op->Output("Out").begin(),
                         i_op->Output("Out").end(),
                         input_name_vec[0]);
          if (it != i_op->Output("Out").end()) {
            slice_input_name_vec = i_op->Input("Input");
            PADDLE_ENFORCE_EQ(
                slice_input_name_vec.size(),
                1,
                platform::errors::PreconditionNotMet(
                    "The 'slice op' input 'Input' size must be 1 "));

            auto start_vec = i_op->GetAttrIfExists<std::vector<int>>("starts");
            start_id = start_vec[0];
            auto end_vec = i_op->GetAttrIfExists<std::vector<int>>("ends");
            end_id = end_vec[0];
            break;
          }
        }
      }
D
denglin-github 已提交
124

D
denglin-github 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
      std::vector<std::string> shape_input_name_vec;
      for (auto *i_op : all_ops) {
        if (i_op->HasOutput("Out")) {
          auto it = find(i_op->Output("Out").begin(),
                         i_op->Output("Out").end(),
                         slice_input_name_vec[0]);
          if (it != i_op->Output("Out").end()) {
            shape_input_name_vec = i_op->Input("Input");
            PADDLE_ENFORCE_EQ(
                slice_input_name_vec.size(),
                1,
                platform::errors::PreconditionNotMet(
                    "The 'shape op' input 'Input' size must be 1 "));
            break;
          }
        }
      }
      auto target_var = block->FindVarRecursive(shape_input_name_vec[0]);
      std::vector<int64_t> target_shape = target_var->GetShape();
      size_t target_shape_len = target_shape.size();
      if (start_id < 0) {
        start_id = target_shape_len + start_id;
      } else if (start_id > static_cast<int>(target_shape_len)) {
        start_id = target_shape_len;
      }
D
denglin-github 已提交
150

D
denglin-github 已提交
151 152 153 154 155
      if (end_id < 0) {
        end_id = target_shape_len + end_id;
      } else if (end_id > static_cast<int>(target_shape_len)) {
        end_id = target_shape_len;
      }
D
denglin-github 已提交
156

D
denglin-github 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
      if (start_id < end_id) {
        std::vector<int64_t> OutSize_dims(target_shape.begin() + start_id,
                                          target_shape.begin() + end_id);

        framework::VariableNameMap output_name_map = block->Op(i)->Outputs();
        std::vector<std::string> output_name_vec = output_name_map["Out"];
        auto out_var = block->FindVarRecursive(output_name_vec[0]);
        PADDLE_ENFORCE_NOT_NULL(
            out_var,
            platform::errors::NotFound(
                "bilinear_interp_v2 op's output %s is not found in the block.",
                output_name_vec[0]));
        std::vector<int64_t> ori_shape = out_var->GetShape();
        std::string data_layout =
            block->Op(i)->GetAttrIfExists<std::string>("data_layout");
        size_t start_dim = 0;
        size_t end_dim = 0;

        if (data_layout == "NCHW") {
          start_dim = 2;
          end_dim = ori_shape.size();
        } else {
          start_dim = 1;
          end_dim = ori_shape.size() - 1;
        }
        for (size_t i_dim = start_dim; i_dim < end_dim; i_dim++) {
          ori_shape[i_dim] = OutSize_dims[i_dim - start_dim];
        }
D
denglin-github 已提交
185

D
denglin-github 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        VLOG(4) << "Set bilinear_interp_v2 shape: " << ori_shape[2] << ", "
                << ori_shape[3];
        out_var->SetShape(ori_shape);
      }

    } else {
      if (OP_WITHOUT_KERNEL_SET.find(block->Op(i)->Type()) ==
          OP_WITHOUT_KERNEL_SET.end())
        block->Op(i)->InferShape(*block);
    }
  }
}

bool analysis::DlnneSubgraphPass::IsDynamicOp(std::string var_name,
                                              bool use_static_batch) const {
  framework::ProgramDesc *global_program =
      Get<framework::ProgramDesc *>("program");
  auto block = global_program->MutableBlock(framework::kRootBlockIndex);
  auto var = block->FindVar(var_name);

  if (var != nullptr) {
    std::vector<int64_t> var_shape = var->GetShape();
    size_t start_idx = use_static_batch ? 1 : 0;
    for (; start_idx < var_shape.size(); start_idx++) {
      if (var_shape[start_idx] < 1) {
        return false;
      }
    }
  }
  return true;
}
D
denglin-github 已提交
217 218

void analysis::DlnneSubgraphPass::ApplyImpl(framework::ir::Graph *graph) const {
D
denglin-github 已提交
219 220 221 222
  framework::ir::FusePassBase::Init("dlnne_subgraph_pass", graph);

  InferShapeForDlnneMainGraph();

D
denglin-github 已提交
223
  static std::unordered_set<std::string> teller_set{
D
denglin-github 已提交
224
      "nearest_interp_v2",
225 226
      "mul",
      "matmul",
D
denglin-github 已提交
227 228
      "matmul_v2",
      "flatten_contiguous_range",
229 230 231 232 233
      "conv2d",
      "pool2d",
      "relu",
      "softmax",
      "sigmoid",
D
denglin-github 已提交
234
      "softplus",
235
      "hard_swish",
D
denglin-github 已提交
236
      "hard_sigmoid",
237 238
      "depthwise_conv2d",
      "batch_norm",
D
denglin-github 已提交
239
      "exp",
240
      "concat",
D
denglin-github 已提交
241 242
      "clip",
      "cast",
243 244 245 246
      "tanh",
      "pad",
      "elementwise_add",
      "elementwise_mul",
D
denglin-github 已提交
247 248 249
      "elementwise_sub",
      "elementwise_div",
      "elementwise_pow",
250
      "dropout",
D
denglin-github 已提交
251 252
      // "deformable_conv",

253 254 255
      "prelu",
      "conv2d_transpose",
      "leaky_relu",
D
denglin-github 已提交
256 257
      "log",
      "fc",
258 259 260
      "shuffle_channel",
      "swish",
      "split",
D
denglin-github 已提交
261
      "instance_norm",
D
denglin-github 已提交
262
      "gelu",
D
denglin-github 已提交
263 264 265 266
      "layer_norm",
      "scale",
      "slice",
      "stack",
267 268 269 270 271
      "relu6",
      "reshape2",
      "transpose2",
      "concat",
      "slice",
D
denglin-github 已提交
272 273 274 275 276 277 278 279
      "fill_constant",
      "fill_constant_batch_size_like",
      "shape",
      "unsqueeze2",
      "pad3d",
      "squeeze2",
      "bilinear_interp_v2"
      // "yolo_box"
D
denglin-github 已提交
280 281
  };

D
denglin-github 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
  // the op which output is special, need special process
  static std::unordered_set<std::string> special_output_op_set{
      "transpose2",
      "fill_constant_batch_size_like",
      "flatten_contiguous_range",
      "batch_norm",
      "unsqueeze2",
  };

  // the op when it's shape is dynamic still can be fused by
  // dlnne_engine_op
  static std::unordered_set<std::string> dynamic_pass_op_set{
      "reshape2",
  };
  auto disable_nodes_by_outputs =
      Get<std::unordered_set<std::string>>("disable_nodes_by_outputs");
  bool use_static_batch = Get<bool>("use_static_batch");
D
denglin-github 已提交
299 300

  auto teller = [&](const framework::ir::Node *node) {
D
denglin-github 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    if (!node->IsOp() || !node->Op()) {
      return false;
    }
    if (teller_set.find(node->Op()->Type()) == teller_set.end()) {
      VLOG(4) << "don't support op:" << node->Op()->Type();
      return false;
    } else {
      bool flag = true;
      // check node output
      if (dynamic_pass_op_set.find(node->Op()->Type()) !=
          dynamic_pass_op_set.end()) {
        flag = true;
      } else if (special_output_op_set.find(node->Op()->Type()) ==
                 special_output_op_set.end()) {
        for (auto *x : node->outputs) {
          std::string var_name = x->Var()->Name();
          flag = IsDynamicOp(var_name, use_static_batch);
          if (!flag) break;
        }
      } else {
        std::string var_name = node->outputs[0]->Var()->Name();
        flag = IsDynamicOp(var_name, use_static_batch);
      }
      // check node input
      if (flag) {
        for (auto *x : node->inputs) {
          std::string var_name = x->Var()->Name();
          flag = IsDynamicOp(var_name, use_static_batch);
          if (!flag) break;
        }
      }
      if (!flag) {
        VLOG(4) << "don't support dynamic shape:" << node->Op()->Type();
      }
      bool flag2 = true;
      for (auto *x : node->outputs) {
        if (disable_nodes_by_outputs.find(x->Name()) !=
            disable_nodes_by_outputs.end()) {
          flag2 = false;
        }
      }
      if (!flag2) {
        VLOG(4) << "user don't use " << node->Name() << "...";
      }
      return flag && flag2;
    }
D
denglin-github 已提交
347 348 349
  };

  framework::ir::SubGraphFuser fuser(
350 351 352
      graph,
      teller,
      Get<int>("min_subgraph_size") /*min subgraph size*/,
D
denglin-github 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
      "dlnne_engine");
  fuser();

  std::vector<std::string> graph_param_names =
      ExtractParameters(graph->Nodes());
  // those parameter already exist in dlnne, and should not have another copy in
  // fluid.
  std::vector<std::string> repetitive_params;

  for (auto *node : graph->Nodes()) {
    if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
      CreateDlnneOp(node, graph, graph_param_names, &repetitive_params);

      std::unordered_set<const Node *> nodes2remove(
          framework::ir::Agent(node).subgraph()->begin(),
          framework::ir::Agent(node).subgraph()->end());
      framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
    }
  }

  std::unordered_set<const Node *> nodes2remove;
  for (auto *node : graph->Nodes()) {
    if (node->IsOp() && framework::ir::Agent(node).deleted()) {
      nodes2remove.insert(node);
    }
  }
  framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
}

std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
                              const std::set<std::string> &engine_outputs,
                              const std::string &predictor_id) {
  std::string engine_hash_key = "";
  for (auto name : engine_inputs) {
    engine_hash_key += name;
  }
  for (auto name : engine_outputs) {
    engine_hash_key += name;
  }
  engine_hash_key += predictor_id;
  auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
  return engine_key;
}
396 397
std::string replace_name(std::string name,
                         const char *raw,
D
denglin-github 已提交
398 399 400 401 402 403 404 405 406 407
                         const char *new_char) {
  std::string r_name = name;
  int pos = r_name.find(raw);
  while (pos >= 0) {
    r_name = r_name.replace(pos, 1, new_char);
    pos = r_name.find(raw);
  }
  return r_name;
}

D
denglin-github 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
auto fix_batch_as_one(
    std::unordered_map<std::string, framework::VarDesc *> *name_var_desc,
    std::set<std::string> *valid_input_names,
    bool use_static_batch = false) {
  std::unordered_map<std::string, std::vector<int64_t>> name_var_shape;

  if (use_static_batch) {
    std::set<std::string> names;
    names.insert(valid_input_names->begin(), valid_input_names->end());

    for (auto name : names) {
      if (name_var_desc->find(name) != name_var_desc->end()) {
        auto var_desc = (*name_var_desc)[name];
        auto sp = var_desc->GetShape();
        if (sp[0] == -1) {
          sp[0] = 1;
          name_var_shape[name] = sp;
          std::stringstream sp_str;
          copy(sp.begin(),
               sp.end(),
               std::ostream_iterator<int64_t>(sp_str, ","));

          LOG(INFO)
              << "Warning: fix var:" << name << " batch,shape is ["
              << sp_str.str()
              << "],we assume subgraph's inputs/outputs first dim is batch,"
              << "but when the first dim is not mean batch "
              << "we suggest you use fix shape model....";
        }
      }
    }
  }
  return name_var_shape;
}
/*
there are two ProgramDesc in the function, global_program is used for generate a
Dlnne op, dump_program is used for dump the subgraph to onnx subgraph which is
loaded by Dlnne op
*/
D
denglin-github 已提交
447
void DlnneSubgraphPass::CreateDlnneOp(
448 449
    framework::ir::Node *node,
    framework::ir::Graph *graph,
D
denglin-github 已提交
450 451 452 453
    const std::vector<std::string> &graph_params,
    std::vector<std::string> *repetitive_params) const {
  auto *op_desc = node->Op();
  auto &subgraph = *framework::ir::Agent(node).subgraph();
454 455
  PADDLE_ENFORCE_EQ(subgraph.empty(),
                    false,
D
denglin-github 已提交
456 457 458 459 460 461 462 463 464 465
                    platform::errors::PreconditionNotMet(
                        "The subgraph should not be empty."));

  // A fake block desc.
  framework::proto::BlockDesc block_proto;
  framework::BlockDesc block_desc(nullptr, &block_proto);
  block_desc.Proto()->set_parent_idx(-1);
  block_desc.Proto()->set_idx(0);
  LOG(INFO) << "---  detect a sub-graph with " << subgraph.size() << " nodes";
  // for debug
D
denglin-github 已提交
466 467 468 469
  framework::ProgramDesc *global_program =
      Get<framework::ProgramDesc *>("program");
  const framework::BlockDesc &main_block =
      global_program->Block(framework::kRootBlockIndex);
D
denglin-github 已提交
470

D
denglin-github 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
  std::set<std::string> input_names;
  std::set<std::string> input_names_with_id;
  std::vector<std::string> params;
  std::set<std::string> valid_input_names;
  // if we delete fluid copy of params shared by more than 1 ops, there will be
  // problem, so we filter them out.

  // The node->inputs contains input tensors and parameters.
  for (auto *x : node->inputs) {
    input_names.insert(x->Name());
    input_names_with_id.insert(x->Name() + std::to_string(x->id()));
    if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) {
      params.push_back(x->Name());
    }
    if (std::find(graph_params.begin(), graph_params.end(), x->Name()) ==
        graph_params.end()) {
      valid_input_names.insert(x->Name());
    }
  }

  std::set<std::string> output_names;
  std::set<std::string> output_names_with_id;
  std::vector<int> origin_output_dims;
  std::set<std::string> valid_output_names;
  for (auto *x : node->outputs) {
    origin_output_dims.push_back(x->Var()->GetShape().size());
    output_names.insert(x->Name());
    output_names_with_id.insert(x->Name() + std::to_string(x->id()));
    if (std::find(graph_params.begin(), graph_params.end(), x->Name()) ==
        graph_params.end()) {
      valid_output_names.insert(x->Name());
    }
  }

  auto *child_block = global_program->AppendBlock(main_block);
  framework::ProgramDesc dump_program;
  auto *export_block = dump_program.MutableBlock(framework::kRootBlockIndex);
D
denglin-github 已提交
508

D
denglin-github 已提交
509
  std::unordered_map<std::string, framework::VarDesc *> name_var_desc;
D
denglin-github 已提交
510 511 512
  for (auto *node : subgraph) {
    auto *op = block_desc.AppendOp();
    *op->Proto() = *node->Op()->Proto();
D
denglin-github 已提交
513 514 515
    auto *child_op = child_block->AppendOp();
    *child_op->Proto() = *node->Op()->Proto();
    // generate op by node to append on block
D
denglin-github 已提交
516
    {
D
denglin-github 已提交
517
      auto *export_op = export_block->AppendOp();
D
denglin-github 已提交
518 519 520 521 522

      framework::OpDesc op_desc;
      op_desc.CopyFrom(*node->Op());

      for (auto argument_name : op_desc.InputArgumentNames()) {
523 524
        if (std::count(
                graph_params.begin(), graph_params.end(), argument_name) > 0) {
D
denglin-github 已提交
525 526 527 528
          op_desc.Rename(argument_name, replace_name(argument_name, "/", "."));
        }
      }
      for (auto argument_name : op_desc.OutputArgumentNames()) {
529 530
        if (std::count(
                graph_params.begin(), graph_params.end(), argument_name) > 0) {
D
denglin-github 已提交
531 532 533
          op_desc.Rename(argument_name, replace_name(argument_name, "/", "."));
        }
      }
D
denglin-github 已提交
534
      *export_op->Proto() = *op_desc.Proto();
D
denglin-github 已提交
535 536 537

      for (auto *x : node->inputs) {
        if (x->IsVar()) {
D
denglin-github 已提交
538 539 540 541 542 543
          auto var_desc_infer = main_block.FindVarRecursive(x->Name());
          if (var_desc_infer != nullptr) {
            name_var_desc[x->Name()] = var_desc_infer;
          } else {
            name_var_desc[x->Name()] = x->Var();
          }
D
denglin-github 已提交
544 545 546 547 548
        }
      }

      for (auto *x : node->outputs) {
        if (x->IsVar()) {
D
denglin-github 已提交
549 550 551 552 553 554
          auto var_desc_infer = main_block.FindVarRecursive(x->Name());
          if (var_desc_infer != nullptr) {
            name_var_desc[x->Name()] = var_desc_infer;
          } else {
            name_var_desc[x->Name()] = x->Var();
          }
D
denglin-github 已提交
555 556 557 558 559
        }
      }
    }
  }

D
denglin-github 已提交
560 561 562 563 564 565 566
  // starting fix bath as one
  bool use_static_batch = Get<bool>("use_static_batch");
  auto name_shape_table =
      fix_batch_as_one(*name_var_desc, *valid_input_names, use_static_batch);

  for (const auto &name_shape : name_shape_table) {
    VLOG(4) << "Fix batch shape as one var name: " << name_shape.first;
D
denglin-github 已提交
567 568 569 570 571 572
  }

  // Then, we will use the input_names_with_id and output_names_with_id to
  // generate the engine key.
  // So, We use set instead of unordered_set here to ensure that the engine key
  // is unique.
D
denglin-github 已提交
573 574 575 576 577 578
  auto engine_key = GenerateEngineKey(
      input_names_with_id, output_names_with_id, std::to_string(0));
  auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
  bool enable_int8 = false;
  if (precision_mode == AnalysisConfig::Precision::kInt8) {
    enable_int8 = true;
D
denglin-github 已提交
579
  }
D
denglin-github 已提交
580 581 582 583 584 585 586 587
  auto use_calib_mode = Get<bool>("use_calib_mode");

  std::string calibration_data_path = "./calibration/dlnne_calib_" + engine_key;
  bool calibration_mode = false;
  if (enable_int8 && use_calib_mode && !PathExists(calibration_data_path)) {
    calibration_mode = true;
    MKDIR("./calibration");
    MKDIR(calibration_data_path.c_str());
D
denglin-github 已提交
588
  }
D
denglin-github 已提交
589 590 591 592 593
  VLOG(4) << "calibration_mode: " << calibration_mode;
  std::stringstream ss;
  ss << "engine_key:" << engine_key << " outputs:[";
  for (auto name : valid_output_names) {
    ss << name << ",";
D
denglin-github 已提交
594
  }
D
denglin-github 已提交
595 596
  ss << "]";
  VLOG(4) << ss.str();
D
denglin-github 已提交
597 598 599

  // Set attrs
  op_desc->SetType("dlnne_engine");
600 601 602
  op_desc->SetInput("Xs",
                    std::vector<std::string>(valid_input_names.begin(),
                                             valid_input_names.end()));
D
denglin-github 已提交
603

604 605 606
  op_desc->SetOutput("Ys",
                     std::vector<std::string>(valid_output_names.begin(),
                                              valid_output_names.end()));
D
denglin-github 已提交
607
  op_desc->SetBlockAttr("sub_block", child_block);
D
denglin-github 已提交
608 609 610

  op_desc->SetAttr("parameters", params);
  op_desc->SetAttr("engine_key", engine_key);
D
denglin-github 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
  op_desc->SetAttr("max_batch_size", Get<int>("max_batch_size"));
  op_desc->SetAttr("use_static_batch", Get<bool>("use_static_batch"));
  op_desc->SetAttr("weight_share_mode", Get<std::string>("weight_share_mode"));
  op_desc->SetAttr("enable_int8", enable_int8);
  op_desc->SetAttr("use_calib_mode", use_calib_mode);
  op_desc->SetAttr("calibration_mode", calibration_mode);
  op_desc->SetAttr("calibration_data_path", calibration_data_path);

  std::string subgraph_root_path = "./dump/" + engine_key;
  op_desc->SetAttr("subgraph_root_path", subgraph_root_path);

  std::stringstream ins_stream;
  for (auto name : valid_input_names) {
    ins_stream << "," << name;
  }
  op_desc->SetAttr("valid_input_names", ins_stream.str().substr(1));
D
denglin-github 已提交
627

D
denglin-github 已提交
628 629 630 631 632
  std::stringstream outs_stream;
  for (auto name : valid_output_names) {
    outs_stream << "," << name;
  }
  op_desc->SetAttr("valid_output_names", outs_stream.str().substr(1));
D
denglin-github 已提交
633

D
denglin-github 已提交
634 635
  auto *scope = param_scope();
  {
D
denglin-github 已提交
636 637
    // add feed to subgraph:
    int input_idx = 0;
D
denglin-github 已提交
638 639 640 641 642 643
    for (auto input_name : valid_input_names) {
      auto *feed1 = export_block->AppendOp();
      feed1->SetType("feed");
      feed1->SetInput("X", {"feed"});
      feed1->SetOutput("Out", {input_name});
      feed1->SetAttr("col", input_idx);
D
denglin-github 已提交
644 645 646 647 648
      input_idx++;
    }
    // add fetch to subgraph:
    int output_idx = 0;
    for (auto output_name : valid_output_names) {
D
denglin-github 已提交
649 650 651 652 653
      auto *fetch1 = export_block->AppendOp();
      fetch1->SetType("fetch");
      fetch1->SetInput("X", {output_name});
      fetch1->SetOutput("Out", {"out"});
      fetch1->SetAttr("col", output_idx);
D
denglin-github 已提交
654 655 656 657 658 659
      output_idx++;
    }

    VLOG(4) << "name_var_desc size:" << name_var_desc.size();

    for (auto &kv : name_var_desc) {
D
denglin-github 已提交
660 661 662 663 664 665 666 667 668 669 670 671
      auto *new_add_var1 = export_block->Proto()->add_vars();
      paddle::framework::VarDesc copy_var_desc(*(kv.second->Proto()));

      if (name_shape_table.find(kv.first) != name_shape_table.end()) {
        copy_var_desc.SetShape(name_shape_table[kv.first]);
      }
      *new_add_var1 = *(copy_var_desc.Proto());

      auto *variable_tmp1 = scope->FindVar(kv.first);
      if (variable_tmp1 != nullptr) {
        *new_add_var1->mutable_name() = replace_name(kv.first, "/", ".");
        new_add_var1->set_persistable(true);
D
denglin-github 已提交
672
      } else {
D
denglin-github 已提交
673
        new_add_var1->set_persistable(false);
D
denglin-github 已提交
674 675 676
      }
    }

D
denglin-github 已提交
677 678 679 680 681 682 683 684 685 686 687 688 689 690
    std::string model_str;
    dump_program.Proto()->SerializeToString(&model_str);
    op_desc->SetAttr("subgraph", model_str);
    op_desc->Flush();

    if (calibration_mode) {
      return;
    }

    MKDIR("./dump");
    MKDIR(subgraph_root_path.c_str());
    std::ofstream m_stream;
    m_stream.open(subgraph_root_path + "/__model__", std::ios::out);

D
denglin-github 已提交
691 692 693 694
    for (auto param_name : params) {
      auto *var = scope->FindVar(param_name);
      if (var != nullptr) {
        auto *var_t = var->GetMutable<framework::LoDTensor>();
D
denglin-github 已提交
695 696 697 698
        std::ofstream p_stream;
        p_stream.open(
            subgraph_root_path + "/" + replace_name(param_name, "/", "."),
            std::ios::out);
D
denglin-github 已提交
699 700 701 702 703 704 705 706
        platform::DeviceContextPool &pool =
            platform::DeviceContextPool::Instance();
        auto &dev_ctx = *pool.Get(var_t->place());
        framework::SerializeToStream(p_stream, *var_t, dev_ctx);
        p_stream.close();
      }
    }

D
denglin-github 已提交
707
    m_stream << model_str;
D
denglin-github 已提交
708 709 710 711 712 713 714 715 716 717
    m_stream.close();
  }
}

}  // namespace analysis
}  // namespace inference
}  // namespace paddle

REGISTER_PASS(dlnne_subgraph_pass,
              paddle::inference::analysis::DlnneSubgraphPass);