api_anakin_engine.cc 8.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

L
Luo Tao 已提交
15
#include "paddle/fluid/inference/api/api_anakin_engine.h"
T
Tao Luo 已提交
16 17

#ifdef PADDLE_WITH_CUDA
Y
Yan Chunwei 已提交
18
#include <cuda.h>
T
Tao Luo 已提交
19 20 21 22 23 24 25
#endif

#include <mkl_service.h>
#include <omp.h>
#include <map>
#include <string>
#include <utility>
L
Luo Tao 已提交
26
#include <vector>
Y
Yan Chunwei 已提交
27

T
Tao Luo 已提交
28 29 30 31
#include "framework/core/net/net.h"
#include "framework/operators/ops.h"
#include "saber/funcs/timer.h"

Y
Yan Chunwei 已提交
32 33
namespace paddle {

Y
Yan Chunwei 已提交
34 35
using paddle::contrib::AnakinConfig;

C
cuichaowen 已提交
36 37
template <typename Target>
PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
Y
Yan Chunwei 已提交
38
    const contrib::AnakinConfig &config) {
Y
Yan Chunwei 已提交
39 40
  CHECK(Init(config));
}
T
Tao Luo 已提交
41 42
template <>
PaddleInferenceAnakinPredictor<anakin::X86>::PaddleInferenceAnakinPredictor(
Y
Yan Chunwei 已提交
43
    const contrib::AnakinConfig &config) {
T
Tao Luo 已提交
44 45 46 47 48
  omp_set_dynamic(0);
  omp_set_num_threads(1);
  mkl_set_num_threads(1);
  CHECK(Init(config));
}
C
cuichaowen 已提交
49
template <typename Target>
Y
Yan Chunwei 已提交
50 51
bool PaddleInferenceAnakinPredictor<Target>::Init(
    const contrib::AnakinConfig &config) {
C
cuichaowen 已提交
52
  if (!(graph_.load(config.model_file))) {
T
Tao Luo 已提交
53
    VLOG(3) << "fail to load graph from " << config.model_file;
C
cuichaowen 已提交
54 55
    return false;
  }
C
cuichaowen 已提交
56 57 58
  auto inputs = graph_.get_ins();
  for (auto &input_str : inputs) {
    graph_.ResetBatchSize(input_str, config.max_batch_size);
T
Tao Luo 已提交
59
    max_batch_size_ = config.max_batch_size;
C
cuichaowen 已提交
60
  }
C
cuichaowen 已提交
61 62 63 64 65
  // optimization for graph
  if (!(graph_.Optimize())) {
    return false;
  }
  // construct executer
C
cuichaowen 已提交
66 67 68 69
  if (executor_p_ == nullptr) {
    executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
                                  anakin::Precision::FP32>(graph_, true);
  }
Y
Yan Chunwei 已提交
70 71 72
  return true;
}

C
cuichaowen 已提交
73 74
template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Run(
Y
Yan Chunwei 已提交
75
    const std::vector<PaddleTensor> &inputs,
76
    std::vector<PaddleTensor> *output_data, int batch_size) {
Y
Yan Chunwei 已提交
77 78
  for (const auto &input : inputs) {
    if (input.dtype != PaddleDType::FLOAT32) {
T
Tao Luo 已提交
79 80
      VLOG(3) << "Only support float type inputs. " << input.name
              << "'s type is not float";
Y
Yan Chunwei 已提交
81 82
      return false;
    }
C
cuichaowen 已提交
83
    auto d_tensor_in_p = executor_p_->get_in(input.name);
T
Tao Luo 已提交
84
    auto net_shape = d_tensor_in_p->shape();
C
cuichaowen 已提交
85
    if (net_shape.size() != input.shape.size()) {
T
Tao Luo 已提交
86 87
      VLOG(3) << " input  " << input.name
              << "'s shape size should be equal to that of net";
C
cuichaowen 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
      return false;
    }
    int sum = 1;
    for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
    if (sum > net_shape.count()) {
      graph_.Reshape(input.name, input.shape);
      delete executor_p_;
      executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
                                    anakin::Precision::FP32>(graph_, true);
      d_tensor_in_p = executor_p_->get_in(input.name);
    }

    anakin::saber::Shape tmp_shape;
    for (auto s : input.shape) {
      tmp_shape.push_back(s);
    }
    d_tensor_in_p->reshape(tmp_shape);

T
Tao Luo 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119
    if (input.lod.size() > 0) {
      if (input.lod.size() > 1) {
        VLOG(3) << " input lod first dim should <=1, but you set "
                << input.lod.size();
        return false;
      }
      std::vector<int> offset(input.lod[0].begin(), input.lod[0].end());
      d_tensor_in_p->set_seq_offset(offset);
      VLOG(3) << "offset.size(): " << offset.size();
      for (int i = 0; i < offset.size(); i++) {
        VLOG(3) << offset[i];
      }
    }

C
cuichaowen 已提交
120
    float *d_data_p = d_tensor_in_p->mutable_data();
T
Tao Luo 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134

#ifdef PADDLE_WITH_CUDA
    if (std::is_same<anakin::NV, Target>::value) {
      if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
                     d_tensor_in_p->valid_size() * sizeof(float),
                     cudaMemcpyHostToDevice) != 0) {
        VLOG(3) << "copy data from CPU to GPU error";
        return false;
      }
    }
#endif
    if (std::is_same<anakin::X86, Target>::value) {
      memcpy(d_data_p, static_cast<float *>(input.data.data()),
             d_tensor_in_p->valid_size() * sizeof(float));
C
cuichaowen 已提交
135
    }
Y
Yan Chunwei 已提交
136
  }
T
Tao Luo 已提交
137
#ifdef PADDLE_WITH_CUDA
C
cuichaowen 已提交
138 139 140
  cudaDeviceSynchronize();
  executor_p_->prediction();
  cudaDeviceSynchronize();
T
Tao Luo 已提交
141
#endif
Y
Yan Chunwei 已提交
142 143

  if (output_data->empty()) {
T
Tao Luo 已提交
144
    VLOG(3) << "At least one output should be set with tensors' names.";
Y
Yan Chunwei 已提交
145 146 147
    return false;
  }
  for (auto &output : *output_data) {
C
cuichaowen 已提交
148 149
    auto *tensor = executor_p_->get_out(output.name);
    output.shape = tensor->valid_shape();
150 151 152
    if (output.data.length() < tensor->valid_size() * sizeof(float)) {
      output.data.Resize(tensor->valid_size() * sizeof(float));
    }
T
Tao Luo 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

#if PADDLE_WITH_CUDA
    if (std::is_same<anakin::NV, Target>::value) {
      // Copy data from GPU -> CPU
      if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
                     tensor->valid_size() * sizeof(float),
                     cudaMemcpyDeviceToHost) != 0) {
        VLOG(3) << "copy data from GPU to CPU error";
        return false;
      }
    }
#endif
    if (std::is_same<anakin::X86, Target>::value) {
      memcpy(output.data.data(), tensor->mutable_data(),
             tensor->valid_size() * sizeof(float));
Y
Yan Chunwei 已提交
168 169 170 171 172
    }
  }
  return true;
}

C
cuichaowen 已提交
173 174 175 176
template <typename Target>
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
    &PaddleInferenceAnakinPredictor<Target>::get_executer() {
  return *executor_p_;
C
cuichaowen 已提交
177 178 179 180
}

// the cloned new Predictor of anakin share the same net weights from original
// Predictor
C
cuichaowen 已提交
181 182 183
template <typename Target>
std::unique_ptr<PaddlePredictor>
PaddleInferenceAnakinPredictor<Target>::Clone() {
C
cuichaowen 已提交
184
  VLOG(3) << "Anakin Predictor::clone";
C
cuichaowen 已提交
185 186
  std::unique_ptr<PaddlePredictor> cls(
      new PaddleInferenceAnakinPredictor<Target>());
C
cuichaowen 已提交
187 188
  // construct executer from other graph
  auto anakin_predictor_p =
C
cuichaowen 已提交
189
      dynamic_cast<PaddleInferenceAnakinPredictor<Target> *>(cls.get());
C
cuichaowen 已提交
190
  if (!anakin_predictor_p) {
T
Tao Luo 已提交
191
    VLOG(3) << "fail to call Init";
C
cuichaowen 已提交
192 193 194 195 196
    return nullptr;
  }
  anakin_predictor_p->get_executer().init(graph_);

  return std::move(cls);
Y
Yan Chunwei 已提交
197 198
}

199
#ifdef PADDLE_WITH_CUDA
C
cuichaowen 已提交
200
template class PaddleInferenceAnakinPredictor<anakin::NV>;
201
#endif
C
cuichaowen 已提交
202 203
template class PaddleInferenceAnakinPredictor<anakin::X86>;

Y
Yan Chunwei 已提交
204 205
// A factory to help create difference predictor.
template <>
Y
Yan Chunwei 已提交
206 207 208
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
    const contrib::AnakinConfig &config) {
C
cuichaowen 已提交
209
  VLOG(3) << "Anakin Predictor create.";
Y
Yan Chunwei 已提交
210
  if (config.target_type == contrib::AnakinConfig::NVGPU) {
211
#ifdef PADDLE_WITH_CUDA
C
cuichaowen 已提交
212 213 214 215
    VLOG(3) << "Anakin Predictor create on [ NVIDIA GPU ].";
    std::unique_ptr<PaddlePredictor> x(
        new PaddleInferenceAnakinPredictor<anakin::NV>(config));
    return x;
216 217 218 219
#else
    LOG(ERROR) << "AnakinConfig::NVGPU could not used in ONLY-CPU environment";
    return nullptr;
#endif
Y
Yan Chunwei 已提交
220
  } else if (config.target_type == contrib::AnakinConfig::X86) {
C
cuichaowen 已提交
221 222 223 224 225 226 227 228
    VLOG(3) << "Anakin Predictor create on [ Intel X86 ].";
    std::unique_ptr<PaddlePredictor> x(
        new PaddleInferenceAnakinPredictor<anakin::X86>(config));
    return x;
  } else {
    VLOG(3) << "Anakin Predictor create on unknown platform.";
    return nullptr;
  }
T
Tao Luo 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
}

#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
template <typename Target>
using executor_t =
    anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>;

template <typename Target>
void DisplayOpTimer(executor_t<Target> *net_executor, int epoch) {
  std::vector<float> op_time = net_executor->get_op_time();
  auto exec_funcs = net_executor->get_exec_funcs();
  auto op_param = net_executor->get_op_param();
  for (int i = 0; i < op_time.size(); i++) {
    LOG(INFO) << "name: " << exec_funcs[i].name
              << " op_type: " << exec_funcs[i].op_name
              << " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
  }
  std::map<std::string, float> op_map;
  for (int i = 0; i < op_time.size(); i++) {
    auto it = op_map.find(op_param[i]);
    if (it != op_map.end())
      op_map[op_param[i]] += op_time[i];
    else
      op_map.insert(std::pair<std::string, float>(op_param[i], op_time[i]));
  }
  for (auto it = op_map.begin(); it != op_map.end(); ++it) {
    LOG(INFO) << it->first << "  " << (it->second) / epoch << " ms";
  }
}
#endif

template <typename Target>
PaddleInferenceAnakinPredictor<Target>::~PaddleInferenceAnakinPredictor() {
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
  DisplayOpTimer<Target>(executor_p_, max_batch_size_);
#endif
  delete executor_p_;
  executor_p_ = nullptr;
}
Y
Yan Chunwei 已提交
268 269

}  // namespace paddle