提交 00aa83e5 编写于 作者: C chonwhite

modified nms

上级 f28625a4
......@@ -56,24 +56,25 @@ class Debugger {
std::unordered_map<std::string, bool> op_config;
std::unordered_map<std::string, float> tick_tock_map;
Debugger() {
// op_config["concat"] = true;
// op_config["pooling"] = true;
// op_config["conv"] = true;
// op_config["dropout"] = true;
// op_config["dwconv"] = true;
// op_config["ew_add"] = true;
// op_config["ew_mul"] = true;
// op_config["crop"] = true;
// op_config["feed"] = true;
// op_config["fc"] = true;
// op_config["mul"] = true;
// op_config["fetch"] = true;
// op_config["boxes"] = true;
// op_config["scores"] = true;
// op_config["nms"] = true;
// op_config["pb_boxes"] = true;
// op_config["pb_variances"] = true;
// op_config["softmax"] = true;
op_config["concat"] = true;
op_config["pooling"] = true;
op_config["conv"] = true;
op_config["dropout"] = true;
op_config["dwconv"] = true;
op_config["ew_add"] = true;
op_config["ew_mul"] = true;
op_config["crop"] = true;
op_config["feed"] = true;
op_config["fetch"] = true;
op_config["fc"] = true;
op_config["mul"] = true;
op_config["boxes"] = true;
op_config["scores"] = true;
op_config["nms"] = true;
op_config["pb_boxes"] = true;
op_config["pb_variances"] = true;
op_config["softmax"] = true;
op_config["split"] = true;
}
};
......
......@@ -61,7 +61,6 @@ void reset_device() {
// memory management;
void *fpga_malloc(size_t size) {
#ifdef PADDLE_MOBILE_OS_LINUX
void *ptr = reinterpret_cast<void *>(
......@@ -205,7 +204,7 @@ int get_device_info(const struct DeviceInfo &args) {
int perform_bypass(const struct BypassArgs &args) {
int ret = -1;
int size = args.image.channels * args.image.width * args.image.height;
int max_size = 1 << 22;
int max_size = 1 << 20;
float times = 1.0 * size / max_size;
int count = static_cast<int>(times);
......
......@@ -241,10 +241,13 @@ void PriorBoxPE::compute_prior_box() {
}
boxes.flush();
boxes.syncToCPU();
// boxes.syncToCPU();
variances.flush();
output_boxes->copyFrom(&boxes);
output_variances->copyFrom(&variances);
output_boxes->invalidate();
output_variances->invalidate();
}
void PriorBoxPE::apply() {}
......@@ -253,8 +256,9 @@ bool PriorBoxPE::dispatch() {
if (cachedBoxes_ == nullptr) {
cachedBoxes_ = new Tensor();
cachedVariances_ = new Tensor();
cachedBoxes_->mutableData<float>(FP32, param_.outputBoxes->shape());
cachedVariances_->mutableData<float>(FP32, param_.outputVariances->shape());
cachedBoxes_->mutableData<float16>(FP16, param_.outputBoxes->shape());
cachedVariances_->mutableData<float16>(FP16,
param_.outputVariances->shape());
cachedBoxes_->setDataLocation(CPU);
cachedVariances_->setDataLocation(CPU);
compute_prior_box();
......
......@@ -389,11 +389,17 @@ class Tensor {
float value = 0;
if (dataType_ == FP32) {
value = data<float>()[i];
} else if (dataType_ == FP16) {
}
if (dataType_ == FP16) {
value = half_to_float(data<float16>()[i]);
} else {
}
if (dataType_ == INT8) {
value = data<int8_t>()[i];
}
if (dataType_ == INT32) {
value = data<int32_t>()[i];
}
ofs << value << std::endl;
}
ofs.close();
......
......@@ -81,8 +81,7 @@ class DDimLite {
return !(a == b);
}
~DDimLite() {
}
~DDimLite() {}
private:
std::vector<value_type> data_;
......@@ -112,9 +111,7 @@ class TensorLite {
return zynq_tensor_->data<R>() + offset_;
}
void Resize(const DDimLite &ddim) {
dims_ = ddim;
}
void Resize(const DDimLite &ddim) { dims_ = ddim; }
void Resize(const std::vector<int64_t> &x) { dims_ = DDimLite(x); }
const DDimLite &dims() const { return dims_; }
......@@ -212,6 +209,28 @@ class TensorLite {
void mutable_data_internal();
};
template <typename T>
zynqmp::DataType get_date_type() {
zynqmp::DataType data_type = zynqmp::FP32;
if (typeid(T) == typeid(float)) {
data_type = zynqmp::FP32;
}
if (typeid(T) == typeid(zynqmp::float16)) {
data_type = zynqmp::FP16;
}
if (typeid(T) == typeid(int)) {
data_type = zynqmp::INT32;
}
if (typeid(T) == typeid(int32_t)) {
data_type = zynqmp::INT32;
}
if (typeid(T) == typeid(int8_t)) {
data_type = zynqmp::INT8;
}
return data_type;
}
template <typename T, typename R>
R *TensorLite::mutable_data() {
std::vector<int> v;
......@@ -237,13 +256,8 @@ R *TensorLite::mutable_data() {
break;
}
zynqmp::Shape input_shape(layout_type, v);
zynqmp::DataType data_type = zynqmp::FP32;
if (typeid(T) == typeid(float)) {
data_type = zynqmp::FP32;
}
if (typeid(T) == typeid(zynqmp::float16)) {
data_type = zynqmp::FP16;
}
zynqmp::DataType data_type = get_date_type<T>();
if (zynq_tensor_.get() == nullptr) {
zynq_tensor_.reset(new zynqmp::Tensor());
}
......
......@@ -25,7 +25,7 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kCUDA)) {
if (place.target == TARGET(kCUDA) || place.target == TARGET(kFPGA)) {
act_types.push_back("leaky_relu");
break;
}
......
......@@ -8,7 +8,7 @@ set(fpga_deps fpga_target_wrapper kernel_fpga)
add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps})
# add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps})
# add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps})
# add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps})
......@@ -28,8 +28,9 @@ add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fp
add_kernel(reshape_compute_fpga FPGA basic SRCS reshape_compute.cc DEPS ${fpga_deps} reshape_op)
# add_kernel(sequence_pool_compute_fpga FPGA basic SRCS sequence_pool_compute.cc DEPS ${fpga_deps})
add_kernel(scale_compute_fpga FPGA basic SRCS scale_compute.cc DEPS ${fpga_deps})
# add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps})
# add_kernel(transpose_compute_fpga FPGA basic SRCS transpose_compute.cc DEPS ${fpga_deps})
add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps})
add_kernel(split_compute_fpga FPGA basic SRCS split_compute.cc DEPS ${fpga_deps})
add_kernel(transpose_compute_fpga FPGA basic SRCS transpose_compute.cc DEPS ${fpga_deps})
add_kernel(io_copy_compute_fpga FPGA basic SRCS io_copy_compute.cc DEPS ${fpga_deps})
add_kernel(calib_compute_fpga FPGA basic SRCS calib_compute.cc DEPS ${fpga_deps})
......
......@@ -45,21 +45,32 @@ class IoCopyHostToFpgaCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kFPGA));
param.y->mutable_data<float16>();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.x->ZynqTensor()->flush();
if (param.x->ZynqTensor()->dataType() == zynqmp::INT32) {
param.y->mutable_data<int>();
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
return;
}
param.y->ZynqTensor()->invalidate();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
if (param.x->ZynqTensor()->dataType() == zynqmp::FP32) {
param.y->mutable_data<float16>();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
param.y->ZynqTensor()->invalidate();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
}
auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod();
}
......
......@@ -318,14 +318,29 @@ void MultiClassOutput(const Tensor& scores,
void MulticlassNmsCompute::Run() {
auto& param = Param<operators::MulticlassNmsParam>();
auto* boxes = param.bboxes;
auto* scores = param.scores;
auto* boxes_in = param.bboxes;
auto* scores_in = param.scores;
auto* outs = param.out;
outs->mutable_data<float>();
auto score_dims = scores->dims();
auto score_dims = boxes_in->dims();
auto score_size = score_dims.size();
Tensor boxes_float;
Tensor scores_float;
boxes_float.Resize(boxes_in->dims());
scores_float.Resize(scores_in->dims());
boxes_float.mutable_data<float>();
scores_float.mutable_data<float>();
boxes_float.ZynqTensor()->copyFrom(boxes_in->ZynqTensor());
scores_float.ZynqTensor()->copyFrom(scores_in->ZynqTensor());
Tensor* boxes = &boxes_float;
Tensor* scores = &scores_float;
auto box_dims = boxes->dims();
int64_t box_dim = boxes->dims()[2];
......@@ -383,6 +398,7 @@ void MulticlassNmsCompute::Run() {
MultiClassOutput<float>(
scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out);
outs->ZynqTensor()->copyFrom(out.ZynqTensor());
out.ZynqTensor()->saveToFile("nms_oo", true);
}
outs->Resize({static_cast<int64_t>(e - s), out_dim});
}
......@@ -402,16 +418,16 @@ void MulticlassNmsCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::MulticlassNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// REGISTER_LITE_KERNEL(multiclass_nms,
// kFPGA,
// kFP16,
// kNHWC,
// paddle::lite::kernels::fpga::MulticlassNmsCompute,
// def)
// .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
// .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
// .Finalize();
REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA,
......@@ -427,5 +443,8 @@ REGISTER_LITE_KERNEL(multiclass_nms,
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -131,3 +131,27 @@ REGISTER_LITE_KERNEL(prior_box,
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// REGISTER_LITE_KERNEL(prior_box,
// kFPGA,
// kFP16,
// kNHWC,
// paddle::lite::kernels::fpga::PriorBoxCompute,
// def)
// .BindInput("Input",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindInput("Image",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindOutput("Boxes",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindOutput("Variances",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/split_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
void SplitCompute::PrepareForRun() {
auto& param = Param<operators::SplitParam>();
zynqmp::SplitParam& split_param = pe_.param();
split_param.input = param.x->ZynqTensor();
auto& dout = param.output;
for (int i = 0; i < dout.size(); i++) {
dout[i]->mutable_data<zynqmp::float16>();
split_param.outputs.push_back(dout[i]->ZynqTensor());
}
pe_.init();
pe_.apply();
}
void SplitCompute::Run() {
zynqmp::SplitParam& split_param = pe_.param();
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
auto& dout = param.output;
for (int i = 0; i < dout.size(); i++) {
Debugger::get_instance().registerOutput("split", split_param.outputs[0]);
}
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
split, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::SplitCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/split_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class SplitCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~SplitCompute() = default;
private:
zynqmp::SplitPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -81,7 +81,17 @@ void transposeCompute(operators::TransposeParam param) {
}
// Transpose
void TransposeCompute::Run() { auto& param = this->Param<param_t>(); }
void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
param.output->mutable_data<zynqmp::float16>();
param.x->ZynqTensor()->invalidate();
param.x->ZynqTensor()->unalignImage();
if (param.x->dims().size() != 4) {
transposeCompute(param);
} else {
param.output->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
}
// Transpose2
void Transpose2Compute::Run() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册