dlnne_engine_op.h 26.0 KB
Newer Older
D
denglin-github 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cuda.h>          // NOTLINT
#include <cuda_runtime.h>  // NOTLINT
#include <dlnne.h>         // NOTLINT

D
denglin-github 已提交
20
#include <assert.h>
D
denglin-github 已提交
21 22 23 24
#include <ctime>
#include <fstream>
#include <iostream>
#include <memory>
D
denglin-github 已提交
25 26
#include <mutex>
#include <random>
D
denglin-github 已提交
27 28 29 30 31 32
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

D
denglin-github 已提交
33
#include "paddle/fluid/framework/data_device_transform.h"
D
denglin-github 已提交
34 35 36 37
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
D
denglin-github 已提交
38 39 40 41 42 43 44 45 46 47
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/ddim.h"

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
D
denglin-github 已提交
48 49 50 51 52 53 54 55

namespace dl {
namespace nne {
class Builder;
class Engine;
class Network;
class Parser;
class ExecutionContext;
D
denglin-github 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

inline unsigned int GetElementSize(DataType type) {
  switch (type) {
    case DataType::kINT64:
    case DataType::kUINT64:
    case DataType::kFLOAT64:
      return 8;
    case DataType::kINT32:
    case DataType::kUINT32:
    case DataType::kFLOAT32:
      return 4;
    case DataType::kINT16:
    case DataType::kUINT16:
    case DataType::kFLOAT16:
      return 2;
    case DataType::kINT8:
    case DataType::kUINT8:
    case DataType::kBOOL:
      return 1;
    case DataType::kUNKNOWN_TYPE:
      return 0;
  }
  return 0;
}

D
denglin-github 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
}  // namespace nne
}  // namespace dl

namespace paddle {
namespace inference {
class NneDeleter {
 public:
  NneDeleter() {}

  template <typename T>
  inline void operator()(T *ptr) {
    if (ptr != nullptr) {
      ptr->Destroy();
    }
  }
};

void CopyTensorDeviceToCpu(void *dst_ptr, void *src_ptr, int total_bytes);

void CopyTensorCpuToDevice(void *dst_ptr, void *src_ptr, int total_bytes);

D
denglin-github 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
std::string ConvertType(paddle::experimental::DataType type);

int GetDataByte(paddle::experimental::DataType type);

std::string GenerateRandomKey();

void ConvertPaddle2Onnx(std::string onnx_file_name,
                        std::string subgraph_root_path);

void QuantizeOnnx(std::string onnx_file_name,
                  std::string rlym_file_name,
                  std::string quantized_rlym_file_name,
                  std::string dataset_path,
                  std::string dataset_plugin_path);

static paddle::experimental::DataType DLNNE2FluidDataType(
    dl::nne::DataType type) {
  switch (type) {
    case dl::nne::DataType::kFLOAT32:
      return paddle::experimental::DataType::FLOAT32;
    case dl::nne::DataType::kINT32:
      return paddle::experimental::DataType::INT32;
    case dl::nne::DataType::kINT64:
      return paddle::experimental::DataType::INT64;
    case dl::nne::DataType::kFLOAT16:
      return paddle::experimental::DataType::FLOAT16;
    case dl::nne::DataType::kUINT8:
      return paddle::experimental::DataType::UINT8;
    case dl::nne::DataType::kINT8:
      return paddle::experimental::DataType::INT8;
    case dl::nne::DataType::kBOOL:
      return paddle::experimental::DataType::BOOL;
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown fluid datatype in Fluid op converter"));
      return paddle::experimental::DataType::FLOAT32;
  }
}

D
denglin-github 已提交
141 142 143 144 145 146 147
}  // namespace inference
}  // namespace paddle

namespace paddle {

namespace operators {

D
denglin-github 已提交
148 149
std::mutex static dlnne_create_lock;

D
denglin-github 已提交
150 151 152 153 154
class DlnneEngineOp : public framework::OperatorBase {
 private:
  std::vector<std::string> input_names_;
  std::unordered_set<std::string> param_names_;
  std::string engine_key_;
D
denglin-github 已提交
155 156 157 158 159 160 161 162 163
  bool use_static_batch_;
  bool calibration_mode_;
  std::string calibration_data_path_;
  std::string subgraph_root_path_;
  bool enable_int8_;
  bool use_calib_mode_;

  std::string weight_share_mode_;
  int max_batch_size_;
D
denglin-github 已提交
164 165
  int num_inputs;
  int num_outputs;
D
denglin-github 已提交
166 167
  // std::vector<std::string> output_names;
  // std::vector<std::string> input_names;
D
denglin-github 已提交
168 169 170 171 172 173 174 175 176 177

  dl::nne::Builder *builder;
  dl::nne::Parser *parser;
  dl::nne::Network *network;
  dl::nne::ExecutionContext *context;
  dl::nne::Engine *engine;

  unsigned int engine_input_size;
  std::vector<int> InputIndexToBindIndex_;

D
denglin-github 已提交
178 179 180 181
  char *dump_flag_;
  char *dlnne_log_flag_;
  char *dl_sdk_dir_;

D
denglin-github 已提交
182 183 184 185 186 187 188 189
 public:
  DlnneEngineOp(const std::string &type,
                const framework::VariableNameMap &inputs,
                const framework::VariableNameMap &outputs,
                const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {
    input_names_ = Inputs("Xs");
    engine_key_ = Attr<std::string>("engine_key");
D
denglin-github 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203
    use_static_batch_ = Attr<bool>("use_static_batch");
    max_batch_size_ = Attr<int32_t>("max_batch_size");
    weight_share_mode_ = Attr<std::string>("weight_share_mode");
    calibration_mode_ = Attr<bool>("calibration_mode");
    calibration_data_path_ = Attr<std::string>("calibration_data_path");
    subgraph_root_path_ = Attr<std::string>("subgraph_root_path");
    enable_int8_ = Attr<bool>("enable_int8");
    use_calib_mode_ = Attr<bool>("use_calib_mode");

    // dump input/output buffer of dlnne engine
    dump_flag_ = getenv("PADDLE_DUMP_DLNNE_BUFFER");
    dlnne_log_flag_ = getenv("PADDLE_DLNNE_LOG");
    dl_sdk_dir_ = getenv("DL_SDK_DIR");

D
denglin-github 已提交
204 205 206 207 208
    auto params = Attr<std::vector<std::string>>("parameters");
    for (const auto &param : params) {
      param_names_.insert(param);
    }

D
denglin-github 已提交
209 210 211 212
    std::vector<std::string> XsMap;
    num_inputs = Inputs("Xs").size();
    std::string valid_input_name_str = Attr<std::string>("valid_input_names");

D
denglin-github 已提交
213
    for (const auto &x : Inputs("Xs")) {
D
denglin-github 已提交
214 215 216 217 218
      // input_names.push_back(x);
      XsMap.push_back(
          valid_input_name_str.substr(0, valid_input_name_str.find(",")));
      valid_input_name_str =
          valid_input_name_str.substr(valid_input_name_str.find(",") + 1);
D
denglin-github 已提交
219
    }
D
denglin-github 已提交
220
    std::vector<std::string> YsMap;
D
denglin-github 已提交
221 222

    num_outputs = Outputs("Ys").size();
D
denglin-github 已提交
223
    std::string valid_output_name_str = Attr<std::string>("valid_output_names");
D
denglin-github 已提交
224
    for (const auto &y : Outputs("Ys")) {
D
denglin-github 已提交
225 226 227 228 229
      // output_names.push_back(y);
      YsMap.push_back(
          valid_output_name_str.substr(0, valid_output_name_str.find(",")));
      valid_output_name_str =
          valid_output_name_str.substr(valid_output_name_str.find(",") + 1);
D
denglin-github 已提交
230 231
    }

D
denglin-github 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
    // TODO(pei.jiang): add dlnne_engine manager to manage dlnne_engine
    if (!calibration_mode_) {
      std::map<std::string, dl::nne::WeightShareMode> weight_share_map;
      weight_share_map.insert(
          std::make_pair("0", dl::nne::WeightShareMode::kSingle));
      weight_share_map.insert(
          std::make_pair("1", dl::nne::WeightShareMode::kSingle));
      weight_share_map.insert(
          std::make_pair("2", dl::nne::WeightShareMode::kSingle));
      weight_share_map.insert(
          std::make_pair("3", dl::nne::WeightShareMode::kSingle));
      weight_share_map.insert(
          std::make_pair("01", dl::nne::WeightShareMode::kShare2));
      weight_share_map.insert(
          std::make_pair("23", dl::nne::WeightShareMode::kShare2));
      weight_share_map.insert(
          std::make_pair("0123", dl::nne::WeightShareMode::kShare4));

      std::map<std::string, dl::nne::ClusterConfig> cluster_config_map;
      cluster_config_map.insert(
          std::make_pair("0", dl::nne::ClusterConfig::kCluster0));
      cluster_config_map.insert(
          std::make_pair("1", dl::nne::ClusterConfig::kCluster1));
      cluster_config_map.insert(
          std::make_pair("2", dl::nne::ClusterConfig::kCluster2));
      cluster_config_map.insert(
          std::make_pair("3", dl::nne::ClusterConfig::kCluster3));
      cluster_config_map.insert(
          std::make_pair("01", dl::nne::ClusterConfig::kCluster01));
      cluster_config_map.insert(
          std::make_pair("23", dl::nne::ClusterConfig::kCluster23));
      cluster_config_map.insert(
          std::make_pair("0123", dl::nne::ClusterConfig::kCluster0123));

      dl::nne::WeightShareMode mode = weight_share_map[weight_share_mode_];
      dl::nne::ClusterConfig cluster_config =
          cluster_config_map[weight_share_mode_];
      if (dlnne_log_flag_) {
        LOG(INFO) << "weight_share_mode: " << mode
                  << " cluster_config: " << cluster_config;
      }
D
denglin-github 已提交
273

D
denglin-github 已提交
274 275 276 277 278 279
      std::string onnx_file_name =
          subgraph_root_path_ + "/" + engine_key_ + ".onnx";
      inference::ConvertPaddle2Onnx(onnx_file_name, subgraph_root_path_);

      std::string rlym_file_name =
          subgraph_root_path_ + "/" + engine_key_ + ".rlym";
S
Shuangchi He 已提交
280
      // quantize don't support set quantized output model path now,
D
denglin-github 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
      // the quantized model file is in current dir
      std::string quantized_rlym_file_name = engine_key_ + ".quantized.rlym";

      std::stringstream filename;
      std::stringstream engine_file_name;

      if (enable_int8_ && use_calib_mode_) {
        std::string dataset_path = calibration_data_path_;
        std::string cnt_dataset_path = dataset_path + "/" + input_names_[0];

        std::stringstream dataset_plugin_path;
        dataset_plugin_path << dl_sdk_dir_
                            << "/python/dleol/quantize/plugin.py";

        inference::QuantizeOnnx(onnx_file_name,
                                rlym_file_name,
                                quantized_rlym_file_name,
                                dataset_path,
                                dataset_plugin_path.str());

        filename << quantized_rlym_file_name;
        engine_file_name << subgraph_root_path_ << "/" << engine_key_
                         << "_quantized"
                         << "_ws_" << weight_share_mode_ << ".engine";
      } else {
        filename << onnx_file_name;
        engine_file_name << subgraph_root_path_ << "/" << engine_key_ << "_ws_"
                         << weight_share_mode_ << ".engine";
      }

      dlnne_create_lock.lock();
      if (dlnne_log_flag_) {
        LOG(INFO) << "EngineKey:" << engine_key_
                  << " use_static_batch_:" << use_static_batch_
                  << " max_batch_size_:" << max_batch_size_
                  << " weight_share_mode_: " << weight_share_mode_;
      }

      builder = dl::nne::CreateInferBuilder();
      PADDLE_ENFORCE_NE(
          builder,
          nullptr,
          platform::errors::Unavailable("nne create builder failed"));
      dl::nne::BuilderConfig builder_cfg;
      builder_cfg.max_batch_size = max_batch_size_;
      builder_cfg.ws_mode = weight_share_map[weight_share_mode_];
      builder->SetBuilderConfig(builder_cfg);
      network = builder->CreateNetwork();

      parser = dl::nne::CreateParser();
      PADDLE_ENFORCE_NE(
          parser,
          nullptr,
          platform::errors::Unavailable("nne create parser failed"));
      if (dlnne_log_flag_) {
        LOG(INFO) << "set output for dlnne";
      }
      for (std::string &output_op_name : YsMap) {
        parser->RegisterOutput(output_op_name.c_str());
        if (dlnne_log_flag_) {
          LOG(INFO) << output_op_name;
        }
      }

      std::fstream engine_file;
      engine_file.open(engine_file_name.str().c_str(), std::ios::in);
      if (!engine_file) {
        if (dlnne_log_flag_) {
          LOG(INFO) << "parser model file for dlnne";
        }
        parser->Parse(filename.str().c_str(), *network);
        if (dlnne_log_flag_) {
          LOG(INFO) << "build network";
        }
        engine = builder->BuildEngine(*network);

        auto memory = engine->Serialize();
        std::ofstream out(engine_file_name.str().c_str(),
                          std::ofstream::binary);
        out.write(reinterpret_cast<char *>(memory->Data()), memory->Size());
        out.close();
        memory->Destroy();
      } else {
        engine_file.seekg(0, std::ios::end);
        uint64_t length = static_cast<uint64_t>(engine_file.tellg());
        engine_file.seekg(0, std::ios::beg);
        char *slz_data = new char[length];
        engine_file.read(slz_data, static_cast<int64_t>(length));
        engine = dl::nne::Deserialize(slz_data, length);
        delete[] slz_data;
      }
D
denglin-github 已提交
372

D
denglin-github 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
      engine_input_size = num_inputs + num_outputs;
      for (std::string &input_name : XsMap) {
        int BindIndex = engine->GetBindingIndex(input_name.c_str());
        InputIndexToBindIndex_.push_back(BindIndex);
      }
      for (std::string &output_name : YsMap) {
        int BindIndex = engine->GetBindingIndex(output_name.c_str());
        InputIndexToBindIndex_.push_back(BindIndex);
      }

      // context
      context = engine->CreateExecutionContext(
          cluster_config_map[weight_share_mode_]);
      dlnne_create_lock.unlock();
    }
D
denglin-github 已提交
388 389 390
  }

  ~DlnneEngineOp() {
D
denglin-github 已提交
391 392 393 394 395 396 397
    if (!calibration_mode_) {
      network->Destroy();
      context->Destroy();
      engine->Destroy();
      parser->Destroy();
      builder->Destroy();
    }
D
denglin-github 已提交
398 399 400 401 402 403
  }

 protected:
  void RunDlnneOnCreateEngine(const framework::Scope &scope,
                              const platform::Place &dev_place) const {
    PADDLE_ENFORCE_EQ(
404 405
        input_names_.empty(),
        false,
D
denglin-github 已提交
406 407 408 409 410 411 412 413 414 415
        platform::errors::PreconditionNotMet(
            "Dlnne engine needs at least one input, but no input is found. "
            "Please check if you set the input correctly."));

    std::vector<void *> input_buffers(num_inputs);
    std::vector<void *> cpu_input_buffers(num_inputs);
    std::vector<std::vector<int64_t>> input_shapes(num_inputs);
    std::vector<int32_t> input_data_types(num_inputs);
    std::vector<int64_t> input_bytes(num_inputs);

D
denglin-github 已提交
416
    dlnne_create_lock.lock();
D
denglin-github 已提交
417
    int index = 0;
D
denglin-github 已提交
418 419 420 421 422 423 424
    int infer_batch = 1;
    std::vector<int> vec_infer_batch;
    // compute infer_batch
    if (use_static_batch_) {
      for (const auto &x : Inputs("Xs")) {
        if (param_names_.count(x)) continue;
        // convert input and copy to Dlnne engine's buffer
425
        auto &t = inference::analysis::GetFromScope<phi::DenseTensor>(scope, x);
D
denglin-github 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450

        auto t_shape = phi::vectorize<int64_t>(t.dims());
        std::vector<int64_t> runtime_input_shape(t_shape.begin(),
                                                 t_shape.end());
        const int bind_index = index;
        index++;
        dl::nne::Dims in_dim = engine->GetBindingDimensions(bind_index);

        int compute_batch = runtime_input_shape[0] / in_dim.d[0];
        VLOG(4) << "compute batch: " << compute_batch;
        vec_infer_batch.push_back(compute_batch);
      }

      int first_batch = vec_infer_batch[0];
      for (auto batch : vec_infer_batch) {
        PADDLE_ENFORCE_EQ(
            first_batch,
            batch,
            platform::errors::Unavailable(
                "compute infer_batchs is different from each other"));
      }
      infer_batch = first_batch;
    }

    index = 0;
D
denglin-github 已提交
451 452 453
    for (const auto &x : Inputs("Xs")) {
      if (param_names_.count(x)) continue;
      // convert input and copy to Dlnne engine's buffer
454
      auto &t = inference::analysis::GetFromScope<phi::DenseTensor>(scope, x);
D
denglin-github 已提交
455 456 457

      const int bind_index = index;
      index++;
D
denglin-github 已提交
458
      int64_t data_bytes, ele_num;
D
denglin-github 已提交
459
      int32_t dtype;
D
denglin-github 已提交
460
      auto type = t.type();
D
denglin-github 已提交
461
      data_bytes = 1;
D
denglin-github 已提交
462
      ele_num = 1;
D
denglin-github 已提交
463
      void *buffer = nullptr;
D
denglin-github 已提交
464 465
      // TODO(pei.jiang): add more type
      if (type == paddle::experimental::DataType::FLOAT32) {
D
denglin-github 已提交
466 467 468
        buffer = static_cast<void *>(t.data<float>());
        data_bytes = 4;
        dtype = 0;
D
denglin-github 已提交
469
      } else if (type == paddle::experimental::DataType::INT64) {
D
denglin-github 已提交
470 471 472
        buffer = static_cast<void *>(t.data<int64_t>());
        data_bytes = 8;
        dtype = 1;
D
denglin-github 已提交
473
      } else if (type == paddle::experimental::DataType::INT32) {
D
denglin-github 已提交
474 475 476
        buffer = static_cast<void *>(t.data<int32_t>());
        data_bytes = 4;
        dtype = 2;
D
denglin-github 已提交
477 478 479 480
      } else if (type == paddle::experimental::DataType::FLOAT16) {
        buffer = static_cast<void *>(t.data<paddle::platform::float16>());
        data_bytes = 2;
        dtype = 3;
D
denglin-github 已提交
481
      } else {
D
denglin-github 已提交
482 483 484
        PADDLE_THROW(
            platform::errors::Fatal("The DLNNE Engine OP only support "
                                    "float/int32_t/int64_t/float16 input."));
D
denglin-github 已提交
485 486 487
      }
      input_buffers[bind_index] = buffer;

488
      auto t_shape = phi::vectorize<int64_t>(t.dims());
D
denglin-github 已提交
489 490 491
      std::vector<int64_t> runtime_input_shape(t_shape.begin(), t_shape.end());
      for (auto &size : t_shape) {
        data_bytes = data_bytes * size;
D
denglin-github 已提交
492
        ele_num = ele_num * size;
D
denglin-github 已提交
493 494 495 496 497 498 499 500
      }

      VLOG(4) << "buffers_size:" << data_bytes;
      cpu_input_buffers[bind_index] =
          input_buffers[bind_index];  // malloc(data_bytes);
      input_shapes[bind_index] = runtime_input_shape;
      input_data_types[bind_index] = dtype;
      input_bytes[bind_index] = data_bytes;
D
denglin-github 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514

      if (dump_flag_) {
        std::stringstream dump_input_name;
        dump_input_name << engine_key_ << "_input_" << bind_index << ".txt";
        std::ofstream dump_input_file;
        dump_input_file.open(dump_input_name.str());
        for (int64_t i = 0; i < ele_num; i++) {
          dump_input_file << static_cast<float *>(
                                 cpu_input_buffers[bind_index])[i]
                          << "\n";
        }
        dump_input_file << "\b";
        dump_input_file.close();
      }
D
denglin-github 已提交
515 516 517 518
    }

    // output shape
    std::vector<std::vector<int64_t>> out_shapes;
D
denglin-github 已提交
519 520
    std::vector<dl::nne::DataType> out_types;
    std::vector<int64_t> out_ele_nums;
D
denglin-github 已提交
521 522
    std::vector<int32_t> output_bytes;
    for (int i = 0; i < num_outputs; i++) {
D
denglin-github 已提交
523 524 525
      int index = InputIndexToBindIndex_[i + num_inputs];
      dl::nne::DataType out_type = engine->GetBindingDataType(index);
      out_types.push_back(out_type);
D
denglin-github 已提交
526 527 528
      dl::nne::Dims out_dim = engine->GetBindingDimensions(index);
      std::vector<int64_t> shape(out_dim.nbDims);
      for (int dim = 0; dim < out_dim.nbDims; dim++) {
D
denglin-github 已提交
529 530 531 532 533
        if (use_static_batch_ && dim == 0) {
          shape[dim] = (out_dim.d[dim]) * infer_batch;
        } else {
          shape[dim] = (out_dim.d[dim]);
        }
D
denglin-github 已提交
534 535 536
      }

      out_shapes.push_back(shape);
D
denglin-github 已提交
537 538
      int64_t data_bytes, out_ele_num;
      out_ele_num = 1;
D
denglin-github 已提交
539 540

      // float32
D
denglin-github 已提交
541
      data_bytes = dl::nne::GetElementSize(out_type);
D
denglin-github 已提交
542 543
      for (auto &size : shape) {
        data_bytes = data_bytes * size;
D
denglin-github 已提交
544
        out_ele_num = out_ele_num * size;
D
denglin-github 已提交
545 546 547
      }
      VLOG(4) << "data_bytes: " << data_bytes;
      output_bytes.push_back(data_bytes);
D
denglin-github 已提交
548
      out_ele_nums.push_back(out_ele_num);
D
denglin-github 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561
    }

    int bind_index = 0;
    std::vector<void *> cpu_output_buffers(num_outputs);
    std::vector<void *> output_buffers(num_outputs);

    for (const auto &y : Outputs("Ys")) {
      auto *fluid_v = scope.FindVar(y);
      PADDLE_ENFORCE_NOT_NULL(
          fluid_v,
          platform::errors::NotFound(
              "Output variable %s is not found in DLNNE subgraph.", y));

562
      auto *fluid_t = fluid_v->GetMutable<phi::DenseTensor>();
D
denglin-github 已提交
563

D
denglin-github 已提交
564 565
      VLOG(4) << bind_index << ": out_shapes[bind_index] dim:"
              << out_shapes[bind_index].size();
566
      fluid_t->Resize(phi::make_ddim(out_shapes[bind_index]));
D
denglin-github 已提交
567

D
denglin-github 已提交
568 569 570 571 572 573 574
      dl::nne::DataType dl_type = out_types[bind_index];
      if (dlnne_log_flag_) {
        LOG(INFO) << "output type: " << dl_type;
      }
      output_buffers[bind_index] = static_cast<void *>(fluid_t->mutable_data(
          dev_place, inference::DLNNE2FluidDataType(dl_type)));

D
denglin-github 已提交
575 576 577 578 579 580 581 582 583
      cpu_output_buffers[bind_index] =
          output_buffers[bind_index];  // malloc(data_bytes);
      bind_index++;
    }

    std::vector<void *> engine_input_ptr(engine_input_size);

    // set input_ptr
    for (unsigned int i = 0; i < engine_input_size; i++) {
D
denglin-github 已提交
584 585 586
      if (InputIndexToBindIndex_[i] < 0) {
        continue;
      }
D
denglin-github 已提交
587 588 589 590 591 592 593 594 595 596 597 598

      if (engine->BindingIsInput(InputIndexToBindIndex_[i])) {
        // copy cpu buffer to gpu buffer
        int64_t total_bytes;
        total_bytes = input_bytes[i];
        VLOG(4) << "input_bytes: " << total_bytes;

        void *gpu_ptr;
        cudaMalloc(&gpu_ptr, total_bytes);
        engine_input_ptr[InputIndexToBindIndex_[i]] = gpu_ptr;

        paddle::inference::CopyTensorCpuToDevice(
599 600
            gpu_ptr,
            reinterpret_cast<void *>(cpu_input_buffers[i]),
D
denglin-github 已提交
601 602 603 604
            total_bytes);

      } else {
        int64_t total_size;
D
denglin-github 已提交
605
        total_size = output_bytes[i - input_names_.size()];
D
denglin-github 已提交
606 607 608 609 610 611 612 613 614
        VLOG(4) << "output_bytes: " << total_size;
        void *gpu_ptr;
        cudaMalloc(&gpu_ptr, total_size);
        engine_input_ptr[InputIndexToBindIndex_[i]] = gpu_ptr;
      }
    }

    clock_t startTime, endTime;
    startTime = clock();
D
denglin-github 已提交
615
    context->Execute(infer_batch, engine_input_ptr.data());
D
denglin-github 已提交
616
    endTime = clock();
D
denglin-github 已提交
617 618 619 620 621 622

    if (dlnne_log_flag_) {
      double during_ms =
          static_cast<double>(endTime - startTime) / CLOCKS_PER_SEC * 1000;
      LOG(INFO) << "dlNNE execute time: " << during_ms << " ms";
    }
D
denglin-github 已提交
623 624 625 626 627

    bind_index = 0;
    for (unsigned int i = 0; i < engine_input_size; i++) {
      if (InputIndexToBindIndex_[i] < 0) continue;

D
denglin-github 已提交
628 629
      if (i >= input_names_.size()) {
        void *cpu_ptr = cpu_output_buffers[i - input_names_.size()];
D
denglin-github 已提交
630
        int64_t size;
D
denglin-github 已提交
631
        size = output_bytes[i - input_names_.size()];
D
denglin-github 已提交
632 633 634 635
        paddle::inference::CopyTensorDeviceToCpu(
            cpu_ptr, engine_input_ptr[InputIndexToBindIndex_[i]], size);

        cpu_output_buffers[bind_index] = cpu_ptr;
D
denglin-github 已提交
636 637 638 639 640 641 642 643 644 645 646 647 648 649

        if (dump_flag_) {
          std::stringstream dump_output_name;
          dump_output_name << engine_key_ << "_output_" << bind_index << ".txt";
          std::ofstream dump_output_file;
          dump_output_file.open(dump_output_name.str());
          for (int64_t i = 0; i < out_ele_nums[bind_index]; i++) {
            dump_output_file
                << static_cast<float *>(cpu_output_buffers[bind_index])[i]
                << "\n";
          }
          dump_output_file << "\b";
          dump_output_file.close();
        }
D
denglin-github 已提交
650 651 652 653
        bind_index++;
      }
      cudaFree(engine_input_ptr[InputIndexToBindIndex_[i]]);
    }
D
denglin-github 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
    dlnne_create_lock.unlock();
  }

  void RunNativeImpl(const framework::Scope &scope,
                     const platform::Place &dev_place) const {
    VLOG(4) << "RunNativeImpl";
    framework::Executor executor(dev_place);
    auto *block = Attr<framework::BlockDesc *>("sub_block");
    auto *program = block->Program();
    auto &current_scope = scope.NewScope();
    auto ctx = executor.Prepare(*program, block->ID());
    executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
  }

  void RunCalibration(const framework::Scope &scope,
                      const platform::Place &dev_place) const {
    std::unordered_map<std::string, void *> calib_data_map;
    std::unordered_map<std::string, std::vector<int64_t>> calib_data_shape_map;
    std::unordered_map<std::string, std::string> calib_data_type_map;
    std::unordered_map<std::string, int64_t> calib_buffer_size_map;

    for (auto &x : Inputs("Xs")) {
      if (param_names_.count(x)) continue;
677
      auto &t = inference::analysis::GetFromScope<phi::DenseTensor>(scope, x);
D
denglin-github 已提交
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738
      calib_data_map.emplace(x, t.data());

      // TODO(pei.jiang): refine this code, because when run dlnne create
      // engine, there is same code
      auto t_shape = phi::vectorize<int64_t>(t.dims());
      std::vector<int64_t> input_shape(t_shape.begin(), t_shape.end());
      calib_data_shape_map.emplace(x, input_shape);
      std::string data_type = inference::ConvertType(t.type());
      calib_data_type_map.emplace(x, data_type);

      int data_bytes = inference::GetDataByte(t.type());
      VLOG(4) << "input name: " << x << ", data_type: " << data_type;
      VLOG(4) << "data shape: ";
      int64_t buffer_size = data_bytes;
      for (auto dim : input_shape) {
        buffer_size *= dim;
        VLOG(4) << dim;
      }
      VLOG(4) << "buffer_size: " << buffer_size;
      calib_buffer_size_map.emplace(x, buffer_size);
    }

    std::string random_key = inference::GenerateRandomKey();
    for (auto calib_data : calib_data_map) {
      std::string input_name = calib_data.first;
      std::string input_data_path = calibration_data_path_ + "/" + input_name;
      MKDIR(input_data_path.c_str());

      std::string input_data_item_path =
          input_data_path + "/" + random_key + ".binary";
      auto outfile = std::fstream(input_data_item_path.c_str(),
                                  std::ios::out | std::ios::binary);
      int64_t buffer_size = calib_buffer_size_map[input_name];
      outfile.write(reinterpret_cast<char *>(calib_data.second), buffer_size);
      outfile.close();
    }

    std::stringstream calib_config_ss;
    calib_config_ss << "shape message: " << std::endl;
    for (auto const &shape_item : calib_data_shape_map) {
      calib_config_ss << shape_item.first << ":";
      for (auto const &dim : shape_item.second) {
        calib_config_ss << dim << " ";
      }
      calib_config_ss << std::endl;
    }

    calib_config_ss << "dtype message: " << std::endl;
    for (auto const &dtype_item : calib_data_type_map) {
      calib_config_ss << dtype_item.first << ":" << dtype_item.second
                      << std::endl;
    }

    std::ofstream calib_config_file;
    std::string calib_config_path =
        calibration_data_path_ + "/calib_config.txt";
    calib_config_file.open(calib_config_path);
    calib_config_file << calib_config_ss.str();
    calib_config_file.close();

    RunNativeImpl(scope, dev_place);
D
denglin-github 已提交
739 740 741 742
  }

  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
D
denglin-github 已提交
743 744 745 746 747 748 749
    VLOG(4) << "calibration_mode_: " << calibration_mode_;
    if (calibration_mode_ == true) {
      VLOG(4) << "RunCalibration";
      RunCalibration(scope, dev_place);
      return;
    }

D
denglin-github 已提交
750 751 752 753 754 755
    RunDlnneOnCreateEngine(scope, dev_place);
  }
};

}  // namespace operators
}  // namespace paddle