api_anakin_engine.cc 14.8 KB
Newer Older
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

T
Tao Luo 已提交
15 16 17
#include <map>
#include <string>
#include <utility>
L
Luo Tao 已提交
18
#include <vector>
Y
Yan Chunwei 已提交
19

20 21 22
#include "paddle/fluid/inference/api/api_anakin_engine.h"
#include "paddle/fluid/inference/api/paddle_api.h"

T
Tao Luo 已提交
23 24 25 26
#include "framework/core/net/net.h"
#include "framework/operators/ops.h"
#include "saber/funcs/timer.h"

Y
Yan Chunwei 已提交
27 28
namespace paddle {

Y
Yan Chunwei 已提交
29
using paddle::contrib::AnakinConfig;
30 31 32 33
template <typename T, Precision P, OpRunType R>
extern std::mutex PaddleInferenceAnakinPredictor<T, P, R>::mutex_;
template <typename T, Precision P, OpRunType R>
extern std::once_flag PaddleInferenceAnakinPredictor<T, P, R>::init_anakin_;
Y
Yan Chunwei 已提交
34

35 36 37 38 39 40
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitEnv() {
  anakin::TargetWrapper<T>::set_device(this->config_.device_id);
  std::call_once(this->init_anakin_, [this]() {
    anakin::Env<T>::env_init(this->config_.max_stream);
  });
Y
Yan Chunwei 已提交
41
}
42 43 44 45
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitNet() {
  std::unique_lock<std::mutex> lock(this->mutex_);
  this->executor_p_ = new anakin::Net<T, P, R>(*this->graph_p_, true);
T
Tao Luo 已提交
46
}
47 48 49 50 51 52 53 54 55 56 57 58
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::SetContext() {
  this->ctx_p_ = std::make_shared<anakin::Context<T>>(
      this->config_.device_id, this->config_.data_stream_id,
      this->config_.compute_stream_id);
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitGraph() {
  this->graph_p_ =
      std::make_shared<anakin::graph::Graph<T, anakin::Precision::FP32>>();
  if (!(this->graph_p_->load(this->config_.model_file))) {
    LOG(FATAL) << "fail to load graph from " << this->config_.model_file;
C
cuichaowen 已提交
59
  }
60
  auto inputs = this->graph_p_->get_ins();
C
cuichaowen 已提交
61
  for (auto &input_str : inputs) {
62 63 64 65 66 67 68
    if (this->config_.init_inputs_shape.find(input_str) ==
        this->config_.init_inputs_shape.end()) {
      LOG(FATAL) << input_str << " is not implemented.";
    }
    std::vector<int> shape =
        this->config_.init_inputs_shape.find(input_str)->second;
    this->graph_p_->Reshape(input_str, shape);
C
cuichaowen 已提交
69
  }
70 71 72 73 74
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::OptimizeGraph() {
  if (!this->graph_p_->Optimize()) {
    LOG(FATAL) << "Graph optimization error.";
C
cuichaowen 已提交
75
  }
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitPredictor() {
  this->InitEnv();
  this->SetContext();
  this->InitGraph();
  this->OptimizeGraph();
  this->InitNet();
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::Predict() {
  anakin::TargetWrapper<T>::device_sync();
  this->executor_p_->prediction();
  anakin::TargetWrapper<T>::device_sync();
}
template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor<T, P, R>::Run(
    const std::vector<PaddleTensor> &inputs,
    std::vector<PaddleTensor> *output_data, int batch_size) {
  if (this->config_.re_allocable) {
    return this->RunImpl(inputs, output_data);
  } else {
    // Run inputs data that exceeds batch size in batches.
    // 1. Reassign the batch size.
    if (batch_size == -1) {
      if (!inputs[0].lod.empty()) {
        batch_size = inputs[0].lod[0].size() - 1;
      } else {
        batch_size = inputs[0].shape[0];
      }
    }
    // 2. If the data don't need to be batched, run it directly.
    if (batch_size <= this->config_.init_batch_size) {
      return this->RunImpl(inputs, output_data);
    }
    // 3. Check the batch size and define temporary variables.
    std::vector<PaddleTensor> cur_inputs;
    std::vector<PaddleTensor> outputs_master;
    std::vector<std::vector<paddle::PaddleTensor>> outputs_vec;
    for (const auto &input : inputs) {
      if (!input.lod.empty()) {
        if (input.lod.size() != 1) {
          return false;
        }
        if (input.lod[0].size() - 1 != batch_size) {
          return false;
        }
      } else {
        LOG(INFO) << "Non-lod mode to be implemented.";
        return false;
      }
      PaddleTensor tensor;
      tensor.name = input.name;
      tensor.dtype = PaddleDType::FLOAT32;
      cur_inputs.push_back(tensor);
    }
    for (auto output : *output_data) {
      PaddleTensor tensor;
      tensor.name = output.name;
      outputs_master.push_back(tensor);
    }
    // 4. Batch execution.
    for (size_t start_batch = 0; start_batch < batch_size;) {
      auto end_batch = start_batch + this->config_.init_batch_size;
      if (end_batch > batch_size) {
        end_batch = batch_size;
      }
      auto cur_outputs = outputs_master;
      for (size_t i = 0; i < inputs.size(); i++) {
        auto start = inputs[i].lod[0][start_batch];
        auto end = inputs[i].lod[0][end_batch];
        std::vector<size_t> offsets;
        for (size_t j = start_batch; j <= end_batch; j++) {
          offsets.push_back(inputs[i].lod[0][j] -
                            inputs[i].lod[0][start_batch]);
        }
        auto mem_start = static_cast<float *>(inputs[i].data.data()) + start;
        cur_inputs[i].data =
            PaddleBuf(mem_start, (end - start) * sizeof(float));
        cur_inputs[i].lod = std::vector<std::vector<size_t>>({offsets});
        cur_inputs[i].shape =
            std::vector<int>({static_cast<int>(end - start), 1, 1, 1});
      }
      if (!this->RunImpl(cur_inputs, &cur_outputs)) {
        return false;
      }
      outputs_vec.push_back(cur_outputs);
      start_batch = end_batch;
    }
    // 5. Copy the results to contiguous memory.
    // Assume that each batch has the same final outputs size.
    auto count = [](const std::vector<int> &v) {
      int cnt = 1;
      for_each(v.begin(), v.end(), [&cnt](int n) { cnt *= n; });
      return cnt;
    };
    for (size_t i = 0; i < output_data->size(); i++) {
      std::vector<int> shape = outputs_vec[i][0].shape;
      shape[0] = batch_size;
      int total_cnt = count(shape);
      (*output_data)[i].shape = shape;
      (*output_data)[i].data.Resize(total_cnt * sizeof(float));
      float *addr = static_cast<float *>((*output_data)[i].data.data());
      for (const auto &single_out : outputs_vec) {
        int cnt = count(single_out[i].shape);
        memcpy(addr, single_out[i].data.data(), cnt * sizeof(float));
        addr += cnt;
      }
    }
C
cuichaowen 已提交
185
  }
Y
Yan Chunwei 已提交
186 187
  return true;
}
188 189
template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
Y
Yan Chunwei 已提交
190
    const std::vector<PaddleTensor> &inputs,
191
    std::vector<PaddleTensor> *output_data) {
Y
Yan Chunwei 已提交
192 193
  for (const auto &input : inputs) {
    if (input.dtype != PaddleDType::FLOAT32) {
194 195
      LOG(FATAL) << "Only support float type inputs. " << input.name
                 << "'s type is not float";
Y
Yan Chunwei 已提交
196
    }
197 198
    auto d_tensor_p = this->executor_p_->get_in(input.name);
    auto net_shape = d_tensor_p->shape();
C
cuichaowen 已提交
199
    if (net_shape.size() != input.shape.size()) {
200 201
      LOG(FATAL) << " input  " << input.name
                 << "'s shape size should be equal to that of net";
C
cuichaowen 已提交
202 203 204 205
    }
    int sum = 1;
    for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
    if (sum > net_shape.count()) {
206 207 208 209 210 211 212 213 214 215
      if (this->config_.re_allocable) {
        this->graph_p_->Reshape(input.name, input.shape);
        delete this->executor_p_;
        this->InitNet();
        d_tensor_p = this->executor_p_->get_in(input.name);
      } else {
        LOG(FATAL)
            << "Run failed because Anakin was expected not to reallocate "
               "memory.";
      }
C
cuichaowen 已提交
216
    }
217
    std::vector<int> tmp_shape;
C
cuichaowen 已提交
218 219 220
    for (auto s : input.shape) {
      tmp_shape.push_back(s);
    }
221 222 223 224 225
    auto *data = static_cast<float *>(input.data.data());
    anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
        h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
                 tmp_shape);
    d_tensor_p->reshape(tmp_shape);
C
cuichaowen 已提交
226

T
Tao Luo 已提交
227 228
    if (input.lod.size() > 0) {
      if (input.lod.size() > 1) {
229 230
        LOG(FATAL) << " input lod first dim should <=1, but you set "
                   << input.lod.size();
T
Tao Luo 已提交
231
      }
232 233 234 235 236 237
      std::vector<int> lod(input.lod[0].begin(), input.lod[0].end());
      std::vector<std::vector<int>> offset({lod});
      d_tensor_p->set_seq_offset(offset);
      VLOG(3) << "offset.size(): " << offset[0].size();
      for (int i = 0; i < offset[0].size(); i++) {
        VLOG(3) << offset[0][i];
T
Tao Luo 已提交
238 239
      }
    }
240
    d_tensor_p->copy_from(h_tensor);
Y
Yan Chunwei 已提交
241
  }
242
  this->Predict();
Y
Yan Chunwei 已提交
243
  if (output_data->empty()) {
244
    LOG(FATAL) << "At least one output should be set with tensors' names.";
Y
Yan Chunwei 已提交
245 246
  }
  for (auto &output : *output_data) {
247 248 249 250
    auto *d_tensor_p = this->executor_p_->get_out(output.name);
    output.shape = d_tensor_p->valid_shape();
    if (output.data.length() < d_tensor_p->valid_size() * sizeof(float)) {
      output.data.Resize(d_tensor_p->valid_size() * sizeof(float));
Y
Yan Chunwei 已提交
251
    }
252 253 254 255 256
    auto *data = static_cast<float *>(output.data.data());
    anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
        h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
                 d_tensor_p->valid_shape());
    h_tensor.copy_from(*d_tensor_p);
Y
Yan Chunwei 已提交
257 258 259
  }
  return true;
}
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor<T, P, R>::ResetConfig(
    const AnakinConfig &config) {
  this->config_ = config;
  return true;
}
template <typename T, Precision P, OpRunType R>
anakin::Net<T, P, R> &PaddleInferenceAnakinPredictor<T, P, R>::ResetExecuter(
    std::shared_ptr<anakin::graph::Graph<T, P>> graph_p) {
  this->graph_p_ = graph_p;
  this->ctx_p_ = std::make_shared<anakin::Context<T>>(
      this->config_.device_id, this->config_.data_stream_id,
      this->config_.compute_stream_id);
  this->InitNet();
  return *this->executor_p_;
C
cuichaowen 已提交
275 276 277
}
// the cloned new Predictor of anakin share the same net weights from original
// Predictor
278
template <typename T, Precision P, OpRunType R>
C
cuichaowen 已提交
279
std::unique_ptr<PaddlePredictor>
280
PaddleInferenceAnakinPredictor<T, P, R>::Clone() {
C
cuichaowen 已提交
281
  VLOG(3) << "Anakin Predictor::clone";
C
cuichaowen 已提交
282
  std::unique_ptr<PaddlePredictor> cls(
283
      new PaddleInferenceAnakinPredictor<T, P, R>());
C
cuichaowen 已提交
284 285
  // construct executer from other graph
  auto anakin_predictor_p =
286
      dynamic_cast<PaddleInferenceAnakinPredictor<T, P, R> *>(cls.get());
C
cuichaowen 已提交
287
  if (!anakin_predictor_p) {
288
    LOG(FATAL) << "fail to call Init";
C
cuichaowen 已提交
289
  }
290 291 292 293
  anakin_predictor_p->ResetConfig(this->config_);
  anakin_predictor_p->ResetExecuter(this->graph_p_);
  return cls;
}
C
cuichaowen 已提交
294

295 296 297 298 299 300 301 302
#ifdef ANAKIN_MLU_PLACE
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::SetContext() {
  this->ctx_p_ = std::make_shared<anakin::Context<anakin::MLU>>(
      this->config_.device_id, this->config_.data_stream_id,
      this->config_.compute_stream_id);
  this->ctx_p_->set_model_parallel(this->config_.model_parallel);
  this->ctx_p_->set_fusion(this->config_.op_fuse);
Y
Yan Chunwei 已提交
303
}
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::OptimizeGraph() {
  if (!this->graph_p_->fusion_optimize(this->config_.op_fuse)) {
    LOG(FATAL) << "Graph optimization error.";
  }
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::InitNet() {
  std::unique_lock<std::mutex> lock(this->mutex_);
  this->executor_p_ = new anakin::Net<anakin::MLU, P, R>();
  this->executor_p_->fusion_init(*this->graph_p_, this->ctx_p_, true);
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
  anakin::TargetWrapper<anakin::MLU>::device_sync();
  this->executor_p_->fusion_prediction();
  anakin::TargetWrapper<anakin::MLU>::device_sync();
}
#endif
Y
Yan Chunwei 已提交
323

324
#ifdef PADDLE_WITH_CUDA
325 326 327 328 329 330 331 332 333 334
template class PaddleInferenceAnakinPredictor<
    anakin::NV, anakin::Precision::FP32, ::anakin::OpRunType::ASYNC>;
#endif
#ifdef ANAKIN_X86_PLACE
template class PaddleInferenceAnakinPredictor<
    anakin::X86, anakin::Precision::FP32, ::anakin::OpRunType::ASYNC>;
#endif
#ifdef ANAKIN_MLU_PLACE
template class PaddleInferenceAnakinMLUPredictor<anakin::Precision::FP32,
                                                 ::anakin::OpRunType::SYNC>;
335
#endif
C
cuichaowen 已提交
336

Y
Yan Chunwei 已提交
337 338
// A factory to help create difference predictor.
template <>
Y
Yan Chunwei 已提交
339 340 341
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
    const contrib::AnakinConfig &config) {
342
#ifdef PADDLE_WITH_CUDA
343 344 345 346 347
  if (config.target_type == contrib::AnakinConfig::NVGPU) {
    return std::unique_ptr<PaddlePredictor>(
        new PaddleInferenceAnakinPredictor<anakin::NV, anakin::Precision::FP32,
                                           ::anakin::OpRunType::ASYNC>(config));
  }
348
#endif
349 350 351 352 353
#ifdef ANAKIN_X86_PLACE
  if (config.target_type == contrib::AnakinConfig::X86) {
    return std::unique_ptr<PaddlePredictor>(
        new PaddleInferenceAnakinPredictor<anakin::X86, anakin::Precision::FP32,
                                           ::anakin::OpRunType::ASYNC>(config));
C
cuichaowen 已提交
354
  }
355 356 357 358 359 360 361 362 363 364 365
#endif
#ifdef ANAKIN_MLU_PLACE
  if (config.target_type == contrib::AnakinConfig::MLU) {
    return std::unique_ptr<PaddlePredictor>(
        new PaddleInferenceAnakinMLUPredictor<anakin::Precision::FP32,
                                              ::anakin::OpRunType::SYNC>(
            config));
  }
#endif
  LOG(FATAL) << "Anakin Predictor create on unknown platform.";
  return nullptr;
T
Tao Luo 已提交
366
}
367 368
template <typename T, Precision P, OpRunType R>
void DisplayOpTimer(anakin::Net<T, P, R> *net_executor, int epoch) {
T
Tao Luo 已提交
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
  std::vector<float> op_time = net_executor->get_op_time();
  auto exec_funcs = net_executor->get_exec_funcs();
  auto op_param = net_executor->get_op_param();
  for (int i = 0; i < op_time.size(); i++) {
    LOG(INFO) << "name: " << exec_funcs[i].name
              << " op_type: " << exec_funcs[i].op_name
              << " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
  }
  std::map<std::string, float> op_map;
  for (int i = 0; i < op_time.size(); i++) {
    auto it = op_map.find(op_param[i]);
    if (it != op_map.end())
      op_map[op_param[i]] += op_time[i];
    else
      op_map.insert(std::pair<std::string, float>(op_param[i], op_time[i]));
  }
  for (auto it = op_map.begin(); it != op_map.end(); ++it) {
    LOG(INFO) << it->first << "  " << (it->second) / epoch << " ms";
  }
#endif
390 391 392 393 394 395
}
template <typename T, Precision P, OpRunType R>
PaddleInferenceAnakinPredictor<T, P, R>::~PaddleInferenceAnakinPredictor() {
  DisplayOpTimer<T, P, R>(this->executor_p_, this->config_.init_batch_size);
  delete this->executor_p_;
  this->executor_p_ = nullptr;
T
Tao Luo 已提交
396
}
Y
Yan Chunwei 已提交
397 398

}  // namespace paddle