inference_process_pass.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/inference_process_pass.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
19 20
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
21 22 23 24 25 26 27 28 29 30
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
  VLOG(10) << "enter InferenceProcessPass::ApplyImpl";

  // Get a new instance of ipu_backend
31
  auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
32 33 34 35 36 37 38 39 40

  // Set scope
  auto& scope = graph->Get<Scope>(kParamScopeAttr);
  ipu_backend->SetScope(scope);

  // Set ipu_strategy
  static std::shared_ptr<platform::ipu::IpuStrategy> ipu_strategy_instance_(
      new platform::ipu::IpuStrategy());
  ipu_strategy_instance_->is_training = false;
41 42 43 44 45 46 47
  // Set graph replication
  auto replica_num = graph->Get<int>("replica_num");
  if (replica_num > 1) {
    ipu_strategy_instance_->popart_options.enableReplicatedGraphs = true;
    ipu_strategy_instance_->popart_options.replicatedGraphCount = replica_num;
  }
  // Set the num of IPUs
48
  auto num_ipus = graph->Get<int>("num_ipus");
49
  // Set sharding
50
  if (num_ipus > 1) {
51 52
    ipu_strategy_instance_->need_avg_shard = true;
    ipu_strategy_instance_->popart_options.virtualGraphMode =
A
Allen Guo 已提交
53
        popart::VirtualGraphMode::Manual;
54
  } else {
55 56
    ipu_strategy_instance_->need_avg_shard = false;
    ipu_strategy_instance_->popart_options.virtualGraphMode =
A
Allen Guo 已提交
57
        popart::VirtualGraphMode::Off;
58
  }
59 60 61 62 63 64
  // total num IPUs = num_ipus * replica_num
  ipu_strategy_instance_->num_ipus = num_ipus * replica_num;

  // Set micro_batch_size for shape inference
  ipu_strategy_instance_->micro_batch_size =
      graph->Get<int>("micro_batch_size");
65

66
  // Set pipelining
67
  auto enable_pipelining = graph->Get<bool>("enable_pipelining");
68
  ipu_strategy_instance_->popart_options.enablePipelining = enable_pipelining;
69 70 71
  if (enable_pipelining) {
    auto batches_per_step = graph->Get<int>("batches_per_step");
    PADDLE_ENFORCE_GE(
72 73
        batches_per_step,
        num_ipus,
74 75 76 77
        platform::errors::InvalidArgument("Batched per step should be equal or "
                                          "greater than the number of IPUs"));
    ipu_strategy_instance_->batches_per_step = batches_per_step;
  }
78 79 80 81 82 83 84 85 86 87 88 89 90 91

  // Set FP16
  auto enable_fp16 = graph->Get<bool>("enable_fp16");
  ipu_strategy_instance_->enable_fp16 = enable_fp16;
  if (enable_fp16) {
    auto enable_half_partial = graph->Get<bool>("enable_half_partial");
    if (enable_half_partial) {
      ipu_strategy_instance_->popart_options.partialsTypeMatMuls = "half";
    }
  }

  // Set available memory proportion for matmul/conv
  ipu_strategy_instance_->available_memory_proportion =
      graph->Get<float>("available_memory_proportion");
92

93 94 95
  // Set tiles_per_ipu for IPUMODEL
  ipu_strategy_instance_->tiles_per_ipu = 128;

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
  ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get()));

  // Get feed_list and fetch list
  std::vector<std::string> feed_list = {};
  std::vector<std::string> fetch_list = {};
  for (auto node : graph->Nodes()) {
    if (node->Name() == "feed") {
      if (node->IsOp()) {
        feed_list.push_back("");
      }
    } else if (node->Name() == "fetch") {
      if (node->IsOp()) {
        fetch_list.push_back("");
      }
    }
  }
  for (auto node : graph->Nodes()) {
    if (node->Name() == "feed") {
      if (node->IsOp()) {
        feed_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] =
            node->outputs[0]->Name();
      }
    } else if (node->Name() == "fetch") {
      if (node->IsOp()) {
        fetch_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] =
            node->inputs[0]->Name();
      }
    }
  }

  // Run passes
127
  std::vector<std::string> graph_pass = {"forward_graph_extract_pass",
128 129
                                         "infer_shape_pass",
                                         "avg_shard_pass",
130 131
                                         "popart_canonicalization_pass",
                                         "inference_dtype_transfer_pass"};
132 133 134 135
  std::vector<std::string> compile_pass = {"ipu_inplace_pass",
                                           "ipu_graph_builder_pass",
                                           "ipu_runtime_replacer_pass",
                                           "inference_postprocess_pass"};
136 137 138
  for (auto pass_name : graph_pass) {
    auto pass = PassRegistry::Instance().Get(pass_name);
    if (pass_name == "infer_shape_pass") {
139 140 141
      pass->Set(
          "feed_list",
          new std::vector<std::string>(feed_list.begin(), feed_list.end()));
142 143 144 145 146 147 148 149
    }
    pass->Apply(graph);
  }

  for (auto pass_name : compile_pass) {
    auto pass = PassRegistry::Instance().Get(pass_name);
    pass->Set("feed_list",
              new std::vector<std::string>(feed_list.begin(), feed_list.end()));
150 151 152
    pass->Set(
        "fetch_list",
        new std::vector<std::string>(fetch_list.begin(), fetch_list.end()));
153 154 155 156 157 158 159 160 161 162 163 164
    pass->Apply(graph);
  }

  VLOG(10) << "leave InferenceProcessPass::ApplyImpl";
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(inference_process_pass,
              paddle::framework::ir::InferenceProcessPass);