未验证 提交 7d17f2ef 编写于 作者: R Ray Liu 提交者: GitHub

Merge branch 'develop' into backup

......@@ -70,10 +70,11 @@ void format_fp16_ofm(framework::Tensor *ofm_tensor) {
DLOG << "Wrong ofm dimension";
}
auto p = fpga_malloc(memory_size);
memset(p, 0, memory_size);
// memset(p, 0, memory_size);
ofm_tensor->reset_data_ptr(p);
ofm_tensor->set_type(typeid(half));
ofm_tensor->fpga_data_num = memory_size / sizeof(half);
fpga::fpga_flush(p, memory_size);
}
void format_fp16_ofm(framework::Tensor *ofm_tensor, framework::DDim dims) {
......@@ -89,10 +90,11 @@ void format_fp16_ofm(framework::Tensor *ofm_tensor, framework::DDim dims) {
DLOG << "Wrong ofm dimension";
}
auto p = fpga_malloc(memory_size);
memset(p, 0, memory_size);
// memset(p, 0, memory_size);
ofm_tensor->reset_data_ptr(p);
ofm_tensor->set_type(typeid(half));
ofm_tensor->fpga_data_num = memory_size / sizeof(half);
fpga::fpga_flush(p, memory_size);
}
void format_fp32_ofm(framework::Tensor *ofm_tensor) {
......@@ -108,10 +110,11 @@ void format_fp32_ofm(framework::Tensor *ofm_tensor) {
DLOG << "Wrong ofm dimension";
}
auto p = fpga_malloc(memory_size);
memset(p, 0, memory_size);
// memset(p, 0, memory_size);
ofm_tensor->reset_data_ptr(p);
ofm_tensor->set_type(typeid(float));
ofm_tensor->fpga_data_num = memory_size / sizeof(float);
fpga::fpga_flush(p, memory_size);
}
float filter_find_max(framework::Tensor *filter_tensor) {
......@@ -463,9 +466,24 @@ void expand_EW_arg(EWAddArgs *arg) {
uint64_t image_amount_per_row =
align_to_x((uint64_t)args.image0.width * (uint64_t)args.image0.channels,
IMAGE_ALIGNMENT);
uint64_t image_image_pixel = ((uint64_t)args.image0.channels << 32) |
((uint64_t)args.image0.width << 16) |
(uint64_t)args.image0.height;
//////////////////////////////////////////////////////////
// temporary modify for EW and DMA problem
uint64_t image_image_pixel = 0;
if ((args.image0.width * args.image0.channels) >= 24576) {
if ((args.image0.width * args.image0.channels) % 32 != 0) {
DLOG << "EW parameter can not be support";
} else {
image_amount_per_row = image_amount_per_row / 2;
image_image_pixel = ((uint64_t)args.image0.channels << 32) |
((uint64_t)(args.image0.width / 2) << 16) |
(uint64_t)(args.image0.height * 2);
}
} else {
image_image_pixel = ((uint64_t)args.image0.channels << 32) |
((uint64_t)args.image0.width << 16) |
(uint64_t)args.image0.height;
}
//////////////////////////////////////////////////////////
(*arg).driver.image0_address_phy = image0_address_phy;
(*arg).driver.image1_address_phy = image1_address_phy;
......@@ -560,6 +578,18 @@ void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
reinterpret_cast<char *>(arg->conv_arg[i].filter_address), deleter));
memcpy(arg->conv_arg[i].filter_address, filter_head, filter_size);
fpga_flush(arg->conv_arg[i].filter_address, filter_size);
// for test
// {
// static int cnt = 0;
// if(cnt == 4){
// int8_t result = 0;
// std::string str = "fc_filter";
// fpga::savefile<int8_t>(str, arg->conv_arg[i].filter_address,
// filter_size, result);
//
// }
// cnt++;
//}
size_t bs_size = 2 *
align_to_x(arg->conv_arg[i].filter_num, BS_NUM_ALIGNMENT) *
......@@ -570,6 +600,18 @@ void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
reinterpret_cast<char *>(arg->conv_arg[i].sb_address), deleter));
memcpy(arg->conv_arg[i].sb_address, bs_head, bs_size);
fpga_flush(arg->conv_arg[i].sb_address, bs_size);
// for test
/*{
static int cnt = 0;
if(cnt == 4){
float result = 0;
std::string str = "fc_bs";
fpga::savefile<float>(str, arg->conv_arg[i].sb_address, bs_size/4,
result);
}
cnt++;
}*/
if (n > 1) {
arg->conv_arg[i].output.scale_address =
......
......@@ -268,6 +268,7 @@ void format_fc_filter(float **data_in, int num, int channel, int height,
quantize(data_in, data_size, max);
char **quantize_data = (char **)data_in; // NOLINT
convert_fc_filter(quantize_data, num, chw);
convert_to_hwc(quantize_data, num, channel, height, width);
align_element(quantize_data, num, chw);
if (num_after_alignment != num) {
align_num(quantize_data, num_per_div_before_alignment, num, chw);
......@@ -316,7 +317,7 @@ void align_element_n(int16_t **data_in, int num, int height, int width) {
}
*data_in = data_tmp;
free(tmp);
fpga_free(tmp);
}
}
void quantize_to_fp16(float **data_in, int num, int height, int width,
......
......@@ -90,11 +90,6 @@ Executor<Device, T>::Executor(const Program<Device> &program,
InitMemory();
}
#ifdef PADDLE_MOBILE_FPGA
program_.scope->EraseVars({"feed", "fetch"});
program_.scope->print_vars();
#endif
int count = 0;
for (auto &op_handler : ops_of_block0_) {
DLOG << "Initialize op[" << count++ << "]: " << op_handler->Type();
......@@ -514,6 +509,32 @@ PMStatus Executor<Device, T>::Predict() {
return PMSuccess;
}
template <typename Device, typename T>
void Executor<Device, T>::FeedTensorData(const vector<framework::Tensor> &v) {
auto input_size = v.size();
auto *feed_var = program_.scope->Var("feed");
PADDLE_MOBILE_ENFORCE(input_size == feed_indices_.size(),
"input data number not correct");
for (int i = 0; i < input_size; i++) {
framework::LoDTensor &target =
feed_var->template GetMutable<framework::LoDTensorArray>()->at(i);
target.ShareDataWith(v[input_size - i - 1]);
}
}
template <typename Device, typename T>
void Executor<Device, T>::GetTensorResults(
std::vector<framework::Tensor *> *v) {
auto *fetch_var = program_.scope->Var("fetch");
auto output_size = fetch_indices_.size();
for (int i = 0; i < output_size; i++) {
framework::LoDTensor &target =
fetch_var->template GetMutable<framework::LoDTensorArray>()->at(i);
v->push_back(&target);
}
}
#ifdef PADDLE_MOBILE_FPGA
template <typename Device, typename T>
void Executor<Device, T>::InjectVariable(const Tensor &t,
......@@ -559,19 +580,6 @@ void Executor<Device, T>::GetResults(std::vector<void *> *v) {
}
}
template <typename Device, typename T>
void Executor<Device, T>::GetTensorResults(
std::vector<framework::Tensor *> *v) {
int index = 0;
auto vars = program_.scope->VarContain("fetch", &index);
auto output_size = vars.size();
for (int i = 0; i < output_size; i++) {
auto var = program_.scope->Var("fetch", i + index);
auto fetch_tensor = var->template GetMutable<LoDTensor>();
v->push_back(fetch_tensor);
}
}
template <typename Device, typename T>
framework::Tensor *Executor<Device, T>::GetTensorByName(
const std::string &name) {
......
......@@ -51,15 +51,15 @@ class Executor {
std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name);
void FeedTensorData(const std::vector<framework::Tensor> &v);
void GetTensorResults(std::vector<framework::Tensor *> *v);
#ifdef PADDLE_MOBILE_FPGA
void InjectVariable(const Tensor &t, std::string var_name);
void FeedData(const Tensor &t);
void FeedData(const std::vector<void *> &v);
void GetResults(std::vector<void *> *v);
void GetTensorResults(std::vector<framework::Tensor *> *v);
framework::Tensor *GetTensorByName(const std::string &name);
std::shared_ptr<Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
......
......@@ -50,9 +50,6 @@ OperatorBase<Dtype>::OperatorBase(const std::string &type,
attrs_(attrs),
scope_(scope) {
CheckAllInputOutputSet();
#ifdef PADDLE_MOBILE_FPGA
InsertTensors();
#endif
}
template <typename Dtype>
......@@ -72,6 +69,9 @@ void OperatorBase<Dtype>::Run() {
var->template IsType<framework::LoDTensor>()) {
const Tensor *tensor = var->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
#ifdef PADDLE_MOBILE_FPGA
DLOG << var_vec_in[i];
#endif
}
}
}
......@@ -83,6 +83,9 @@ void OperatorBase<Dtype>::Run() {
var->template IsType<framework::LoDTensor>()) {
const Tensor *tensor = var->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor;
#ifdef PADDLE_MOBILE_FPGA
DLOG << var_vec_out[i];
#endif
}
}
}
......
......@@ -146,7 +146,7 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors(
tensors[i].init(typeid(float));
ConvertPaddleTensors(inputs[i], &tensors[i]);
}
// paddle_mobile_->FeedTensorData(tensors);
paddle_mobile_->FeedTensorData(tensors);
}
template <typename Device, typename T>
......
......@@ -236,6 +236,11 @@ template <typename Device, typename T>
void PaddleMobile<Device, T>::FeedData(const std::vector<void *> &v) {
executor_->FeedData(v);
}
template <typename Device, typename T>
void PaddleMobile<Device, T>::FeedTensorData(
const std::vector<framework::Tensor> &v) {
executor_->FeedTensorData(v);
}
template <typename Device, typename T>
void PaddleMobile<Device, T>::GetResults(std::vector<void *> *v) {
......
......@@ -91,6 +91,7 @@ class PaddleMobile {
void InjectVariable(const framework::Tensor &t, std::string var_name);
void FeedData(const framework::Tensor &t);
void FeedData(const std::vector<void *> &v);
void FeedTensorData(const std::vector<framework::Tensor> &v);
void GetResults(std::vector<void *> *v);
void GetTensorResults(std::vector<framework::Tensor *> *v);
......
......@@ -21,6 +21,7 @@ template <>
bool FeedKernel<FPGA, float>::Init(FeedParam<FPGA> *param) {
auto output = param->Out();
int col = param->Col();
DLOG << "col = " << col;
auto input = const_cast<LoDTensor *>(&param->InputX()->at(col));
input->init(typeid(float));
input->Resize(output->dims());
......
......@@ -19,6 +19,7 @@ template <>
bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) {
auto input = const_cast<LoDTensor *>(param->InputX());
int col = param->Col();
DLOG << "col = " << col;
auto output = &(param->Out()->at(col));
if (input->type() == typeid(float)) {
return true;
......@@ -59,7 +60,11 @@ template <>
void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto input = const_cast<LoDTensor *>(param.InputX());
int col = param.Col();
LoDTensor *out = &param.Out()->at(col);
auto output = &param.Out()->at(col);
if (input->type() == typeid(float)) {
output->ShareDataWith(*input);
return;
}
fpga::BypassArgs args = param.fpga_bypass_args;
auto input_address = (input->data<half>());
......@@ -67,7 +72,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
float *outdata_ptr =
reinterpret_cast<float *>(param.fpga_bypass_args.output.address);
const int num_th = 32;
if ((out->fpga_data_num) < num_th) {
if (output->fpga_data_num < num_th) {
fpga::fpga_invalidate(input_address, (input->fpga_data_num) * sizeof(half));
for (int idx = 0; idx < product(input->dims()); ++idx) {
......@@ -77,14 +82,14 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
}
fpga::PerformBypass(args);
auto outC = out->dims()[1];
auto outH = out->dims()[2];
auto outW = out->dims()[3];
auto outC = output->dims()[1];
auto outH = output->dims()[2];
auto outW = output->dims()[3];
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
out->fpga_data_num * sizeof(float));
output->fpga_data_num * sizeof(float));
if (out->fpga_data_num != product(input->dims())) {
if (output->fpga_data_num != product(input->dims())) {
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
......@@ -92,7 +97,6 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
free(data_tmp);
}
}
template class FetchKernel<FPGA, float>;
} // namespace operators
......
......@@ -68,23 +68,38 @@ endif ()
list(FIND NET "FPGA_NET_V1" CON)
if (CON GREATER -1)
ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet50 paddle-mobile)
#ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-resnet50 paddle-mobile)
ADD_EXECUTABLE(test-densebox fpga/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-densebox paddle-mobile)
#ADD_EXECUTABLE(test-densebox fpga/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-densebox paddle-mobile)
ADD_EXECUTABLE(test-rfcn fpga/test_rfcn.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-rfcn paddle-mobile)
#ADD_EXECUTABLE(test-rfcn fpga/test_rfcn.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-rfcn paddle-mobile)
ADD_EXECUTABLE(test-marker fpga/test_marker.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-marker paddle-mobile)
#ADD_EXECUTABLE(test-marker fpga/test_marker.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-marker paddle-mobile)
ADD_EXECUTABLE(test-rfcn-api fpga/test_rfcn_api.cpp)
target_link_libraries(test-rfcn-api paddle-mobile)
ADD_EXECUTABLE(test-mobilenet-api fpga/test_mobilenet_api.cpp)
target_link_libraries(test-mobilenet-api paddle-mobile)
ADD_EXECUTABLE(test-yolo-api fpga/test_yolo_api.cpp)
target_link_libraries(test-yolo-api paddle-mobile)
ADD_EXECUTABLE(test-marker-api fpga/test_marker_api.cpp)
target_link_libraries(test-marker-api paddle-mobile)
ADD_EXECUTABLE(test-marker2 fpga/test_marker2.cpp test_helper.h test_include.h executor_for_test.h )
target_link_libraries(test-marker2 paddle-mobile)
#ADD_EXECUTABLE(test-marker2 fpga/test_marker2.cpp test_helper.h test_include.h executor_for_test.h )
#target_link_libraries(test-marker2 paddle-mobile)
#ADD_EXECUTABLE(test-mobilenet fpga/test_mobilenet_beijing.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-mobilenet paddle-mobile)
#ADD_EXECUTABLE(test-yolo fpga/test_yolo_combine.cpp test_helper.h test_include.h executor_for_test.h)
#target_link_libraries(test-yolo paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include <fstream>
#include <iostream>
#include "../../src/io/paddle_inference_api.h"
using namespace paddle_mobile;
using namespace paddle_mobile::fpga;
static const char *g_image = "../models/marker/model/image.bin";
static const char *g_model = "../models/marker/model/model";
static const char *g_param = "../models/marker/model/params";
static const char *g_image1 = "../models/marker2/model/marker.bin";
static const char *g_model1 = "../models/marker2/model/model";
static const char *g_param1 = "../models/marker2/model/params";
void readStream(std::string filename, char *buf) {
std::ifstream in;
in.open(filename, std::ios::in | std::ios::binary);
if (!in.is_open()) {
std::cout << "open File Failed." << std::endl;
return;
}
in.seekg(0, std::ios::end); // go to the end
auto length = in.tellg(); // report location (this is the length)
in.seekg(0, std::ios::beg); // go back to the beginning
in.read(buf, length);
in.close();
}
signed char float_to_int8(float fdata) {
if (fdata < 0.0) {
fdata -= 0.5;
} else {
fdata += 0.5;
}
return (signed char)fdata;
}
void quantize(float **data_in, int data_size) {
float *tmp = *data_in;
signed char *tmp_data =
(signed char *)paddle_mobile::fpga::fpga_malloc(data_size * sizeof(char));
for (int i = 0; i < data_size; i++) {
tmp_data[i] = float_to_int8((*data_in)[i] + 128);
}
*data_in = (float *)tmp_data; // NOLINT
paddle_mobile::fpga::fpga_free(tmp);
}
void convert_to_chw(float **data_in, int channel, int height, int width,
float *data_tmp) {
int64_t amount_per_side = width * height;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++);
}
}
}
}
void dump_stride_float(std::string filename,
paddle_mobile::PaddleTensor input_tensor) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.data.data());
int c = (input_tensor.shape)[1];
int h = (input_tensor.shape)[2];
int w = (input_tensor.shape)[3];
int n = (input_tensor.shape)[0];
float *data_tmp =
reinterpret_cast<float *>(malloc(c * h * w * sizeof(float)));
// convert_to_chw(&data_ptr, c, h, w, data_tmp);
std::ofstream out(filename.c_str());
float result = 0;
int datasize = abs(c * h * w * n);
if (datasize == 0) {
std::cout << "wrong dump data size" << std::endl;
return;
}
for (int i = 0; i < datasize; i++) {
result = data_ptr[i];
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename,
paddle_mobile::PaddleTensor input_tensor) {
if (input_tensor.dtypeid == typeid(float)) {
dump_stride_float(filename, input_tensor);
} else {
std::cout << "only support dumping float data" << std::endl;
}
}
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model;
config.param_file = g_param;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
PaddleMobileConfig GetConfig1() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model1;
config.param_file = g_param1;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
int main() {
open_device();
PaddleMobileConfig config1 = GetConfig1();
auto predictor1 =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config1);
std::cout << "Finishing loading model" << std::endl;
for (int i = 0; i < 1; ++i) {
int img_length1 = 144 * 14 * 14;
auto img1 =
reinterpret_cast<float *>(fpga_malloc(img_length1 * sizeof(float)));
readStream(g_image1, reinterpret_cast<char *>(img1));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img1;
t_img1.dtypeid = typeid(float);
t_img1.layout = LAYOUT_HWC;
t_img1.shape = std::vector<int>({1, 14, 14, 144});
t_img1.name = "Image information";
t_img1.data.Reset(img1, img_length1 * sizeof(float));
predictor1->FeedPaddleTensors({t_img1});
std::cout << "Finishing feeding data " << std::endl;
predictor1->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<paddle_mobile::PaddleTensor> v1; // No need to initialize v
predictor1->FetchPaddleTensors(&v1); // Old data in v will be cleared
std::cout << "Output number is " << v1.size() << std::endl;
for (int fetchNum = 0; fetchNum < v1.size(); fetchNum++) {
std::string dumpName = "marker2_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v1[fetchNum]);
}
}
/////////////////////////////////////
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
std::cout << "Finishing loading model" << std::endl;
float img_info[3] = {432, 1280, 1.0f};
int img_length = 432 * 1280 * 3;
auto img = reinterpret_cast<float *>(fpga_malloc(img_length * sizeof(float)));
readStream(g_image, reinterpret_cast<char *>(img));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img_info, t_img;
t_img_info.dtypeid = typeid(float);
t_img_info.layout = LAYOUT_HWC;
t_img_info.shape = std::vector<int>({1, 3});
t_img_info.name = "Image information";
t_img_info.data.Reset(img_info, 3 * sizeof(float));
t_img.dtypeid = typeid(float);
// quantize(&img, img_length);
// t_img.dtypeid = typeid(int8_t);
t_img.layout = LAYOUT_HWC;
t_img.shape = std::vector<int>({1, 432, 1280, 3});
t_img.name = "Image information";
t_img.data.Reset(img, img_length * sizeof(float));
// t_img.data.Reset(img, img_length * sizeof(int8_t));
// for(int i = 0; i < 100; ++i){
predictor->FeedPaddleTensors({t_img_info, t_img});
std::cout << "Finishing feeding data " << std::endl;
predictor->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<paddle_mobile::PaddleTensor> v; // No need to initialize v
predictor->FetchPaddleTensors(&v); // Old data in v will be cleared
std::cout << "Output number is " << v.size() << std::endl;
for (int fetchNum = 0; fetchNum < v.size(); fetchNum++) {
std::string dumpName = "marker_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v[fetchNum]);
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include <fstream>
#include <iostream>
#include "../../src/io/paddle_inference_api.h"
using namespace paddle_mobile; // NOLINT
using namespace paddle_mobile::fpga; // NOLINT
static const char *g_image = "../images/mobilenet_txtdata/1.txt";
static const char *g_model = "../models/keycurve_l2_regular4_model/__model__";
static const char *g_param =
"../models/keycurve_l2_regular4_model/model.params";
void readStream(std::string filename, float *buf) {
std::ifstream in;
in.open(filename, std::ios::in);
if (!in.is_open()) {
std::cout << "open File Failed." << std::endl;
return;
}
int i = 0;
while (!in.eof()) {
in >> buf[i];
i++;
}
in.close();
}
signed char float_to_int8(float fdata) {
if (fdata < 0.0) {
fdata -= 0.5;
} else {
fdata += 0.5;
}
return (signed char)fdata;
}
void quantize(float **data_in, int data_size) {
float *tmp = *data_in;
signed char *tmp_data = (signed char *)fpga_malloc(data_size * sizeof(char));
for (int i = 0; i < data_size; i++) {
tmp_data[i] = float_to_int8((*data_in)[i] + 128);
}
*data_in = (float *)tmp_data; // NOLINT
fpga_free(tmp);
}
void convert_to_chw(float **data_in, int channel, int height, int width,
float *data_tmp) {
int64_t amount_per_side = width * height;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++);
}
}
}
}
void dump_stride_float(std::string filename, PaddleTensor input_tensor) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.data.data());
int c = (input_tensor.shape)[1];
int h = (input_tensor.shape)[2];
int w = (input_tensor.shape)[3];
int n = (input_tensor.shape)[0];
float *data_tmp =
reinterpret_cast<float *>(malloc(c * h * w * sizeof(float)));
convert_to_chw(&data_ptr, c, h, w, data_tmp);
std::ofstream out(filename.c_str());
float result = 0;
int datasize = abs(c * h * w * n);
if (datasize == 0) {
std::cout << "wrong dump data size" << std::endl;
return;
}
for (int i = 0; i < datasize; i++) {
result = data_tmp[i];
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename, PaddleTensor input_tensor) {
if (input_tensor.dtypeid == typeid(float)) {
dump_stride_float(filename, input_tensor);
} else {
std::cout << "only support dumping float data" << std::endl;
}
}
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model;
config.param_file = g_param;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
int main() {
open_device();
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<paddle_mobile::PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
std::cout << "Finishing loading model" << std::endl;
int img_length = 256 * 416 * 3;
auto img = reinterpret_cast<float *>(fpga_malloc(img_length * sizeof(float)));
readStream(g_image, img);
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img;
t_img.dtype = FLOAT32;
t_img.dtypeid = typeid(float);
// quantize(&img, img_length);
// t_img.dtype = INT8;
// t_img.dtypeid = typeid(int8_t);
t_img.layout = LAYOUT_HWC;
t_img.shape = std::vector<int>({1, 256, 416, 3});
t_img.name = "Image information";
t_img.data.Reset(img, img_length * sizeof(float));
// t_img.data.Reset(img, img_length * sizeof(int8_t));
predictor->FeedPaddleTensors({t_img});
std::cout << "Finishing feeding data " << std::endl;
predictor->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<PaddleTensor> v; // No need to initialize v
predictor->FetchPaddleTensors(&v); // Old data in v will be cleared
std::cout << "Output number is " << v.size() << std::endl;
for (int fetchNum = 0; fetchNum < v.size(); fetchNum++) {
std::string dumpName = "mobilenet_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v[fetchNum]);
}
return 0;
}
......@@ -12,18 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include <fstream>
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
#include "../../src/io/paddle_inference_api.h"
#ifdef PADDLE_MOBILE_FPGA_V1
#include "fpga/V1/api.h"
#endif
#ifdef PADDLE_MOBILE_FPGA_V2
#include "fpga/V2/api.h"
#endif
using namespace paddle_mobile;
using namespace paddle_mobile::fpga;
#include <string>
static const char *g_image = "../models/rfcn/data.bin";
static const char *g_model = "../models/rfcn/model";
static const char *g_param = "../models/rfcn/params";
void readStream(std::string filename, char *buf) {
std::ifstream in;
......@@ -37,116 +38,128 @@ void readStream(std::string filename, char *buf) {
auto length = in.tellg(); // report location (this is the length)
in.seekg(0, std::ios::beg); // go back to the beginning
in.read(buf, length);
DLOG << length;
in.close();
}
void convert_to_chw(int16_t **data_in, int channel, int height, int width,
int num, int16_t *data_tmp) {
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + n * amount_per_side * channel + c * amount_per_side +
width * h + w) = *((*data_in)++);
}
}
}
}
}
void dump_stride_half(std::string filename, Tensor input_tensor,
const int dumpnum, bool use_chw) {
// bool use_chw = true;
if (input_tensor.dims().size() != 4) return;
int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3];
int n = (input_tensor.dims())[0];
auto data_ptr = input_tensor.get_data();
auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr);
auto data_tmp = data_ptr_16;
if (use_chw) {
data_tmp =
reinterpret_cast<half *>(malloc(n * c * h * w * sizeof(int16_t)));
convert_to_chw(&data_ptr_16, c, h, w, n, data_tmp);
}
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl;
}
out.close();
if (data_tmp != data_ptr_16) {
free(data_tmp);
}
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model;
config.param_file = g_param;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
void dump_stride_float(std::string filename, Tensor input_tensor,
const int dumpnum) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data());
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = data_ptr[i];
out << result << std::endl;
}
out.close();
PaddleMobileConfig GetConfig1() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.model_dir = "../models/resnet50";
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.quantification = false;
return config;
}
void dump_stride(std::string filename, Tensor input_tensor, const int dumpnum,
bool use_chw) {
static int i = 0;
if (input_tensor.numel() == 0) {
return;
int main() {
open_device();
PaddleMobileConfig config1 = GetConfig1();
auto predictor1 =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config1);
std::cout << "Finishing loading model" << std::endl;
int img_length1 = 224 * 224 * 3;
auto img1 =
reinterpret_cast<float *>(fpga_malloc(img_length1 * sizeof(float)));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img1;
t_img1.dtypeid = typeid(float);
t_img1.layout = LAYOUT_HWC;
t_img1.shape = std::vector<int>({1, 224, 224, 3});
t_img1.name = "Image information";
t_img1.data.Reset(img1, img_length1 * sizeof(float));
predictor1->FeedPaddleTensors({t_img1});
predictor1->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<PaddleTensor> v1; // No need to initialize v
predictor1->FetchPaddleTensors(&v1); // Old data in v will be cleared
std::cout << "Output number is " << v1.size() << std::endl;
std::cout << "out[0] length " << v1[0].data.length() << std::endl;
////////////////////////////
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
std::cout << "Finishing loading model" << std::endl;
float img_info[3] = {432, 1280, 1.0f};
int img_length = 432 * 1280 * 3;
auto img = reinterpret_cast<float *>(fpga_malloc(img_length * sizeof(float)));
readStream(g_image, reinterpret_cast<char *>(img));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img_info, t_img;
t_img.dtypeid = typeid(float);
t_img_info.layout = LAYOUT_HWC;
t_img_info.shape = std::vector<int>({1, 3});
t_img_info.name = "Image information";
t_img_info.data.Reset(img_info, 3 * sizeof(float));
t_img.dtypeid = typeid(float);
t_img.layout = LAYOUT_HWC;
t_img.shape = std::vector<int>({1, 432, 1280, 3});
t_img.name = "Image information";
t_img.data.Reset(img, img_length * sizeof(float));
predictor->FeedPaddleTensors({t_img_info, t_img});
std::cout << "Finishing feeding data " << std::endl;
predictor->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<PaddleTensor> v; // No need to initialize v
predictor->FetchPaddleTensors(&v); // Old data in v will be cleared
std::cout << "Output number is " << v.size() << std::endl;
std::cout << "out[0] length " << v[0].data.length() << std::endl;
std::cout << "out[1] length " << v[1].data.length() << std::endl;
std::cout << "out[2] length " << v[2].data.length() << std::endl;
auto post_nms = v[0].data.length() / sizeof(float) / 8;
for (int num = 0; num < post_nms; num++) {
for (int i = 0; i < 8; i++) {
auto p = reinterpret_cast<float *>(v[0].data.data());
std::cout << p[num * 8 + i] << std::endl;
}
}
if (input_tensor.type() == typeid(float)) {
DLOG << "op: " << i++ << ", float data " << input_tensor.numel();
dump_stride_float(filename, input_tensor, dumpnum);
} else {
DLOG << "op: " << i++ << ", half data " << input_tensor.numel();
dump_stride_half(filename, input_tensor, dumpnum, use_chw);
for (int num = 0; num < post_nms; num++) {
for (int i = 0; i < 8; i++) {
auto p = reinterpret_cast<float *>(v[1].data.data());
std::cout << p[num * 8 + i] << std::endl;
}
}
DLOG << "dump input address: " << input_tensor.get_data();
}
static const char *g_rfcn_combine = "../models/rfcn";
static const char *g_image_src_float = "../models/rfcn/data.bin";
int main() {
paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_rfcn_combine) + "/model",
std::string(g_rfcn_combine) + "/params", true, false,
1, true)) {
float img_info[3] = {768, 1536, 768.0f / 960.0f};
auto img = reinterpret_cast<float *>(
fpga::fpga_malloc(768 * 1536 * 3 * sizeof(float)));
readStream(g_image_src_float, reinterpret_cast<char *>(img));
std::vector<void *> v(3, nullptr);
paddle_mobile.FeedData(std::vector<void *>({img_info, img}));
paddle_mobile.Predict_To(-1);
for (int i = 65; i < 69; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "rfcn_" + std::to_string(i);
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(),
tensor_ptr->numel() * sizeof(float));
dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(), true);
for (int num = 0; num < post_nms; num++) {
for (int i = 0; i < 4; i++) {
auto p = reinterpret_cast<float *>(v[2].data.data());
std::cout << p[num * 4 + i] << std::endl;
}
// paddle_mobile.GetResults(&v);
DLOG << "Computation done";
fpga::fpga_free(img);
}
std::cout << "Finish getting vector values" << std::endl;
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include <fstream>
#include <iostream>
#include "../../src/io/paddle_inference_api.h"
using namespace paddle_mobile; // NOLINT
using namespace paddle_mobile::fpga; // NOLINT
static const char *g_image = "../images/yolo_test_txtimg/1.txt";
static const char *g_model = "../models/yolo_bn_l2_model/__model__";
static const char *g_param = "../models/yolo_bn_l2_model/model.params";
void readStream(std::string filename, float *buf) {
std::ifstream in;
in.open(filename, std::ios::in);
if (!in.is_open()) {
std::cout << "open File Failed." << std::endl;
return;
}
int i = 0;
while (!in.eof()) {
in >> buf[i];
i++;
}
in.close();
}
signed char float_to_int8(float fdata) {
if (fdata < 0.0) {
fdata -= 0.5;
} else {
fdata += 0.5;
}
return (signed char)fdata;
}
void quantize(float **data_in, int data_size) {
float *tmp = *data_in;
signed char *tmp_data = (signed char *)fpga_malloc(data_size * sizeof(char));
for (int i = 0; i < data_size; i++) {
tmp_data[i] = float_to_int8((*data_in)[i] + 128);
}
*data_in = (float *)tmp_data; // NOLINT
fpga_free(tmp);
}
void convert_to_chw(float **data_in, int channel, int height, int width,
float *data_tmp) {
int64_t amount_per_side = width * height;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++);
}
}
}
}
void dump_stride_float(std::string filename, PaddleTensor input_tensor) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.data.data());
int c = (input_tensor.shape)[1];
int h = (input_tensor.shape)[2];
int w = (input_tensor.shape)[3];
int n = (input_tensor.shape)[0];
float *data_tmp =
reinterpret_cast<float *>(malloc(c * h * w * sizeof(float)));
convert_to_chw(&data_ptr, c, h, w, data_tmp);
std::ofstream out(filename.c_str());
float result = 0;
int datasize = abs(c * h * w * n);
if (datasize == 0) {
std::cout << "wrong dump data size" << std::endl;
return;
}
for (int i = 0; i < datasize; i++) {
result = data_tmp[i];
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename, PaddleTensor input_tensor) {
if (input_tensor.dtypeid == typeid(float)) {
dump_stride_float(filename, input_tensor);
} else {
std::cout << "only support dumping float data" << std::endl;
}
}
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model;
config.param_file = g_param;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
int main() {
open_device();
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
std::cout << "Finishing loading model" << std::endl;
int img_length = 256 * 416 * 3;
auto img = reinterpret_cast<float *>(fpga_malloc(img_length * sizeof(float)));
readStream(g_image, img);
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img;
// t_img.dtype = FLOAT32;
// t_img.dtypeid = typeid(float);
quantize(&img, img_length);
t_img.dtype = INT8;
t_img.dtypeid = typeid(int8_t);
t_img.layout = LAYOUT_HWC;
t_img.shape = std::vector<int>({1, 256, 416, 3});
t_img.name = "Image information";
// t_img.data.Reset(img, img_length * sizeof(float));
t_img.data.Reset(img, img_length * sizeof(int8_t));
predictor->FeedPaddleTensors({t_img});
std::cout << "Finishing feeding data " << std::endl;
predictor->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<PaddleTensor> v; // No need to initialize v
predictor->FetchPaddleTensors(&v); // Old data in v will be cleared
std::cout << "Output number is " << v.size() << std::endl;
for (int fetchNum = 0; fetchNum < v.size(); fetchNum++) {
std::string dumpName = "yolo_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v[fetchNum]);
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册