提交 264f1162 编写于 作者: J jameswu2014

fix NMS no use bug

上级 e79949e9
......@@ -300,7 +300,7 @@ static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
template <class T>
static inline Tensor NMS(Tensor *bbox, Tensor *scores, T nms_threshold,
float eta) {
float eta, int post_nms_num = 100) {
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
......@@ -314,7 +314,7 @@ static inline Tensor NMS(Tensor *bbox, Tensor *scores, T nms_threshold,
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
while (sorted_indices.size() != 0) {
while ((sorted_indices.size() != 0) && (selected_num < post_nms_num)) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
......@@ -397,17 +397,19 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
return std::make_pair(bbox_sel, scores_filter);
}
Tensor keep_nms = NMS<T>(&bbox_sel, &scores_filter, nms_thresh, eta);
// Tensor keep_nms = NMS<T>(&bbox_sel, &scores_filter, nms_thresh, eta);
Tensor keep_nms =
NMS<T>(&bbox_sel, &scores_filter, nms_thresh, eta, post_nms_top_n);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
// proposals.mutable_data<T>({keep_nms.numel(), 4});//original
// scores_sel.mutable_data<T>({keep_nms.numel(), 1});//original
proposals.mutable_data<T>({keep_nms.numel(), 4}); // original
scores_sel.mutable_data<T>({keep_nms.numel(), 1}); // original
proposals.mutable_data<T>({post_nms_top_n, 4}); // wong
scores_sel.mutable_data<T>({post_nms_top_n, 1}); // wong
// proposals.mutable_data<T>({post_nms_top_n, 4}); // wong
// scores_sel.mutable_data<T>({post_nms_top_n, 1}); // wong
CPUGather<T>(bbox_sel, keep_nms, &proposals);
CPUGather<T>(scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel);
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#ifdef PSROI_POOL_OP
#include <cmath>
#include <memory>
#include <vector>
#include "operators/kernel/detection_kernel.h"
......@@ -72,16 +71,72 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
return true;
}
/*
template <typename Dtype>
void PSROIPoolingForward(
const Dtype* bottom_data,
const int height, const int width, const int input_channel,
Dtype* top_data,
const int pooled_height, const int pooled_width, const int output_channel,
const Dtype* bottom_rois,
const Dtype Bin_size_h, const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype roi_start_w, const int pw, const int ph, const int roi_batch_ind)
{
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw)* Bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w);
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
float32x4_t sum_pixels_low_c= vdupq_n_f32(0);
float32x4_t sum_pixels_high_c= vdupq_n_f32(0);
if(!is_empty){
Dtype bin_area = (hend - hstart) * (wend - wstart);
float rev_bin_area = 1 / bin_area;
float32x4_t q_bin_area = vdupq_n_f32(rev_bin_area);
//static_cast<float>(bin_area) float pixels_c[output_channel];
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int pixel_offset = (h * width + w) * input_channel;
for(int output_c = 0; output_c < output_channel; output_c++){
int input_channel_offset = output_c * pooled_height *
pooled_width; int input_bias = pixel_offset + input_channel_offset + ph *
pooled_width + pw; pixels_c[output_c] = bottom_data[input_bias];
}
float32x4_t pixel_low_c = vld1q_f32(pixels_c);
float32x4_t pixel_high_c = vld1q_f32(pixels_c + 4);
sum_pixels_low_c = vaddq_f32(sum_pixels_low_c, pixel_low_c);
sum_pixels_high_c = vaddq_f32(sum_pixels_high_c, pixel_high_c);
}
}
sum_pixels_low_c = vmulq_f32(sum_pixels_low_c, q_bin_area);
sum_pixels_high_c = vmulq_f32(sum_pixels_high_c, q_bin_area);
}
int output_index_base = (ph * pooled_width + pw) * output_channel;
top_data += output_index_base;
vst1q_f32(top_data, sum_pixels_low_c);
top_data += 4;
vst1q_f32(top_data, sum_pixels_high_c);
}*/
template <typename Dtype>
void PSROIPooling(const Dtype* bottom_data, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const Dtype* bottom_rois,
const int output_dim, const int group_size, Dtype* top_data,
int index, int nid, const Dtype Bin_size_h,
const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype roi_start_w, const int ctop, const int ph,
const int roi_batch_ind) {
int pw = index;
void PSROIPoolingForward(const Dtype* bottom_data, const int height,
const int width, const int input_channel,
Dtype* top_data, const int pooled_height,
const int pooled_width, const int output_channel,
const Dtype* bottom_rois, const Dtype Bin_size_h,
const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype roi_start_w, const int pw, const int ph,
const int roi_batch_ind) {
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw) * Bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
......@@ -94,60 +149,35 @@ void PSROIPooling(const Dtype* bottom_data, const int channels,
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c = (ctop * group_size + ph) * group_size + pw;
Dtype bin_area = (hend - hstart) * (wend - wstart);
bottom_data += (roi_batch_ind * channels + c) * height * width;
Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h * width + w;
out_sum += bottom_data[bottom_index];
}
}
top_data[nid + index] = is_empty ? 0. : out_sum / bin_area;
}
void convert_to_chw(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(channel * height * width * sizeof(float))); // NOLINT
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + n * height * width * channel + c * amount_per_side +
width * h + w) = *((*data_in)++);
float sum_pixels_c[output_channel] = {0};
float pixels_c[output_channel] = {0};
if (!is_empty) {
Dtype bin_area = (hend - hstart) * (wend - wstart);
float rec_bin_area = 1 / bin_area;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int pixel_offset = (h * width + w) * input_channel;
for (int output_c = 0; output_c < output_channel; output_c++) {
int input_channel_offset = output_c * pooled_height * pooled_width;
int input_bias =
pixel_offset + input_channel_offset + ph * pooled_width + pw;
pixels_c[output_c] = bottom_data[input_bias];
}
}
}
}
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
}
void convert_to_hwc(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(num * channel * height * width * sizeof(float)));
int64_t amount_per_row = width * channel;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) {
for (int h = 0; h < height; h++) {
int64_t offset_height = h * amount_per_row;
for (int w = 0; w < width; w++) {
*(data_tmp + n * channel * height * width + offset_height +
w * channel + c) = *((*data_in)++);
for (int output_c = 0; output_c < output_channel; output_c++) {
sum_pixels_c[output_c] += pixels_c[output_c];
}
}
}
for (int output_c = 0; output_c < output_channel; output_c++) {
sum_pixels_c[output_c] *= rec_bin_area;
}
}
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
int output_index_base = (ph * pooled_width + pw) * output_channel;
top_data += output_index_base;
memcpy(top_data, sum_pixels_c, output_channel * 4);
}
template <>
......@@ -174,14 +204,15 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
int rois_num = rois->dims()[0];
auto data_nhwc = in->mutable_data<float>();
fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width, 1);
// fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width);
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new);
float* input_data = data_nhwc; // in->data<float>();
// shared_ptr<float> input_data(data_nhwc);
const float* input_data = data_nhwc; // in->data<float>();
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
......@@ -203,18 +234,19 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
// for (int n = 0; n < rois_batch_size; ++n) {
// for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
// rois_batch_id_data[i] = n;
// }
//}
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
auto output_data = out->mutable_data<float>();
auto input_rois = rois->data<float>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// [start, end) interval for spatial sampling
auto offset_input_rois = input_rois + n * 4;
auto offset_output_data =
output_data + pooled_height * pooled_width * output_channels * n;
auto roi_start_w =
static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h =
......@@ -232,27 +264,18 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto bin_size_h = roi_height / static_cast<float>(pooled_height);
auto bin_size_w = roi_width / static_cast<float>(pooled_width);
int roi_batch_ind = 0; // rois_batch_id_data[n];
// std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < pooled_height; ph++) {
int index = pooled_width;
int nid = n * output_channels * pooled_height * pooled_width +
c * pooled_width * pooled_height + ph * pooled_width;
for (int idx = 0; idx < index; idx++) {
PSROIPooling<float>(input_data, input_channels, height, width,
pooled_height, pooled_width, input_rois,
output_channels, pooled_height, output_data, idx,
nid, bin_size_h, bin_size_w, roi_start_h,
roi_start_w, c, ph, roi_batch_ind);
}
int roi_batch_ind = rois_batch_id_data[n];
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
PSROIPoolingForward<float>(input_data, height, width, input_channels,
offset_output_data, pooled_height,
pooled_width, output_channels, input_rois,
bin_size_h, bin_size_w, roi_start_h,
roi_start_w, pw, ph, roi_batch_ind);
}
}
}
fpga::fpga_free(input_data);
fpga::image::convert_to_hwc(&output_data, output_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(output_data);
}
} // namespace operators
......
......@@ -12,17 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include "../test_helper.h"
#include "../test_include.h"
#ifdef PADDLE_MOBILE_FPGA_V1
#include "fpga/V1/api.h"
#endif
#ifdef PADDLE_MOBILE_FPGA_V2
#include "fpga/V2/api.h"
#endif
#include <string>
#include <fstream>
#include <iostream>
#include "../../src/io/paddle_inference_api.h"
using namespace paddle_mobile; // NOLINT
using namespace paddle_mobile::fpga; // NOLINT
static const char *g_image = "../models/marker/marker1/image.bin";
static const char *g_model = "../models/marker/marker1/model";
static const char *g_param = "../models/marker/marker1/params";
void readStream(std::string filename, char *buf) {
std::ifstream in;
......@@ -36,132 +48,78 @@ void readStream(std::string filename, char *buf) {
auto length = in.tellg(); // report location (this is the length)
in.seekg(0, std::ios::beg); // go back to the beginning
in.read(buf, length);
DLOG << length;
in.close();
}
void convert_to_chw(int16_t **data_in, int channel, int height, int width,
int num, int16_t *data_tmp) {
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + n * amount_per_side * channel + c * amount_per_side +
width * h + w) = *((*data_in)++);
}
}
}
}
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kFPGA;
config.prog_file = g_model;
config.param_file = g_param;
config.thread_num = 1;
config.batch_size = 1;
config.optimize = true;
config.lod_mode = true;
config.quantification = false;
return config;
}
void dump_stride_half(std::string filename, Tensor input_tensor,
const int dumpnum, bool use_chw) {
// bool use_chw = true;
if (input_tensor.dims().size() != 4) return;
int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3];
int n = (input_tensor.dims())[0];
auto data_ptr = input_tensor.get_data();
auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr);
auto data_tmp = data_ptr_16;
if (use_chw) {
data_tmp =
reinterpret_cast<half *>(malloc(n * c * h * w * sizeof(int16_t)));
convert_to_chw(&data_ptr_16, c, h, w, n, data_tmp);
}
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl;
}
out.close();
if (data_tmp != data_ptr_16) {
free(data_tmp);
int main() {
open_device();
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
std::cout << "Finishing loading model" << std::endl;
float img_info[3] = {432, 1280, 1.0f};
int img_length = 432 * 1280 * 3;
auto img = reinterpret_cast<float *>(fpga_malloc(img_length * sizeof(float)));
readStream(g_image, reinterpret_cast<char *>(img));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img_info, t_img;
t_img.dtypeid = typeid(float);
t_img_info.layout = LAYOUT_HWC;
t_img_info.shape = std::vector<int>({1, 3});
t_img_info.name = "Image information";
t_img_info.data.Reset(img_info, 3 * sizeof(float));
t_img.dtypeid = typeid(float);
t_img.layout = LAYOUT_HWC;
t_img.shape = std::vector<int>({1, 432, 1280, 3});
t_img.name = "Image information";
t_img.data.Reset(img, img_length * sizeof(float));
predictor->FeedPaddleTensors({t_img_info, t_img});
std::cout << "Finishing feeding data " << std::endl;
predictor->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<PaddleTensor> v; // No need to initialize v
predictor->FetchPaddleTensors(&v); // Old data in v will be cleared
for (int i = 0; i < v.size(); ++i) {
auto p = reinterpret_cast<float *>(v[i].data.data());
int len = v[i].data.length();
float result = 0.0f;
std::string str = "fetch" + std::to_string(i);
fpga::savefile<float>(str, p, len, result);
}
}
void dump_stride_float(std::string filename, Tensor input_tensor,
const int dumpnum) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data());
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = data_ptr[i];
out << result << std::endl;
}
out.close();
}
std::cout << "Finish getting vector values" << std::endl;
void dump_stride(std::string filename, Tensor input_tensor, const int dumpnum,
bool use_chw) {
static int i = 0;
if (input_tensor.numel() == 0) {
return;
}
if (input_tensor.type() == typeid(float)) {
DLOG << "op: " << i++ << ", float data " << input_tensor.numel();
dump_stride_float(filename, input_tensor, dumpnum);
} else {
DLOG << "op: " << i++ << ", half data " << input_tensor.numel();
dump_stride_half(filename, input_tensor, dumpnum, use_chw);
}
DLOG << "dump input address: " << input_tensor.get_data();
}
////////////////////////////////////////////////////
static const char *g_marker_combine = "../models/marker/model";
static const char *g_image_src_float = "../models/marker/model/input_0.bin";
int main() {
paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
// if (paddle_mobile.Load(std::string(g_rfcn_combine) + "/model",
// std::string(g_rfcn_combine) + "/params", true, false,
// 1, true)) {
if (paddle_mobile.Load(std::string(g_marker_combine), true)) {
float img_info[3] = {720, 1280, 800.0f / 960.0f};
auto img = reinterpret_cast<float *>(
fpga::fpga_malloc(720 * 1280 * 3 * sizeof(float)));
readStream(g_image_src_float, reinterpret_cast<char *>(img));
std::vector<void *> v(3, nullptr);
paddle_mobile.FeedData({img});
paddle_mobile.Predict_To(-1);
for (int i = 47; i < 52; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "marker_" + std::to_string(i);
// if(i != 58)
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(),
tensor_ptr->numel() * sizeof(float));
// tensor_ptr->numel() * sizeof(float));
dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(),
true); // 20);//tensor_ptr->numel());
/* float result = 0;
std::string str = "softmax_input_data";
float* data =
static_cast<float*>(fpga::fpga_malloc(tensor_ptr->numel() *
sizeof(float))); str = "softmax_output_data"; auto output_ptr =
static_cast<half*>((*tensor_ptr).get_data()); for (int idx = 0; idx <
tensor_ptr->numel(); ++idx)
{
data[idx] = fpga::fp16_2_fp32(output_ptr[idx]);
}
fpga::savefile<float>(str,data, tensor_ptr->numel(), result ); */
}
// paddle_mobile.GetResults(&v);
DLOG << "Computation done";
fpga::fpga_free(img);
}
// PaddleTensor tensor;
// predictor->GetPaddleTensor("fetch2", &tensor);
// for (int i = 0; i < post_nms; i++) {
// auto p = reinterpret_cast<float *>(tensor.data.data());
// std::cout << p[+i] << std::endl;
// }
return 0;
}
......@@ -15,12 +15,15 @@ limitations under the License. */
#ifndef PADDLE_MOBILE_FPGA
#define PADDLE_MOBILE_FPGA
#endif
#include <sys/time.h>
#include <time.h>
#include <fstream>
#include <iomanip>
#include <iostream>
#include "../../src/io/paddle_inference_api.h"
using namespace paddle_mobile;
using namespace paddle_mobile::fpga;
using namespace paddle_mobile; // NOLINT
using namespace paddle_mobile::fpga; // NOLINT
static const char *g_image = "../models/marker/model/image.bin";
static const char *g_model = "../models/marker/model/model";
......@@ -136,44 +139,6 @@ PaddleMobileConfig GetConfig1() {
int main() {
open_device();
PaddleMobileConfig config1 = GetConfig1();
auto predictor1 =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config1);
std::cout << "Finishing loading model" << std::endl;
for (int i = 0; i < 1; ++i) {
int img_length1 = 144 * 14 * 14;
auto img1 =
reinterpret_cast<float *>(fpga_malloc(img_length1 * sizeof(float)));
readStream(g_image1, reinterpret_cast<char *>(img1));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img1;
t_img1.dtypeid = typeid(float);
t_img1.layout = LAYOUT_HWC;
t_img1.shape = std::vector<int>({1, 14, 14, 144});
t_img1.name = "Image information";
t_img1.data.Reset(img1, img_length1 * sizeof(float));
predictor1->FeedPaddleTensors({t_img1});
std::cout << "Finishing feeding data " << std::endl;
predictor1->Predict_From_To(0, -1);
std::cout << "Finishing predicting " << std::endl;
std::vector<paddle_mobile::PaddleTensor> v1; // No need to initialize v
predictor1->FetchPaddleTensors(&v1); // Old data in v will be cleared
std::cout << "Output number is " << v1.size() << std::endl;
for (int fetchNum = 0; fetchNum < v1.size(); fetchNum++) {
std::string dumpName = "marker2_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v1[fetchNum]);
}
}
/////////////////////////////////////
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
......@@ -207,7 +172,16 @@ int main() {
std::cout << "Finishing feeding data " << std::endl;
timeval start11, end11;
long dif_sec, dif_usec; // NOLINT
gettimeofday(&start11, NULL);
predictor->Predict_From_To(0, -1);
gettimeofday(&end11, NULL);
dif_sec = end11.tv_sec - start11.tv_sec;
dif_usec = end11.tv_usec - start11.tv_usec;
std::cout << "marker1 total"
<< " cost time: " << (dif_sec * 1000000 + dif_usec) << " us"
<< std::endl;
std::cout << "Finishing predicting " << std::endl;
std::vector<paddle_mobile::PaddleTensor> v; // No need to initialize v
......@@ -217,5 +191,48 @@ int main() {
std::string dumpName = "marker_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v[fetchNum]);
}
PaddleMobileConfig config1 = GetConfig1();
auto predictor1 =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config1);
std::cout << "Finishing loading model" << std::endl;
for (int i = 0; i < 1; ++i) {
int img_length1 = 144 * 14 * 14;
auto img1 =
reinterpret_cast<float *>(fpga_malloc(img_length1 * sizeof(float)));
readStream(g_image1, reinterpret_cast<char *>(img1));
std::cout << "Finishing initializing data" << std::endl;
struct PaddleTensor t_img1;
t_img1.dtypeid = typeid(float);
t_img1.layout = LAYOUT_HWC;
t_img1.shape = std::vector<int>({1, 14, 14, 144});
t_img1.name = "Image information";
t_img1.data.Reset(img1, img_length1 * sizeof(float));
predictor1->FeedPaddleTensors({t_img1});
std::cout << "Finishing feeding data " << std::endl;
gettimeofday(&start11, NULL);
predictor1->Predict_From_To(0, -1);
gettimeofday(&end11, NULL);
dif_sec = end11.tv_sec - start11.tv_sec;
dif_usec = end11.tv_usec - start11.tv_usec;
std::cout << "marker2 total"
<< " cost time: " << (dif_sec * 1000000 + dif_usec) << " us"
<< std::endl;
std::cout << "Finishing predicting " << std::endl;
std::vector<paddle_mobile::PaddleTensor> v1; // No need to initialize v
predictor1->FetchPaddleTensors(&v1); // Old data in v will be cleared
std::cout << "Output number is " << v1.size() << std::endl;
for (int fetchNum = 0; fetchNum < v1.size(); fetchNum++) {
std::string dumpName = "marker2_api_fetch_" + std::to_string(fetchNum);
dump_stride(dumpName, v1[fetchNum]);
}
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册