api_paddle_mobile.cc 4.9 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "io/api_paddle_mobile.h"
#include <vector>
17
#include "common/enforce.h"
N
nhzlx 已提交
18 19 20 21
#include "framework/tensor.h"

namespace paddle_mobile {

22 23
template <typename Device, typename T>
PaddleMobilePredictor<Device, T>::PaddleMobilePredictor(
N
nhzlx 已提交
24 25 26 27 28 29
    const PaddleMobileConfig &config) {
  PADDLE_MOBILE_ENFORCE(Init(config) == true,
                        "paddle mobile predictor init failed!");
  config_ = config;
}

30 31 32
template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) {
  paddle_mobile_.reset(new PaddleMobile<Device, T>());
Y
yangfei 已提交
33 34 35
#ifdef PADDLE_MOBILE_CL
  paddle_mobile_->SetCLPath(config.cl_path);
#endif
xiebaiyuan's avatar
xiebaiyuan 已提交
36 37 38 39 40 41 42
  if (config.memory_pack.from_memory) {
    DLOG << "load from memory!";
    paddle_mobile_->LoadCombinedMemory(config.memory_pack.model_size,
                                       config.memory_pack.model_buf,
                                       config.memory_pack.combined_params_size,
                                       config.memory_pack.combined_params_buf);
  } else if (!config.model_dir.empty()) {
N
nhzlx 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
    paddle_mobile_->Load(config.model_dir, config.optimize,
                         config.quantification, config.batch_size);
  } else if (!config.prog_file.empty() && !config.param_file.empty()) {
    paddle_mobile_->Load(config.prog_file, config.param_file, config.optimize,
                         config.quantification, config.batch_size);
  } else {
    LOG(kLOG_ERROR) << "fail to load inference model!";
    return false;
  }
  // If the openmp is open, set the thread num
  paddle_mobile_->SetThreadNum(config.thread_num);
  return true;
}
56 57
template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Run(
N
nhzlx 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    const std::vector<PaddleTensor> &inputs,
    std::vector<PaddleTensor> *output_data, int batch_size) {
  if (inputs.empty()) {
    LOG(kLOG_ERROR) << "At least one output should be set with tensors' names.";
    return false;
  }
  auto input = inputs[0];

  if (input.shape.size() != 4) {
    LOG(kLOG_ERROR) << "input shape not equal to 4!";
    return false;
  }
  std::vector<int64_t> dims;
  for (auto d : input.shape) {
    dims.push_back(static_cast<int64_t>(d));
  }

  // use tensor
  framework::DDim ddim =
      framework::make_ddim({dims[0], dims[1], dims[2], dims[3]});

  framework::Tensor input_tensor;
  input_tensor.Resize(ddim);
  int input_length = framework::product(ddim);
82
  auto input_ptr = input_tensor.mutable_data<T>();
N
nhzlx 已提交
83

84 85 86 87
  memcpy(input_ptr, static_cast<T *>(input.data.data()),
         input_length * sizeof(T));
  paddle_mobile_->Predict(input_tensor);
  auto output_tensor = paddle_mobile_->Fetch();
N
nhzlx 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

  if (output_data->empty()) {
    LOG(kLOG_ERROR) << "At least one output should be set with tensors' names.";
    return false;
  }

  auto &output = (*output_data)[0];
  int output_length = output_tensor->numel();
  std::vector<int64_t> tensor_shape =
      framework::vectorize(output_tensor->dims());

  for (auto d : tensor_shape) {
    output.shape.push_back(static_cast<int>(d));
  }

103 104
  if (output.data.length() < output_length * sizeof(T)) {
    output.data.Resize(output_length * sizeof(T));
N
nhzlx 已提交
105 106
  }

107 108
  memcpy(output.data.data(), output_tensor->template data<T>(),
         output_length * sizeof(T));
N
nhzlx 已提交
109 110 111 112

  return true;
}

113 114
template <typename Device, typename T>
PaddleMobilePredictor<Device, T>::~PaddleMobilePredictor() {
L
liuruilong 已提交
115 116 117
  paddle_mobile_->Clear();
}

N
nhzlx 已提交
118 119 120 121 122 123 124 125
// A factory to help create difference predictor.
template <>
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<PaddleMobileConfig, PaddleEngineKind::kPaddleMobile>(
    const PaddleMobileConfig &config) {
  std::unique_ptr<PaddlePredictor> x;
  if (config.precision == PaddleMobileConfig::FP32) {
    if (config.device == PaddleMobileConfig::kCPU) {
126
      x.reset(new PaddleMobilePredictor<CPU, float>(config));
N
nhzlx 已提交
127
    } else if (config.device == PaddleMobileConfig::kFPGA) {
128
      x.reset(new PaddleMobilePredictor<FPGA, float>(config));
N
nhzlx 已提交
129
    } else if (config.device == PaddleMobileConfig::kGPU_MALI) {
130
      x.reset(new PaddleMobilePredictor<GPU_MALI, float>(config));
L
liuruilong 已提交
131
    } else if (config.device == PaddleMobileConfig::kGPU_CL) {
132
      x.reset(new PaddleMobilePredictor<GPU_CL, float>(config));
N
nhzlx 已提交
133 134 135 136 137 138 139 140 141 142 143 144
    } else {
      LOG(kLOG_ERROR) << "unsupport device type!";
      return nullptr;
    }
  } else {
    LOG(kLOG_ERROR) << "unsupport precision type!";
    return nullptr;
  }
  return std::move(x);
}

}  // namespace paddle_mobile