api_anakin_engine.cc 8.4 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

L
Luo Tao 已提交
15
#include "paddle/fluid/inference/api/api_anakin_engine.h"
T
Tao Luo 已提交
16 17

#ifdef PADDLE_WITH_CUDA
Y
Yan Chunwei 已提交
18
#include <cuda.h>
T
Tao Luo 已提交
19 20 21 22 23 24 25
#endif

#include <mkl_service.h>
#include <omp.h>
#include <map>
#include <string>
#include <utility>
L
Luo Tao 已提交
26
#include <vector>
Y
Yan Chunwei 已提交
27

T
Tao Luo 已提交
28 29 30 31
#include "framework/core/net/net.h"
#include "framework/operators/ops.h"
#include "saber/funcs/timer.h"

Y
Yan Chunwei 已提交
32 33
namespace paddle {

C
cuichaowen 已提交
34 35
template <typename Target>
PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
Y
Yan Chunwei 已提交
36 37 38
    const AnakinConfig &config) {
  CHECK(Init(config));
}
T
Tao Luo 已提交
39 40 41 42 43 44 45 46
template <>
PaddleInferenceAnakinPredictor<anakin::X86>::PaddleInferenceAnakinPredictor(
    const AnakinConfig &config) {
  omp_set_dynamic(0);
  omp_set_num_threads(1);
  mkl_set_num_threads(1);
  CHECK(Init(config));
}
C
cuichaowen 已提交
47 48
template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
C
cuichaowen 已提交
49
  if (!(graph_.load(config.model_file))) {
T
Tao Luo 已提交
50
    VLOG(3) << "fail to load graph from " << config.model_file;
C
cuichaowen 已提交
51 52
    return false;
  }
C
cuichaowen 已提交
53 54 55
  auto inputs = graph_.get_ins();
  for (auto &input_str : inputs) {
    graph_.ResetBatchSize(input_str, config.max_batch_size);
T
Tao Luo 已提交
56
    max_batch_size_ = config.max_batch_size;
C
cuichaowen 已提交
57
  }
C
cuichaowen 已提交
58 59 60 61 62
  // optimization for graph
  if (!(graph_.Optimize())) {
    return false;
  }
  // construct executer
C
cuichaowen 已提交
63 64 65 66
  if (executor_p_ == nullptr) {
    executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
                                  anakin::Precision::FP32>(graph_, true);
  }
Y
Yan Chunwei 已提交
67 68 69
  return true;
}

C
cuichaowen 已提交
70 71
template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Run(
Y
Yan Chunwei 已提交
72
    const std::vector<PaddleTensor> &inputs,
73
    std::vector<PaddleTensor> *output_data, int batch_size) {
Y
Yan Chunwei 已提交
74 75
  for (const auto &input : inputs) {
    if (input.dtype != PaddleDType::FLOAT32) {
T
Tao Luo 已提交
76 77
      VLOG(3) << "Only support float type inputs. " << input.name
              << "'s type is not float";
Y
Yan Chunwei 已提交
78 79
      return false;
    }
C
cuichaowen 已提交
80
    auto d_tensor_in_p = executor_p_->get_in(input.name);
T
Tao Luo 已提交
81
    auto net_shape = d_tensor_in_p->shape();
C
cuichaowen 已提交
82
    if (net_shape.size() != input.shape.size()) {
T
Tao Luo 已提交
83 84
      VLOG(3) << " input  " << input.name
              << "'s shape size should be equal to that of net";
C
cuichaowen 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
      return false;
    }
    int sum = 1;
    for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
    if (sum > net_shape.count()) {
      graph_.Reshape(input.name, input.shape);
      delete executor_p_;
      executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
                                    anakin::Precision::FP32>(graph_, true);
      d_tensor_in_p = executor_p_->get_in(input.name);
    }

    anakin::saber::Shape tmp_shape;
    for (auto s : input.shape) {
      tmp_shape.push_back(s);
    }
    d_tensor_in_p->reshape(tmp_shape);

T
Tao Luo 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116
    if (input.lod.size() > 0) {
      if (input.lod.size() > 1) {
        VLOG(3) << " input lod first dim should <=1, but you set "
                << input.lod.size();
        return false;
      }
      std::vector<int> offset(input.lod[0].begin(), input.lod[0].end());
      d_tensor_in_p->set_seq_offset(offset);
      VLOG(3) << "offset.size(): " << offset.size();
      for (int i = 0; i < offset.size(); i++) {
        VLOG(3) << offset[i];
      }
    }

C
cuichaowen 已提交
117
    float *d_data_p = d_tensor_in_p->mutable_data();
T
Tao Luo 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131

#ifdef PADDLE_WITH_CUDA
    if (std::is_same<anakin::NV, Target>::value) {
      if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
                     d_tensor_in_p->valid_size() * sizeof(float),
                     cudaMemcpyHostToDevice) != 0) {
        VLOG(3) << "copy data from CPU to GPU error";
        return false;
      }
    }
#endif
    if (std::is_same<anakin::X86, Target>::value) {
      memcpy(d_data_p, static_cast<float *>(input.data.data()),
             d_tensor_in_p->valid_size() * sizeof(float));
C
cuichaowen 已提交
132
    }
Y
Yan Chunwei 已提交
133
  }
T
Tao Luo 已提交
134
#ifdef PADDLE_WITH_CUDA
C
cuichaowen 已提交
135 136 137
  cudaDeviceSynchronize();
  executor_p_->prediction();
  cudaDeviceSynchronize();
T
Tao Luo 已提交
138
#endif
Y
Yan Chunwei 已提交
139 140

  if (output_data->empty()) {
T
Tao Luo 已提交
141
    VLOG(3) << "At least one output should be set with tensors' names.";
Y
Yan Chunwei 已提交
142 143 144
    return false;
  }
  for (auto &output : *output_data) {
C
cuichaowen 已提交
145 146
    auto *tensor = executor_p_->get_out(output.name);
    output.shape = tensor->valid_shape();
147 148 149
    if (output.data.length() < tensor->valid_size() * sizeof(float)) {
      output.data.Resize(tensor->valid_size() * sizeof(float));
    }
T
Tao Luo 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

#if PADDLE_WITH_CUDA
    if (std::is_same<anakin::NV, Target>::value) {
      // Copy data from GPU -> CPU
      if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
                     tensor->valid_size() * sizeof(float),
                     cudaMemcpyDeviceToHost) != 0) {
        VLOG(3) << "copy data from GPU to CPU error";
        return false;
      }
    }
#endif
    if (std::is_same<anakin::X86, Target>::value) {
      memcpy(output.data.data(), tensor->mutable_data(),
             tensor->valid_size() * sizeof(float));
Y
Yan Chunwei 已提交
165 166 167 168 169
    }
  }
  return true;
}

C
cuichaowen 已提交
170 171 172 173
template <typename Target>
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
    &PaddleInferenceAnakinPredictor<Target>::get_executer() {
  return *executor_p_;
C
cuichaowen 已提交
174 175 176 177
}

// the cloned new Predictor of anakin share the same net weights from original
// Predictor
C
cuichaowen 已提交
178 179 180
template <typename Target>
std::unique_ptr<PaddlePredictor>
PaddleInferenceAnakinPredictor<Target>::Clone() {
C
cuichaowen 已提交
181
  VLOG(3) << "Anakin Predictor::clone";
C
cuichaowen 已提交
182 183
  std::unique_ptr<PaddlePredictor> cls(
      new PaddleInferenceAnakinPredictor<Target>());
C
cuichaowen 已提交
184 185
  // construct executer from other graph
  auto anakin_predictor_p =
C
cuichaowen 已提交
186
      dynamic_cast<PaddleInferenceAnakinPredictor<Target> *>(cls.get());
C
cuichaowen 已提交
187
  if (!anakin_predictor_p) {
T
Tao Luo 已提交
188
    VLOG(3) << "fail to call Init";
C
cuichaowen 已提交
189 190 191 192 193
    return nullptr;
  }
  anakin_predictor_p->get_executer().init(graph_);

  return std::move(cls);
Y
Yan Chunwei 已提交
194 195
}

C
cuichaowen 已提交
196 197 198
template class PaddleInferenceAnakinPredictor<anakin::NV>;
template class PaddleInferenceAnakinPredictor<anakin::X86>;

Y
Yan Chunwei 已提交
199 200
// A factory to help create difference predictor.
template <>
201 202
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
    AnakinConfig, PaddleEngineKind::kAnakin>(const AnakinConfig &config) {
C
cuichaowen 已提交
203
  VLOG(3) << "Anakin Predictor create.";
C
cuichaowen 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217
  if (config.target_type == AnakinConfig::NVGPU) {
    VLOG(3) << "Anakin Predictor create on [ NVIDIA GPU ].";
    std::unique_ptr<PaddlePredictor> x(
        new PaddleInferenceAnakinPredictor<anakin::NV>(config));
    return x;
  } else if (config.target_type == AnakinConfig::X86) {
    VLOG(3) << "Anakin Predictor create on [ Intel X86 ].";
    std::unique_ptr<PaddlePredictor> x(
        new PaddleInferenceAnakinPredictor<anakin::X86>(config));
    return x;
  } else {
    VLOG(3) << "Anakin Predictor create on unknown platform.";
    return nullptr;
  }
T
Tao Luo 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
}

#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
template <typename Target>
using executor_t =
    anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>;

template <typename Target>
void DisplayOpTimer(executor_t<Target> *net_executor, int epoch) {
  std::vector<float> op_time = net_executor->get_op_time();
  auto exec_funcs = net_executor->get_exec_funcs();
  auto op_param = net_executor->get_op_param();
  for (int i = 0; i < op_time.size(); i++) {
    LOG(INFO) << "name: " << exec_funcs[i].name
              << " op_type: " << exec_funcs[i].op_name
              << " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
  }
  std::map<std::string, float> op_map;
  for (int i = 0; i < op_time.size(); i++) {
    auto it = op_map.find(op_param[i]);
    if (it != op_map.end())
      op_map[op_param[i]] += op_time[i];
    else
      op_map.insert(std::pair<std::string, float>(op_param[i], op_time[i]));
  }
  for (auto it = op_map.begin(); it != op_map.end(); ++it) {
    LOG(INFO) << it->first << "  " << (it->second) / epoch << " ms";
  }
}
#endif

template <typename Target>
PaddleInferenceAnakinPredictor<Target>::~PaddleInferenceAnakinPredictor() {
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
  DisplayOpTimer<Target>(executor_p_, max_batch_size_);
#endif
  delete executor_p_;
  executor_p_ = nullptr;
}
Y
Yan Chunwei 已提交
257 258

}  // namespace paddle