diff --git a/lite/backends/fpga/KD/debugger.hpp b/lite/backends/fpga/KD/debugger.hpp index 454e5db8c6ef259f206f37d53111a25c6476178c..f23c242fee848ff205dadf7bccc2d1c0899cbeda 100755 --- a/lite/backends/fpga/KD/debugger.hpp +++ b/lite/backends/fpga/KD/debugger.hpp @@ -19,12 +19,13 @@ #include #include +#include "lite/core/program.h" #include "lite/core/tensor.h" namespace paddle { namespace lite { -// #define FPGA_PRINT_TENSOR +#define FPGA_PRINT_TENSOR class Debugger { public: @@ -35,7 +36,7 @@ class Debugger { void registerOutput(std::string op_type, zynqmp::Tensor* tensor) { if (op_config[op_type]) { - tensor->saveToFile(op_type, true); + // tensor->saveToFile(op_type, true); } } diff --git a/lite/backends/fpga/KD/pes/conv_pe.hpp b/lite/backends/fpga/KD/pes/conv_pe.hpp index 210c02fac5ff3fa135cfc626dbd2d00cbeac12ba..7d5970560b39c91f6168fbf6a120f3b4e2f64993 100644 --- a/lite/backends/fpga/KD/pes/conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/conv_pe.hpp @@ -72,18 +72,110 @@ class ConvPE : public PE { } if (param_.filter->shape().width() == 1 && param_.filter->shape().num() % 16 != 0) { - use_cpu_ = true; + // use_cpu_ = true; } if (!use_cpu_) { // param_.filter->releaseData(); } + } + + void cpu_conv_half_hwc() { + Tensor* input = param_.input; + Tensor* output = param_.output; + + Shape& input_shape = input->shape(); + Shape& out_shape = output->shape(); + + int image_height = input_shape.height(); + int image_width = input_shape.width(); + int image_channels = input_shape.channel(); + int image_pad_h = param_.paddings[0]; + int image_pad_w = param_.paddings[0]; + int kernel_height = param_.filter->shape().height(); + int kernel_width = param_.filter->shape().width(); + int kernel_step_h = param_.strides[0]; + int kernel_step_w = param_.strides[1]; + int dilation_rate = 1; + int out_channel = out_shape.channel(); + int pooled_height_ = out_shape.height(); + int pooled_width_ = out_shape.width(); + int filter_chw = image_channels * kernel_height * kernel_width; + + int kernel_rw = kernel_width + (dilation_rate - 1) * (kernel_width - 1); + int kernel_rh = kernel_height + (dilation_rate - 1) * (kernel_height - 1); + + float* weight = param_.filter->data(); + + Tensor float_input; + Tensor float_output; + float* image_addr = float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); - // exit(-1); + float* out = float_output.mutableData(FP32, output->shape()); + + for (int ph = 0; ph < pooled_height_; ph++) { + for (int pw = 0; pw < pooled_width_; pw++) { + int hstart = ph * kernel_step_h - image_pad_h; + int wstart = pw * kernel_step_w - image_pad_w; + int hend = std::min(hstart + kernel_rh, (int)image_height); + int wend = std::min(wstart + kernel_rw, (int)image_width); + + int hstart_plus = + dilation_rate * ceil(float(image_pad_h - ph * kernel_step_h) / + float(dilation_rate)) - + image_pad_h + ph * kernel_step_h; + int wstart_plus = + dilation_rate * ceil(float(image_pad_w - pw * kernel_step_w) / + float(dilation_rate)) - + image_pad_w + pw * kernel_step_w; + + int hstart_ = hstart < 0 ? hstart_plus : hstart; + int wstart_ = wstart < 0 ? wstart_plus : wstart; + + for (int oc = 0; oc < out_channel; oc++) { + float sum = 0.0f; + const int pool_index = (ph * pooled_width_ + pw) * out_channel + oc; + for (int c = 0; c < image_channels; c++) { + for (int h = hstart_; h < hend; h += dilation_rate) { + int hi = 0; + if (hstart < 0) { + hi = (kernel_rh - (hend - h)) / dilation_rate; + } else { + hi = (h - hstart_) / dilation_rate; + } + + for (int w = wstart_; w < wend; w += dilation_rate) { + int wi = 0; + if (wstart < 0) { + wi = (kernel_rw - (wend - w)) / dilation_rate; + } else { + wi = (w - wstart_) / dilation_rate; + } + + const int index = (h * image_width + w) * image_channels + c; + int weight_index = oc * filter_chw + + kernel_width * kernel_height * c + + kernel_width * hi + wi; + float value = image_addr[index] * weight[weight_index]; + sum += value; + } + } + } + float s = param_.scale()->data()[oc]; + float b = param_.bias()->data()[oc]; + out[pool_index] = sum * s + b; + } + } + } + float_output.saveToFile("fo", true); + exit(-1); } + void cpu_compute() { Tensor* input = param_.input; Tensor* output = param_.output; - input->syncToCPU(); + // input->saveToFile("input", true); + // input->syncToCPU(); Tensor float_input; Tensor float_output; @@ -117,24 +209,39 @@ class ConvPE : public PE { for (int j = 0; j < in_channel; j++) { sum += mi[j]; } - sum *= param_.scale()->data()[i]; - sum += param_.bias()->data()[i]; - out[i * wh + k] = sum; - max = std::max(max, std::abs(sum)); + float fv = sum; + float s = param_.scale()->data()[i]; + float b = param_.bias()->data()[i]; + + fv *= s; + fv += b; + + // std::cout << "\n" << fv << " = " << sum << " x " << s << " + " << b + // << std::endl; + + out[i * wh + k] = fv; + max = std::max(max, std::abs(fv)); } } delete[] mi; + param_.bias()->saveToFile("bias", true); + + exit(-1); + float_output.flush(); + float_output.saveToFile("float_output", true); output->copyFrom(&float_output); + output->invalidate(); output->scale()[0] = max / 127.0; output->scale()[1] = 127.0 / max; // output->saveToFile("cpu", true); } bool dispatch() { - fpga_reset(); + // fpga_reset(); if (use_cpu_) { - cpu_compute(); + // cpu_compute(); + cpu_conv_half_hwc(); return true; } diff --git a/lite/backends/fpga/KD/pes/softmax_pe.cpp b/lite/backends/fpga/KD/pes/softmax_pe.cpp index 7a834169fb5755b6769163021c03779cb4c374ec..fd6ba986507084f03a313fb432d8adc126cac4c6 100755 --- a/lite/backends/fpga/KD/pes/softmax_pe.cpp +++ b/lite/backends/fpga/KD/pes/softmax_pe.cpp @@ -59,6 +59,7 @@ static void softmax(Tensor *X, Tensor *Y) { int batch_size = X->shape().num(); int num_classes = dims[X->shape().dimSize() - 1]; int channels = X->shape().numel() / batch_size / num_classes; + float *x = X->data(); float *y = Y->mutableData(); @@ -140,12 +141,23 @@ bool SoftmaxPE::init() { bool SoftmaxPE::dispatch() { Tensor *input = param_.input; Tensor *output = param_.output; - input->syncToCPU(); Tensor float_input; Tensor float_output; float_input.mutableData(DataType::FP32, input->shape()); - float_input.copyFrom(input); + // input->saveToFile("in", true); + // input->syncToDevice(); + // float_input.copyFrom(input); + + input->syncToCPU(); + float16 *in_data = input->data(); + float *f_data = float_input.data(); + for (int i = 0; i < input->shape().channel(); i++) { + f_data[i] = half_to_float(in_data[i]); + } + + // float_input.invalidate(); + // float_input.saveToFile("fin", true); float *out_data = float_output.mutableData(DataType::FP32, input->shape()); diff --git a/lite/backends/fpga/KD/pes/yolobox_pe.hpp b/lite/backends/fpga/KD/pes/yolobox_pe.hpp index d997fe07833f63bf6843c226477a402d6aeb0357..0322aec6d89acf3bd1585f85409225acdbde2f59 100644 --- a/lite/backends/fpga/KD/pes/yolobox_pe.hpp +++ b/lite/backends/fpga/KD/pes/yolobox_pe.hpp @@ -20,51 +20,61 @@ limitations under the License. */ namespace paddle { namespace zynqmp { - -float sigmoid(float x) { - return 1.0 / (1.0 + std::exp(-x)); -} - -inline void GetYoloBox(float* box, const float* x, const int* anchors, int w, - int h, int an_idx, int grid_size, - int input_size, int index, - int img_height, int img_width) { - box[0] = (w + sigmoid(x[index])) * img_width * 1.0f/ grid_size; +float sigmoid(float x) { return 1.0 / (1.0 + std::exp(-x)); } + +inline void GetYoloBox(float* box, + const float* x, + const int* anchors, + int w, + int h, + int an_idx, + int grid_size, + int input_size, + int index, + int img_height, + int img_width) { + box[0] = (w + sigmoid(x[index])) * img_width * 1.0f / grid_size; box[1] = (h + sigmoid(x[index + 1])) * img_height * 1.0f / grid_size; - box[2] = std::exp(x[index + 2 ]) * anchors[2 * an_idx] * img_width * 1.0f/ + box[2] = std::exp(x[index + 2]) * anchors[2 * an_idx] * img_width * 1.0f / input_size; - box[3] = std::exp(x[index + 3]) * anchors[2 * an_idx + 1] * - img_height * 1.0f / input_size; + box[3] = std::exp(x[index + 3]) * anchors[2 * an_idx + 1] * img_height * + 1.0f / input_size; } -inline int GetEntryIndex(int batch, int an_idx, int hw_idx, - int an_num, int an_stride, int stride, - int entry) { +inline int GetEntryIndex(int batch, + int an_idx, + int hw_idx, + int an_num, + int an_stride, + int stride, + int entry) { return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; } -inline void CalcDetectionBox(float* boxes, float* box, const int box_idx, - const int img_height, - const int img_width) { +inline void CalcDetectionBox(float* boxes, + float* box, + const int box_idx, + const int img_height, + const int img_width) { boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2; boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : 0; - boxes[box_idx + 1] = - boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : 0; - boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 - ? boxes[box_idx + 2] - : (img_width - 1); - boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 - ? boxes[box_idx + 3] - : (img_height - 1); + boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : 0; + boxes[box_idx + 2] = + boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : (img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] + : (img_height - 1); } -inline void CalcLabelScore(float* scores, const float* input, - const int label_idx, const int score_idx, - const int class_num, const float conf) { +inline void CalcLabelScore(float* scores, + const float* input, + const int label_idx, + const int score_idx, + const int class_num, + const float conf) { for (int i = 0; i < class_num; i++) { scores[score_idx + i] = conf * sigmoid(input[label_idx + i]); // std::cout << scores[score_idx + i] << " "; @@ -72,7 +82,6 @@ inline void CalcLabelScore(float* scores, const float* input, // std::cout << std::endl; } - class YoloBoxPE : public PE { public: bool init() { @@ -93,7 +102,6 @@ class YoloBoxPE : public PE { float conf_thresh = param_.confThresh; int downsample_ratio = param_.downsampleRatio; - const int num = input->shape().num(); const int height = input->shape().height(); const int width = input->shape().width(); @@ -134,39 +142,42 @@ class YoloBoxPE : public PE { imgsize->saveToFile("img_size", true); const int32_t* imgsize_data = imgsize->data(); - + Tensor boxes_float; Tensor scores_float; boxes_float.setDataLocation(CPU); - float* boxes_float_data = boxes_float.mutableData(FP32, boxes->shape()); + float* boxes_float_data = + boxes_float.mutableData(FP32, boxes->shape()); memset(boxes_float_data, 0, boxes->shape().numel() * sizeof(float)); scores_float.setDataLocation(CPU); - float* scores_float_data = scores_float.mutableData(FP32, scores->shape()); + float* scores_float_data = + scores_float.mutableData(FP32, scores->shape()); memset(scores_float_data, 0, scores->shape().numel() * sizeof(float)); // float* boxes_data = boxes->mutableData(); // memset(boxes_data, 0, boxes->shape().numel() * sizeof(float)); - + // float* scores_data = scores->mutableData(); // memset(scores_data, 0, scores->shape().numel() * sizeof(float)); float box[4]; // for (int n = 0; n < num; n++) { - // int img_height = imgsize_data[2 * i]; - // int img_width = imgsize_data[2 * i + 1]; + // int img_height = imgsize_data[2 * i]; + // int img_width = imgsize_data[2 * i + 1]; int img_height = imgsize_data[0]; int img_width = imgsize_data[1]; - std::cout << "YoloBoxPE imgsize:" << img_height << "," << img_width << std::endl; + std::cout << "YoloBoxPE imgsize:" << img_height << "," << img_width + << std::endl; int channel = input_float.shape().channel(); int count = 0; for (int h = 0; h < height; h++) { - for (int w = 0; w < width ; w++) { + for (int w = 0; w < width; w++) { for (int n = 0; n < an_num; n++) { - - int obj_idx = channel * width * h + channel * w + n * (5 + class_num) + 4; + int obj_idx = + channel * width * h + channel * w + n * (5 + class_num) + 4; // std::cout << obj_idx << " "; float conf = sigmoid(input_data[obj_idx]); if (conf < conf_thresh) { @@ -174,16 +185,34 @@ class YoloBoxPE : public PE { continue; } - int box_idx = channel * width * h + channel * w + n * (5 + class_num) + 0; - GetYoloBox(box, input_data, anchors_data, w, h, n, height, input_size, - box_idx, img_height, img_width); - - box_idx = h * an_num * 4 * width + an_num * 4 * w + n * 4; - CalcDetectionBox(boxes_float_data, box, box_idx, img_height,img_width); - - int label_idx = channel * width * h + channel * w + n * (5 + class_num) + 5; - int score_idx = h * an_num * class_num * width + an_num * class_num * w + n * class_num; - CalcLabelScore(scores_float_data, input_data, label_idx, score_idx, class_num, conf); + int box_idx = + channel * width * h + channel * w + n * (5 + class_num) + 0; + GetYoloBox(box, + input_data, + anchors_data, + w, + h, + n, + height, + input_size, + box_idx, + img_height, + img_width); + + box_idx = h * an_num * 4 * width + an_num * 4 * w + n * 4; + CalcDetectionBox( + boxes_float_data, box, box_idx, img_height, img_width); + + int label_idx = + channel * width * h + channel * w + n * (5 + class_num) + 5; + int score_idx = h * an_num * class_num * width + + an_num * class_num * w + n * class_num; + CalcLabelScore(scores_float_data, + input_data, + label_idx, + score_idx, + class_num, + conf); } } } @@ -195,11 +224,10 @@ class YoloBoxPE : public PE { void apply(){}; - YoloBoxParam& param() { return param_; } + YoloBoxParam& param() { return param_; } private: YoloBoxParam param_; - }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/tensor.hpp b/lite/backends/fpga/KD/tensor.hpp index 3a2996fed4965b019aeae386103e595ba4c6d3d9..6f8f6f198f1474d3eab805322ca8131f184b14a6 100644 --- a/lite/backends/fpga/KD/tensor.hpp +++ b/lite/backends/fpga/KD/tensor.hpp @@ -70,6 +70,7 @@ class PlaceHolder { explicit PlaceHolder(size_t size) { size_ = size; data_ = fpga_malloc(size_); + // memset(data_, 0, size); } void* data() { return data_; } @@ -80,7 +81,7 @@ class PlaceHolder { ~PlaceHolder() { fpga_free(data_); } - float scale_[2]; + float scale_[2] = {0}; private: void* data_ = nullptr; @@ -409,12 +410,14 @@ class Tensor { if (i < 10) { std::cout << value << ","; } + // if (i > 1000) { // break; // } ofs << value << std::endl; } - usleep(30000); + std::cout << std::endl; + // usleep(30000); ofs.close(); } diff --git a/lite/backends/fpga/monitor.hpp b/lite/backends/fpga/monitor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..61f2a35aed3cbf51aa74bd349201ebe2504ca609 --- /dev/null +++ b/lite/backends/fpga/monitor.hpp @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "lite/core/program.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +class Monitor { + public: + static Monitor& get_instance() { + static Monitor s_instance; + return s_instance; + } + + void inferStart() {} + + void preRun(Instruction& inst) { + VLOG(4) << "Running op:" << const_cast(inst.op())->Type(); + } + + void postRun(Instruction& inst) {} + + void inferEnd() {} + + private: +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 87ebaeeb4b8522f8654c0a6e587ee4982e03390d..186a621cf35673ef28510965ba5702663d06e9ab 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -134,7 +134,6 @@ void PrecisionCastPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. std::list nodes; for (auto& node : graph->StmtTopologicalOrder()) { - // if (node->IsStmt()) { // auto& s = node->AsStmt(); // std::cout << "type_precision type:" << s.op_type() << std::endl; diff --git a/lite/core/program.cc b/lite/core/program.cc index 9a6790f65430007b490030c338db1232403fda8e..81d50a28baf84d7848f0e06237caa7cce54c12db 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -25,6 +25,10 @@ #include "lite/core/profile/precision_profiler.h" #endif +#ifdef LITE_WITH_FPGA +#include "lite/backends/fpga/monitor.hpp" +#endif + namespace paddle { namespace lite { @@ -151,23 +155,41 @@ void RuntimeProgram::Run() { inst_precision_profiler.GetSummaryHeader(); #endif +#ifdef LITE_WITH_FPGA + Monitor& monitor = Monitor::get_instance(); + monitor.inferStart(); +#endif + for (auto& inst : instructions_) { +#ifdef LITE_WITH_FPGA + monitor.preRun(inst); +#endif + #ifndef LITE_WITH_FPGA if (inst.is_feed_fetch_op()) continue; #endif + #ifdef LITE_WITH_CUDA if (inst.need_sync()) { inst.Sync(); } #endif inst.Run(); + +#ifdef LITE_WITH_FPGA + monitor.postRun(inst); +#endif + #ifdef LITE_WITH_PRECISION_PROFILE -#ifndef LITE_WITH_FPGA precision_profiler_summary += inst_precision_profiler.GetInstPrecision(&inst); -#endif #endif // LITE_WITH_PRECISION_PROFILE } + +#ifdef LITE_WITH_FPGA + monitor.inferEnd(); +#endif + #ifdef LITE_WITH_PROFILE LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 1); #endif diff --git a/lite/kernels/fpga/conv_compute.cc b/lite/kernels/fpga/conv_compute.cc index 7754d70d822ebc22e3c1bc4db50504ea2d71e14c..e1b5b691fdd6fccecd19ca9ca93cd7f38e634c9d 100644 --- a/lite/kernels/fpga/conv_compute.cc +++ b/lite/kernels/fpga/conv_compute.cc @@ -25,12 +25,46 @@ namespace kernels { namespace fpga { using float16 = zynqmp::float16; +using lite_api::ActivationType; void ConvCompute::PrepareForRun() { auto& param = this->Param(); param.output->mutable_data(); int pad_h = (*param.paddings)[0]; int pad_w = (*param.paddings)[2]; + + zynqmp::ActiveType active_type = zynqmp::TYPE_NONE; + float leaky_relu_factor = 0; + + switch (param.activation_param.active_type) { + case ActivationType::kIndentity: + active_type = zynqmp::TYPE_NONE; + break; + case ActivationType::kRelu: + active_type = zynqmp::TYPE_RELU; + break; + case ActivationType::kRelu6: + active_type = zynqmp::TYPE_RELU6; + break; + case ActivationType::kPRelu: + case ActivationType::kLeakyRelu: + active_type = zynqmp::TYPE_LEAKY_RELU; + leaky_relu_factor = param.activation_param.Leaky_relu_alpha; + break; + case ActivationType::kSigmoid: + active_type = zynqmp::TYPE_SIGMOID; + break; + case ActivationType::kTanh: + case ActivationType::kSwish: + case ActivationType::kExp: + case ActivationType::kAbs: + case ActivationType::kHardSwish: + case ActivationType::kReciprocal: + default: + throw("not supported activation"); + break; + } + // ==================================================== if (param.x->ZynqTensor()->shape().channel() != 1 && param.groups == param.x->ZynqTensor()->shape().channel()) { @@ -45,17 +79,12 @@ void ConvCompute::PrepareForRun() { conv_param.paddings = std::vector({pad_h, pad_w}); conv_param.dilations = *param.dilations; fill_scale_bias_const(&conv_param); - conv_param.bias()->copyFrom(param.bias->ZynqTensor()); - - if (param.fuse_relu) { - conv_param.activeParam.type = zynqmp::TYPE_RELU; + if (param.bias != nullptr) { + conv_param.bias()->copyFrom(param.bias->ZynqTensor()); } - if (param.activation_param.Leaky_relu_alpha > 0.001) { - conv_param.activeParam.type = zynqmp::TYPE_LEAKY_RELU; - conv_param.activeParam.leaky_relu_factor = - param.activation_param.Leaky_relu_alpha; - } + conv_param.activeParam.type = active_type; + conv_param.activeParam.leaky_relu_factor = leaky_relu_factor; dw_conv_pe_.init(); dw_conv_pe_.apply(); @@ -74,21 +103,12 @@ void ConvCompute::PrepareForRun() { conv_param.bias()->copyFrom(param.bias->ZynqTensor()); } - if (param.fuse_relu) { - conv_param.activeParam.type = zynqmp::TYPE_RELU; - } - - if (param.activation_param.Leaky_relu_alpha > 0.001) { - conv_param.activeParam.type = zynqmp::TYPE_LEAKY_RELU; - conv_param.activeParam.leaky_relu_factor = - param.activation_param.Leaky_relu_alpha; - } + conv_param.activeParam.type = active_type; + conv_param.activeParam.leaky_relu_factor = leaky_relu_factor; conv_pe_.init(); conv_pe_.apply(); } - // std::cout << "Leaky_relu_alpha:" << param.activation_param.Leaky_relu_alpha - // << std::endl; } void ConvCompute::Run() { diff --git a/lite/kernels/fpga/multiclass_nms_compute.cc b/lite/kernels/fpga/multiclass_nms_compute.cc index 9e1e106223bae49e3fd23cdc2d67cf099ced0278..0222484897a44d3b70cfb3cebfd5fc1e5360febe 100644 --- a/lite/kernels/fpga/multiclass_nms_compute.cc +++ b/lite/kernels/fpga/multiclass_nms_compute.cc @@ -227,7 +227,7 @@ void MultiClassNMS(const operators::MulticlassNmsParam& param, SliceOneClass(scores, c, &score_slice); SliceOneClass(bboxes, c, &bbox_slice); } - NMSFast(bboxes,// TODO + NMSFast(bboxes, // TODO score_slice, score_threshold, nms_threshold, diff --git a/lite/kernels/fpga/reshape_compute.cc b/lite/kernels/fpga/reshape_compute.cc index 3f088700f1ee9787d0d19eff00a809cb3daf9ddd..97adbe5eaa66377c05f32ea05216623f1d6ff6eb 100644 --- a/lite/kernels/fpga/reshape_compute.cc +++ b/lite/kernels/fpga/reshape_compute.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/fpga/reshape_compute.h" #include + +#include "lite/backends/fpga/KD/debugger.hpp" +#include "lite/kernels/fpga/reshape_compute.h" #include "lite/operators/reshape_op.h" namespace paddle { @@ -48,21 +50,31 @@ void FlattenCompute::Run() { #endif } -void ReshapeCompute::Run() { +void ReshapeCompute::PrepareForRun() { auto& param = Param(); auto x = param.x; auto output = param.output; auto output_dims = output->dims(); - x->ZynqTensor()->unalignImage(); - - // x->ZynqTensor()->saveToFile("ri", true); - output->Resize(output_dims); output->mutable_data(); +} + +void ReshapeCompute::Run() { + auto& param = Param(); + auto x = param.x; + auto output = param.output; + // auto output_dims = output->dims(); + + // x->ZynqTensor()->invalidate();// TODO + x->ZynqTensor()->unalignImage(); + x->ZynqTensor()->flush(); + + // output->Resize(output_dims); + // output->mutable_data(); if (param.inplace) { - output->ShareDataWith(*x); + // output->ShareDataWith(*x); } else { // output->CopyDataFrom(*x); } @@ -70,7 +82,7 @@ void ReshapeCompute::Run() { output->ZynqTensor()->copyFrom(x->ZynqTensor()); // output->ZynqTensor()->saveToFile("ro", true); output->ZynqTensor()->flush(); - output->ZynqTensor()->setAligned(x->ZynqTensor()->aligned()); +// output->ZynqTensor()->setAligned(x->ZynqTensor()->aligned()); #ifdef FPGA_PRINT_TENSOR Debugger::get_instance().registerOutput("reshape", output->ZynqTensor()); diff --git a/lite/kernels/fpga/reshape_compute.h b/lite/kernels/fpga/reshape_compute.h index 8a3b3c266ed4b6fdf73c6304a166d8600bee4dc8..d157054bbca5c4ff3e97ec9511ecd1da5e8234e7 100755 --- a/lite/kernels/fpga/reshape_compute.h +++ b/lite/kernels/fpga/reshape_compute.h @@ -25,6 +25,7 @@ namespace fpga { class ReshapeCompute : public KernelLite { public: + void PrepareForRun() override; void Run() override; virtual ~ReshapeCompute() = default; @@ -41,6 +42,7 @@ class FlattenCompute class ReshapeComputeFpgaToHost : public KernelLite { public: + void PrepareForRun() override; void Run() override; virtual ~ReshapeComputeFpgaToHost() = default; diff --git a/lite/kernels/fpga/softmax_compute.cc b/lite/kernels/fpga/softmax_compute.cc index 8e51a716b092489920f0cfbc729f0cb22e5c71c8..415f500c6fbbe1037668f16bb1b08d4ed06690dd 100755 --- a/lite/kernels/fpga/softmax_compute.cc +++ b/lite/kernels/fpga/softmax_compute.cc @@ -14,6 +14,7 @@ #include "lite/kernels/fpga/softmax_compute.h" #include "lite/backends/arm/math/funcs.h" +#include "lite/backends/fpga/KD/debugger.hpp" namespace paddle { namespace lite { @@ -36,11 +37,10 @@ void SoftmaxCompute::PrepareForRun() { void SoftmaxCompute::Run() { zynqmp::SoftmaxParam& softmax_param = pe_.param(); - // softmax_param.input->saveToFile("softmax_in", true); pe_.dispatch(); - softmax_param.output->flush(); -// softmax_param.output->saveToFile("softmax", true); +// softmax_param.output->flush(); +// // softmax_param.output->saveToFile("softmax", true); #ifdef FPGA_PRINT_TENSOR Debugger::get_instance().registerOutput("softmax", softmax_param.output); #endif diff --git a/lite/kernels/fpga/yolo_box_compute.cc b/lite/kernels/fpga/yolo_box_compute.cc index 1e90cf30c4b0d6b5bf93d85811ef8fbe7324ba34..4821e740cc58b460653f78e4a87fc31aabe74450 100644 --- a/lite/kernels/fpga/yolo_box_compute.cc +++ b/lite/kernels/fpga/yolo_box_compute.cc @@ -28,7 +28,6 @@ void YoloBoxCompute::PrepareForRun() { lite::Tensor* ImgSize = param.ImgSize; lite::Tensor* Boxes = param.Boxes; lite::Tensor* Scores = param.Scores; - Boxes->mutable_data(); Scores->mutable_data(); @@ -45,16 +44,14 @@ void YoloBoxCompute::PrepareForRun() { pe_.init(); pe_.apply(); - } void YoloBoxCompute::Run() { - pe_.dispatch(); zynqmp::YoloBoxParam& yolobox_param = pe_.param(); yolobox_param.imgSize->saveToFile("img_size", true); -// exit(-1); + // exit(-1); yolobox_param.outputBoxes->saveToFile("yolo_boxes", true); yolobox_param.outputScores->saveToFile("yolo_scores", true); } diff --git a/lite/kernels/fpga/yolo_box_compute.h b/lite/kernels/fpga/yolo_box_compute.h index e4c573cf6719ea1b49fb83b431182ff57c8f4796..18e9446a4383b91d9302487eed707261d246942c 100644 --- a/lite/kernels/fpga/yolo_box_compute.h +++ b/lite/kernels/fpga/yolo_box_compute.h @@ -27,13 +27,13 @@ namespace fpga { using float16 = zynqmp::float16; -class YoloBoxCompute +class YoloBoxCompute : public KernelLite { public: void PrepareForRun() override; void Run() override; - virtual ~YoloBoxCompute() { + virtual ~YoloBoxCompute(){ };