未验证 提交 decda738 编写于 作者: T Tao Luo 提交者: GitHub

fea/anakin compile with demo (#12772)

* anakin support x86

* fix code style

* add anakin ditu cnn demo

* add timer

* add rnn

* fix inference_anakin_cnn/rnn_test compile error

* make anakin_rnn_tester run

* add anakin_enable_op_time option

* update api/CMakeLists.txt

* enlarge the max_batch_size in anakin.config

* update with comments
上级 bcaa1d5e
......@@ -2,6 +2,11 @@ if (NOT WITH_ANAKIN)
return()
endif()
option(ANAKIN_ENABLE_OP_TIMER "Get more detailed information with Anakin op time" OFF)
if(ANAKIN_ENABLE_OP_TIMER)
add_definitions(-DPADDLE_ANAKIN_ENABLE_OP_TIMER)
endif()
INCLUDE(ExternalProject)
set(ANAKIN_SOURCE_DIR ${THIRD_PARTY_PATH}/anakin)
# the anakin install dir is only default one now
......@@ -11,23 +16,34 @@ set(ANAKIN_LIBRARY ${ANAKIN_INSTALL_DIR})
set(ANAKIN_SHARED_LIB ${ANAKIN_LIBRARY}/libanakin.so)
set(ANAKIN_SABER_LIB ${ANAKIN_LIBRARY}/libanakin_saber_common.so)
# TODO(luotao): ANAKIN_MODLE_URL will move to demo ci later.
set(ANAKIN_MODLE_URL "http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2.anakin.bin")
# TODO(luotao): ANAKIN_MODLE_URL etc will move to demo ci later.
set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com")
set(ANAKIN_MODLE_URL "${INFERENCE_URL}/mobilenet_v2.anakin.bin")
set(ANAKIN_RNN_MODLE_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn.anakin2.model.bin")
set(ANAKIN_RNN_DATA_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn_data.txt")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_SOURCE_DIR}")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL}")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL} -N")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_MODLE_URL} -N")
execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_DATA_URL} -N")
include_directories(${ANAKIN_INCLUDE})
include_directories(${ANAKIN_INCLUDE}/saber/)
include_directories(${ANAKIN_INCLUDE}/saber/core/)
include_directories(${ANAKIN_INCLUDE}/saber/funcs/impl/x86/)
include_directories(${ANAKIN_INCLUDE}/saber/funcs/impl/cuda/base/cuda_c/)
set(ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-error=unused-but-set-variable -Wno-unused-but-set-variable
-Wno-error=unused-variable -Wno-unused-variable
-Wno-error=format-extra-args -Wno-format-extra-args
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=maybe-uninitialized -Wno-maybe-uninitialized
-Wno-error=switch -Wno-switch
-Wno-error=return-type -Wno-return-type
-Wno-error=non-virtual-dtor -Wno-non-virtual-dtor
-Wno-error=ignored-qualifiers
-Wno-ignored-qualifiers
-Wno-sign-compare
-Wno-reorder
-Wno-error=cpp)
......@@ -38,7 +54,7 @@ ExternalProject_Add(
DEPENDS ${MKLML_PROJECT}
# Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code.
GIT_REPOSITORY "https://github.com/luotao1/Anakin"
GIT_TAG "bcf17aabe7921ceb7bce591244b4f9dce7dba5c8"
GIT_TAG "211d1fc5d813d70c0c14072f9083cf25f40940ea"
PREFIX ${ANAKIN_SOURCE_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DUSE_GPU_PLACE=YES
......@@ -48,6 +64,7 @@ ExternalProject_Add(
-DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml
-DCUDNN_ROOT=${CUDNN_ROOT}
-DCUDNN_INCLUDE_DIR=${CUDNN_INCLUDE_DIR}
-DENABLE_OP_TIMER=${ANAKIN_ENABLE_OP_TIMER}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR}
)
......
......@@ -65,7 +65,7 @@ endif()
if (WITH_ANAKIN AND WITH_GPU) # only needed in CI
# compile the libinference_anakin_api.a and anakin.so.
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber)
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber mklml)
cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber)
function(anakin_target target_name)
target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
......@@ -73,9 +73,12 @@ if (WITH_ANAKIN AND WITH_GPU) # only needed in CI
anakin_target(inference_anakin_api)
anakin_target(inference_anakin_api_shared)
if (WITH_TESTING)
cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc
cc_test(api_anakin_engine_tester SRCS api_anakin_engine_tester.cc
ARGS --model=${ANAKIN_SOURCE_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api dynload_cuda SERIAL)
target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
DEPS inference_anakin_api_shared dynload_cuda SERIAL)
cc_test(api_anakin_engine_rnn_tester SRCS api_anakin_engine_rnn_tester.cc
ARGS --model=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin
--datapath=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn_data.txt
DEPS inference_anakin_api_shared dynload_cuda SERIAL)
endif(WITH_TESTING)
endif()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -13,9 +13,22 @@
// limitations under the License.
#include "paddle/fluid/inference/api/api_anakin_engine.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif
#include <mkl_service.h>
#include <omp.h>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "framework/core/net/net.h"
#include "framework/operators/ops.h"
#include "saber/funcs/timer.h"
namespace paddle {
template <typename Target>
......@@ -23,16 +36,24 @@ PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
const AnakinConfig &config) {
CHECK(Init(config));
}
template <>
PaddleInferenceAnakinPredictor<anakin::X86>::PaddleInferenceAnakinPredictor(
const AnakinConfig &config) {
omp_set_dynamic(0);
omp_set_num_threads(1);
mkl_set_num_threads(1);
CHECK(Init(config));
}
template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
if (!(graph_.load(config.model_file))) {
LOG(FATAL) << "fail to load graph from " << config.model_file;
VLOG(3) << "fail to load graph from " << config.model_file;
return false;
}
auto inputs = graph_.get_ins();
for (auto &input_str : inputs) {
graph_.ResetBatchSize(input_str, config.max_batch_size);
max_batch_size_ = config.max_batch_size;
}
// optimization for graph
if (!(graph_.Optimize())) {
......@@ -52,15 +73,15 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
std::vector<PaddleTensor> *output_data, int batch_size) {
for (const auto &input : inputs) {
if (input.dtype != PaddleDType::FLOAT32) {
LOG(ERROR) << "Only support float type inputs. " << input.name
<< "'s type is not float";
VLOG(3) << "Only support float type inputs. " << input.name
<< "'s type is not float";
return false;
}
auto d_tensor_in_p = executor_p_->get_in(input.name);
auto net_shape = d_tensor_in_p->valid_shape();
auto net_shape = d_tensor_in_p->shape();
if (net_shape.size() != input.shape.size()) {
LOG(ERROR) << " input " << input.name
<< "'s shape size should be equal to that of net";
VLOG(3) << " input " << input.name
<< "'s shape size should be equal to that of net";
return false;
}
int sum = 1;
......@@ -79,21 +100,45 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
}
d_tensor_in_p->reshape(tmp_shape);
if (input.lod.size() > 0) {
if (input.lod.size() > 1) {
VLOG(3) << " input lod first dim should <=1, but you set "
<< input.lod.size();
return false;
}
std::vector<int> offset(input.lod[0].begin(), input.lod[0].end());
d_tensor_in_p->set_seq_offset(offset);
VLOG(3) << "offset.size(): " << offset.size();
for (int i = 0; i < offset.size(); i++) {
VLOG(3) << offset[i];
}
}
float *d_data_p = d_tensor_in_p->mutable_data();
if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float),
cudaMemcpyHostToDevice) != 0) {
LOG(ERROR) << "copy data from CPU to GPU error";
return false;
#ifdef PADDLE_WITH_CUDA
if (std::is_same<anakin::NV, Target>::value) {
if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float),
cudaMemcpyHostToDevice) != 0) {
VLOG(3) << "copy data from CPU to GPU error";
return false;
}
}
#endif
if (std::is_same<anakin::X86, Target>::value) {
memcpy(d_data_p, static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float));
}
cudaStreamSynchronize(NULL);
}
#ifdef PADDLE_WITH_CUDA
cudaDeviceSynchronize();
executor_p_->prediction();
cudaDeviceSynchronize();
#endif
if (output_data->empty()) {
LOG(ERROR) << "At least one output should be set with tensors' names.";
VLOG(3) << "At least one output should be set with tensors' names.";
return false;
}
for (auto &output : *output_data) {
......@@ -102,14 +147,22 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
if (output.data.length() < tensor->valid_size() * sizeof(float)) {
output.data.Resize(tensor->valid_size() * sizeof(float));
}
// Copy data from GPU -> CPU
if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
tensor->valid_size() * sizeof(float),
cudaMemcpyDeviceToHost) != 0) {
LOG(ERROR) << "copy data from GPU to CPU error";
return false;
#if PADDLE_WITH_CUDA
if (std::is_same<anakin::NV, Target>::value) {
// Copy data from GPU -> CPU
if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
tensor->valid_size() * sizeof(float),
cudaMemcpyDeviceToHost) != 0) {
VLOG(3) << "copy data from GPU to CPU error";
return false;
}
}
#endif
if (std::is_same<anakin::X86, Target>::value) {
memcpy(output.data.data(), tensor->mutable_data(),
tensor->valid_size() * sizeof(float));
}
cudaStreamSynchronize(NULL);
}
return true;
}
......@@ -132,7 +185,7 @@ PaddleInferenceAnakinPredictor<Target>::Clone() {
auto anakin_predictor_p =
dynamic_cast<PaddleInferenceAnakinPredictor<Target> *>(cls.get());
if (!anakin_predictor_p) {
LOG(ERROR) << "fail to call Init";
VLOG(3) << "fail to call Init";
return nullptr;
}
anakin_predictor_p->get_executer().init(graph_);
......@@ -162,6 +215,44 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
VLOG(3) << "Anakin Predictor create on unknown platform.";
return nullptr;
}
};
}
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
template <typename Target>
using executor_t =
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>;
template <typename Target>
void DisplayOpTimer(executor_t<Target> *net_executor, int epoch) {
std::vector<float> op_time = net_executor->get_op_time();
auto exec_funcs = net_executor->get_exec_funcs();
auto op_param = net_executor->get_op_param();
for (int i = 0; i < op_time.size(); i++) {
LOG(INFO) << "name: " << exec_funcs[i].name
<< " op_type: " << exec_funcs[i].op_name
<< " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
}
std::map<std::string, float> op_map;
for (int i = 0; i < op_time.size(); i++) {
auto it = op_map.find(op_param[i]);
if (it != op_map.end())
op_map[op_param[i]] += op_time[i];
else
op_map.insert(std::pair<std::string, float>(op_param[i], op_time[i]));
}
for (auto it = op_map.begin(); it != op_map.end(); ++it) {
LOG(INFO) << it->first << " " << (it->second) / epoch << " ms";
}
}
#endif
template <typename Target>
PaddleInferenceAnakinPredictor<Target>::~PaddleInferenceAnakinPredictor() {
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
DisplayOpTimer<Target>(executor_p_, max_batch_size_);
#endif
delete executor_p_;
executor_p_ = nullptr;
}
} // namespace paddle
......@@ -47,10 +47,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>&
get_executer();
~PaddleInferenceAnakinPredictor() override {
delete executor_p_;
executor_p_ = nullptr;
};
~PaddleInferenceAnakinPredictor() override;
private:
bool Init(const AnakinConfig& config);
......@@ -60,6 +57,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>*
executor_p_{nullptr};
AnakinConfig config_;
int max_batch_size_{0};
};
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <thread> // NOLINT
#include <vector>
#include "framework/core/net/net.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
DEFINE_string(model, "", "Directory of the inference model.");
DEFINE_string(datapath, "", "Path of the dataset.");
DEFINE_int32(batch_size, 1, "batch size.");
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
// Timer for timer
class Timer {
public:
double start;
double startu;
void tic() {
struct timeval tp;
gettimeofday(&tp, NULL);
start = tp.tv_sec;
startu = tp.tv_usec;
}
double toc() {
struct timeval tp;
gettimeofday(&tp, NULL);
double used_time_ms =
(tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0;
return used_time_ms;
}
};
std::vector<std::string> string_split(std::string in_str,
std::string delimiter) {
std::vector<std::string> seq;
int found = in_str.find(delimiter);
int pre_found = -1;
while (found != std::string::npos) {
if (pre_found == -1) {
seq.push_back(in_str.substr(0, found));
} else {
seq.push_back(in_str.substr(pre_found + delimiter.length(),
found - delimiter.length() - pre_found));
}
pre_found = found;
found = in_str.find(delimiter, pre_found + delimiter.length());
}
seq.push_back(
in_str.substr(pre_found + 1, in_str.length() - (pre_found + 1)));
return seq;
}
std::vector<std::string> string_split(
std::string in_str, std::vector<std::string>& delimiter) { // NOLINT
std::vector<std::string> in;
std::vector<std::string> out;
out.push_back(in_str);
for (auto del : delimiter) {
in = out;
out.clear();
for (auto s : in) {
auto out_s = string_split(s, del);
for (auto o : out_s) {
out.push_back(o);
}
}
}
return out;
}
class Data {
public:
Data(std::string file_name, int batch_size)
: _batch_size(batch_size), _total_length(0) {
_file.open(file_name);
_file.seekg(_file.end);
_total_length = _file.tellg();
_file.seekg(_file.beg);
}
void get_batch_data(std::vector<std::vector<float>>& fea, // NOLINT
std::vector<std::vector<float>>& week_fea, // NOLINT
std::vector<std::vector<float>>& time_fea, // NOLINT
std::vector<long unsigned int>& seq_offset); // NOLINT
private:
std::fstream _file;
int _total_length;
int _batch_size;
};
void Data::get_batch_data(
std::vector<std::vector<float>>& fea, // NOLINT
std::vector<std::vector<float>>& week_fea, // NOLINT
std::vector<std::vector<float>>& time_fea, // NOLINT
std::vector<long unsigned int>& seq_offset) { // NOLINT
int seq_num = 0;
long unsigned int cum = 0; // NOLINT
char buf[10000];
seq_offset.clear();
seq_offset.push_back(0);
fea.clear();
week_fea.clear();
time_fea.clear();
while (_file.getline(buf, 10000)) {
std::string s = buf;
std::vector<std::string> deli_vec = {":"};
std::vector<std::string> data_vec = string_split(s, deli_vec);
std::vector<std::string> seq;
seq = string_split(data_vec[0], {"|"});
for (auto link : seq) {
std::vector<std::string> data = string_split(link, ",");
std::vector<float> vec;
for (int i = 0; i < data.size(); i++) {
vec.push_back(atof(data[i].c_str()));
}
fea.push_back(vec);
}
std::vector<std::string> week_data;
std::vector<std::string> time_data;
week_data = string_split(data_vec[2], ",");
std::vector<float> vec_w;
for (int i = 0; i < week_data.size(); i++) {
vec_w.push_back(atof(week_data[i].c_str()));
}
week_fea.push_back(vec_w);
time_data = string_split(data_vec[1], ",");
std::vector<float> vec_t;
for (int i = 0; i < time_data.size(); i++) {
vec_t.push_back(atof(time_data[i].c_str()));
}
time_fea.push_back(vec_t);
cum += seq.size();
seq_offset.push_back(cum);
seq_num++;
if (seq_num >= _batch_size) {
break;
}
}
}
namespace paddle {
AnakinConfig GetConfig() {
AnakinConfig config;
// using AnakinConfig::X86 if you need to use cpu to do inference
config.target_type = AnakinConfig::X86;
config.model_file = FLAGS_model;
config.device = 0;
config.max_batch_size = 1000; // the max number of token
return config;
}
void set_tensor(std::string name, std::vector<int> shape,
std::vector<PaddleTensor>& vec) { // NOLINT
int sum = 1;
std::for_each(shape.begin(), shape.end(), [&](int n) { sum *= n; });
float* data = new float[sum];
PaddleTensor tensor;
tensor.name = name;
tensor.shape = shape;
tensor.data = PaddleBuf(data, sum);
tensor.dtype = PaddleDType::FLOAT32;
vec.push_back(tensor);
}
void single_test() {
AnakinConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(config);
int max_batch_size = 1000;
std::string feature_file = FLAGS_datapath;
Data map_data(feature_file, FLAGS_batch_size);
std::vector<std::vector<float>> fea;
std::vector<std::vector<float>> week_fea;
std::vector<std::vector<float>> time_fea;
std::vector<long unsigned int> seq_offset; // NOLINT
paddle::PaddleTensor tensor_0, tensor_1, tensor_2;
tensor_0.name = "input_0";
tensor_1.name = "input_4";
tensor_2.name = "input_5";
PaddleTensor tensor_out;
tensor_out.name = "final_output.tmp_1_gout";
tensor_out.shape = std::vector<int>({});
tensor_out.data = PaddleBuf();
tensor_out.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> inputs;
std::vector<PaddleTensor> outputs(1, tensor_out);
int data_0_dim = 38;
int data_1_dim = 10;
int data_2_dim = 10;
float data_0[max_batch_size * data_0_dim]; // NOLINT
float data_1[max_batch_size * data_1_dim]; // NOLINT
float data_2[max_batch_size * data_2_dim]; // NOLINT
int count = 0;
while (true) {
if (count++ > 0) break; // only run the first batch in ci.
seq_offset.clear();
map_data.get_batch_data(fea, week_fea, time_fea, seq_offset);
if (seq_offset.size() <= 1) {
LOG(FATAL) << "seq_offset.size() <= 1, exit.";
break;
}
std::vector<std::vector<long unsigned int>> seq_offset_vec; // NOLINT
seq_offset_vec.push_back(seq_offset);
tensor_0.lod = seq_offset_vec;
int p_shape_0[] = {(int)fea.size(), 1, 1, data_0_dim}; // NOLINT
int p_shape_1[] = {(int)week_fea.size(), data_1_dim, 1, 1}; // NOLINT
int p_shape_2[] = {(int)time_fea.size(), data_2_dim, 1, 1}; // NOLINT
std::vector<int> shape_0(p_shape_0, p_shape_0 + 4);
std::vector<int> shape_1(p_shape_1, p_shape_1 + 4);
std::vector<int> shape_2(p_shape_2, p_shape_2 + 4);
tensor_0.shape = shape_0;
tensor_1.shape = shape_1;
tensor_2.shape = shape_2;
for (int i = 0; i < fea.size(); i++) {
memcpy(data_0 + i * data_0_dim, &fea[i][0], sizeof(float) * data_0_dim);
}
for (int i = 0; i < week_fea.size(); i++) {
memcpy(data_1 + i * data_1_dim, &week_fea[i][0],
sizeof(float) * data_1_dim);
}
for (int i = 0; i < time_fea.size(); i++) {
memcpy(data_2 + i * data_2_dim, &time_fea[i][0],
sizeof(float) * data_2_dim);
}
tensor_0.data =
paddle::PaddleBuf(data_0, fea.size() * sizeof(float) * data_0_dim);
tensor_1.data =
paddle::PaddleBuf(data_1, week_fea.size() * sizeof(float) * data_1_dim);
tensor_2.data =
paddle::PaddleBuf(data_2, time_fea.size() * sizeof(float) * data_2_dim);
tensor_0.dtype = paddle::PaddleDType::FLOAT32;
tensor_1.dtype = paddle::PaddleDType::FLOAT32;
tensor_2.dtype = paddle::PaddleDType::FLOAT32;
inputs.clear();
inputs.push_back(tensor_1);
inputs.push_back(tensor_2);
inputs.push_back(tensor_0);
Timer timer;
timer.tic();
for (int i = 0; i < FLAGS_repeat; i++) predictor->Run(inputs, &outputs);
LOG(INFO) << "batch_size = " << FLAGS_batch_size
<< ", repeat = " << FLAGS_repeat
<< ", sequence_length = " << seq_offset[seq_offset.size() - 1]
<< ", latency: " << timer.toc() / FLAGS_repeat << "ms";
float* data_o = static_cast<float*>(outputs[0].data.data());
VLOG(3) << "outputs[0].data.length() = " << outputs[0].data.length();
for (size_t j = 0; j < outputs[0].data.length(); ++j) {
VLOG(3) << "output[" << j << "]: " << data_o[j];
}
}
}
} // namespace paddle
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
logger::init(argv[0]);
paddle::single_test();
/* multi-threads
std::vector<std::thread> threads;
int num = 1;
for (int i = 0; i < num; i++) {
LOG(INFO) << " thread id : " << i;
threads.emplace_back(paddle::single_test);
}
for (int i = 0; i < num; i++) {
threads[i].join();
}
threads.clear();
*/
return 0;
}
......@@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
explicit PaddleBuf(size_t length)
PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
......@@ -70,7 +70,7 @@ struct PaddleTensor {
std::vector<int> shape;
PaddleBuf data; // blob of data.
PaddleDType dtype;
std::vector<std::vector<uint64_t>> lod; // lod data
std::vector<std::vector<size_t>> lod; // Tensor+LoD equals LoDTensor
};
enum class PaddleEngineKind {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册