提交 1893489c 编写于 作者: M MyPandaShaoxiang

feature: update fpga kernel patch

test=develop
上级 a39f92ea
......@@ -4,29 +4,42 @@ endif()
set(fpga_deps fpga_target_wrapper kernel_fpga)
add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_acivation_fpga SRCS activation_compute_test.cc DEPS ${lite_kernel_deps} activation_compute_fpga ${fpga_deps})
# add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps})
# add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps})
# add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_conv_fpga SRCS conv_compute_test.cc DEPS ${lite_kernel_deps} conv_compute_fpga ${fpga_deps})
# add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps})
add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps})
add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_elementwise_fpga SRCS elementwise_compute_test.cc DEPS ${lite_kernel_deps} elementwise_compute_fpga ${fpga_deps})
# add_kernel(feed_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps})
# add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps})
add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps})
add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps})
# add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps})
add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_pooling_compute_fpga SRCS pooling_compute_test.cc DEPS ${lite_kernel_deps} pooling_compute_fpga ${fpga_deps})
add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps})
# add_kernel(reshape_compute_fpga FPGA basic SRCS reshape_compute.cc DEPS ${fpga_deps} reshape_op)
# add_kernel(sequence_pool_compute_fpga FPGA basic SRCS sequence_pool_compute.cc DEPS ${fpga_deps})
add_kernel(scale_compute_fpga FPGA basic SRCS scale_compute.cc DEPS ${fpga_deps})
add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_softmax_compute_fpga SRCS softmax_compute_test.cc DEPS ${lite_kernel_deps} softmax_compute_fpga ${fpga_deps})
add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
lite_cc_test(test_fc_compute_fpga SRCS fc_compute_test.cc DEPS ${lite_kernel_deps} fc_compute_fpga ${fpga_deps})
# add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps})
# add_kernel(transpose_compute_fpga FPGA basic SRCS transpose_compute.cc DEPS ${fpga_deps})
add_kernel(io_copy_compute_fpga FPGA basic SRCS io_copy_compute.cc DEPS ${fpga_deps})
add_kernel(calib_compute_fpga FPGA basic SRCS calib_compute.cc DEPS ${fpga_deps})
add_kernel(layout_compute_fpga FPGA basic SRCS layout_compute.cc DEPS ${fpga_deps})
add_kernel(feed_compute_fpga FPGA basic SRCS feed_compute.cc DEPS ${fpga_deps})
add_kernel(fetch_compute_fpga FPGA basic SRCS fetch_compute.cc DEPS ${fpga_deps})
# add_kernel(while_compute_fpga FPGA extra SRCS while_compute.cc DEPS ${fpga_deps})
# add_kernel(write_to_array_compute_fpga FPGA extra SRCS write_to_array_compute.cc DEPS ${fpga_deps})
# lite_cc_test(test_acivation_fpga SRCS activation_compute_test.cc DEPS ${lite_kernel_deps} activation_compute_fpga ${fpga_deps})
lite_cc_test(test_conv_fpga SRCS conv_compute_test.cc DEPS ${lite_kernel_deps} conv_compute_fpga ${fpga_deps})
lite_cc_test(test_elementwise_fpga SRCS elementwise_compute_test.cc DEPS ${lite_kernel_deps} elementwise_compute_fpga ${fpga_deps})
lite_cc_test(test_fc_compute_fpga SRCS fc_compute_test.cc DEPS ${lite_kernel_deps} fc_compute_fpga ${fpga_deps})
lite_cc_test(test_pooling_compute_fpga SRCS pooling_compute_test.cc DEPS ${lite_kernel_deps} pooling_compute_fpga ${fpga_deps})
# lite_cc_test(test_softmax_compute_fpga SRCS softmax_compute_test.cc DEPS ${lite_kernel_deps} softmax_compute_fpga ${fpga_deps})
......@@ -23,24 +23,24 @@ namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void CalibComputeFp32ToFP16::Run() {
auto& param = this->Param<operators::CalibParam>();
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<float16>(TARGET(kFPGA));
for (int i = 0; i < param.input->numel(); ++i) {
dout[i] = zynqmp::float_to_half(din[i]);
}
param.output->mutable_data<float16>();
param.output->ZynqTensor()->copyFrom(param.input->ZynqTensor());
auto out_lod = param.output->mutable_lod();
*out_lod = param.input->lod();
return;
}
void CalibComputeFP16ToFp32::Run() {
auto& param = this->Param<operators::CalibParam>();
const auto* din = param.input->data<float16>();
auto* dout = param.output->mutable_data<float>(TARGET(kFPGA));
for (int i = 0; i < param.input->numel(); ++i) {
dout[i] = zynqmp::half_to_float(din[i]);
}
auto* dout = param.output->mutable_data<float>();
param.output->ZynqTensor()->copyFrom(param.input->ZynqTensor());
auto out_lod = param.output->mutable_lod();
*out_lod = param.input->lod();
return;
}
......
文件模式从 100644 更改为 100755
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/concat_compute.h"
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void ConcatCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
param.output->mutable_data<float16>();
// ====================================================
zynqmp::ConcatParam& concat_param = pe_.param();
for (auto t : param.x) {
concat_param.inputs.push_back(t->ZynqTensor());
}
concat_param.output = param.output->ZynqTensor();
concat_param.axis = param.axis;
pe_.init();
pe_.apply();
}
void ConcatCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ConcatParam& concat_param = pe_.param();
Debugger::get_instance().registerOutput("concat", concat_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(concat,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ConcatCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/concat_op.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/concat_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class ConcatCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ConcatParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConcatCompute() = default;
private:
zynqmp::ConcatPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -13,9 +13,12 @@
// limitations under the License.
#include "lite/kernels/fpga/conv_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -25,37 +28,61 @@ using float16 = zynqmp::float16;
void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
// ====================================================
zynqmp::ConvParam& conv_param = pe_.param();
param.output->mutable_data<float16>();
int pad_h = (*param.paddings)[0];
int pad_w = (*param.paddings)[2];
// ====================================================
if (param.x->ZynqTensor()->shape().channel() != 1 &&
param.groups == param.x->ZynqTensor()->shape().channel()) {
zynqmp::DepthwiseConvParam& conv_param = dw_conv_pe_.param();
// filter_.setDataType(zynqmp::FP32);
conv_param.input = param.x->ZynqTensor();
conv_param.output = param.output->ZynqTensor();
conv_param.filter = param.filter->ZynqTensor();
conv_param.groups = param.groups;
conv_param.strides = param.strides;
auto paddings = *param.paddings;
conv_param.paddings = param.paddings;
conv_param.dilations = param.dilations;
bool pad_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (!pad_equal) {
LOG(FATA) << "This pad not support ! " << paddings[0] << ", " << paddings[1]
<< ", " << paddings[2] << ", " << paddings[3];
conv_param.input = param.x->ZynqTensor();
conv_param.output = param.output->ZynqTensor();
conv_param.filter = param.filter->ZynqTensor();
conv_param.filter->setDataType(zynqmp::FP32);
conv_param.groups = param.groups;
conv_param.strides = param.strides;
conv_param.paddings = std::vector<int>({pad_h, pad_w});
conv_param.dilations = *param.dilations;
fill_scale_bias_const(&conv_param);
conv_param.bias()->copyFrom(param.bias->ZynqTensor());
conv_param.relu.enabled = param.fuse_relu;
dw_conv_pe_.init();
dw_conv_pe_.apply();
} else {
zynqmp::ConvParam& conv_param = conv_pe_.param();
conv_param.input = param.x->ZynqTensor();
conv_param.output = param.output->ZynqTensor();
conv_param.filter = param.filter->ZynqTensor();
conv_param.filter->setDataType(zynqmp::FP32);
conv_param.groups = param.groups;
conv_param.strides = param.strides;
conv_param.paddings = std::vector<int>({pad_h, pad_w});
conv_param.dilations = *param.dilations;
fill_scale_bias_const(&conv_param);
if (param.bias != nullptr) {
conv_param.bias()->copyFrom(param.bias->ZynqTensor());
}
conv_param.relu.enabled = param.fuse_relu;
conv_pe_.init();
conv_pe_.apply();
}
fill_scale_bias_const(&conv_param);
conv_param.bias()->copyFrom(param.bias->ZynqTensor());
conv_param.relu.enabled = param.fuse_relu;
pe_.init();
pe_.apply();
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
zynqmp::ConvParam& conv_param = pe_.param();
pe_.dispatch();
if (param.x->ZynqTensor()->shape().channel() != 1 &&
param.groups == param.x->ZynqTensor()->shape().channel()) {
dw_conv_pe_.dispatch();
} else {
conv_pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ConvParam& conv_param = conv_pe_.param();
Debugger::get_instance().registerOutput("conv", conv_param.output);
#endif
}
}
} // namespace fpga
......
......@@ -14,11 +14,13 @@
#pragma once
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/conv_pe.hpp"
#include "lite/core/kernel.h"
#include "lite/operators/conv_op.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/conv_pe.hpp"
#include "lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -36,7 +38,8 @@ class ConvCompute
~ConvCompute() {}
private:
zynqmp::ConvPE pe_;
zynqmp::ConvPE conv_pe_;
zynqmp::DepthwiseConvPE dw_conv_pe_;
};
} // namespace fpga
......
......@@ -141,15 +141,13 @@ void conv_compute_ref(const operators::ConvParam& param) {
int group = param.groups;
int kernel_w = param.filter->dims()[2];
int kernel_h = param.filter->dims()[3];
auto paddings = *param.paddings;
auto dilations = *para.dilations;
int stride_w = param.strides[0];
int stride_h = param.strides[1];
int dila_w = dilations[0];
int dila_h = dilations[1];
int pad_w = paddings[2];
int pad_h = paddings[0];
int dila_w = (*param.dilations)[0];
int dila_h = (*param.dilations)[1];
int pad_w = (*param.paddings)[2];
int pad_h = (*param.paddings)[0];
bool flag_bias = (param.bias != nullptr);
bool flag_relu = param.fuse_relu;
......@@ -279,14 +277,11 @@ TEST(conv_fpga, compute) {
param.bias = &bias;
}
param.fuse_relu = flag_relu;
std::vector<int> paddings = {
padding, padding, padding, padding};
*param.paddings = std::vector<int>(
{padding, padding, padding, padding});
param.strides = std::vector<int>({stride, stride});
std::vector<int> dilations = {dilation, dilation};
param.paddings =
std::make_shared<std::vector<int>>(paddings);
param.dilations =
std::make_shared<std::vector<int>>(dilations);
*param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Launch();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/dropout_compute.h"
#include <string>
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/backends/fpga/KD/float16.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
void DropoutCompute::PrepareForRun() {
auto& param = Param<operators::DropoutParam>();
param.output->mutable_data<float16>();
zynqmp::ScaleParam& scale_param = pe_.param();
scale_param.input = param.x->ZynqTensor();
scale_param.output = param.output->ZynqTensor();
int channel = scale_param.input->shape().channel();
zynqmp::Tensor* scale = new zynqmp::Tensor();
zynqmp::Tensor* bias = new zynqmp::Tensor();
zynqmp::Shape shape(zynqmp::N, {channel});
float* scale_data = scale->mutableData<float>(zynqmp::FP32, shape);
float* bias_data = bias->mutableData<float>(zynqmp::FP32, shape);
float scale_value = 1 - param.dropout_prob;
for (int i = 0; i < channel; ++i) {
scale_data[i] = scale_value;
bias_data[i] = 0.0f;
}
scale->flush();
bias->flush();
scale_param.bias = bias;
scale_param.scale = scale;
pe_.init();
pe_.apply();
}
void DropoutCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ScaleParam& scale_param = pe_.param();
Debugger::get_instance().registerOutput("dropout", scale_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(dropout,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::DropoutCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/pes/scale_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
class DropoutCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~DropoutCompute() = default;
private:
zynqmp::ScalePE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,7 @@
#include "lite/kernels/fpga/elementwise_compute.h"
#include <string>
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
......@@ -37,7 +38,13 @@ void ElementwiseAddCompute::PrepareForRun() {
pe_.init();
pe_.apply();
}
void ElementwiseAddCompute::Run() { pe_.dispatch(); }
void ElementwiseAddCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ElementwiseAddParam& ew_param = pe_.param();
Debugger::get_instance().registerOutput("ew_add", ew_param.output);
#endif
}
void ElementwiseAddActivationCompute::PrepareForRun() {
zynqmp::ElementwiseAddParam& ew_param = pe_.param();
......@@ -53,7 +60,54 @@ void ElementwiseAddActivationCompute::PrepareForRun() {
pe_.init();
pe_.apply();
}
void ElementwiseAddActivationCompute::Run() { pe_.dispatch(); }
void ElementwiseAddActivationCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ElementwiseAddParam& ew_param = pe_.param();
Debugger::get_instance().registerOutput("ew_add", ew_param.output);
#endif
}
void ElementwiseMulCompute::PrepareForRun() {
zynqmp::ScaleParam& scale_param = pe_.param();
auto& param = Param<operators::ElementwiseParam>();
param.Out->mutable_data<float16>();
scale_param.input = param.X->ZynqTensor();
scale_param.output = param.Out->ZynqTensor();
scale_param.relu.enabled = false;
int channel = scale_param.input->shape().channel();
zynqmp::Tensor* scale = new zynqmp::Tensor();
zynqmp::Tensor* bias = new zynqmp::Tensor();
scale_param.scale = scale;
scale_param.bias = bias;
zynqmp::Shape shape(zynqmp::N, {channel});
float* scale_data = scale->mutableData<float>(zynqmp::FP32, shape);
float* bias_data = bias->mutableData<float>(zynqmp::FP32, shape);
float scale_value = param.Y->data<float>()[0];
for (int i = 0; i < channel; ++i) {
if (param.Y->dims().production() != 1) {
scale_value = param.Y->ZynqTensor()->data<float>()[i];
}
scale_data[i] = scale_value;
bias_data[i] = 0;
}
pe_.init();
pe_.apply();
}
void ElementwiseMulCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ScaleParam& scale_param = pe_.param();
Debugger::get_instance().registerOutput("ew_mul_in", scale_param.input);
Debugger::get_instance().registerOutput("ew_mul", scale_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
......@@ -70,10 +124,7 @@ REGISTER_LITE_KERNEL(elementwise_add,
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
......@@ -100,3 +151,20 @@ REGISTER_LITE_KERNEL(
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ElementwiseMulCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -16,6 +16,7 @@
#include <algorithm>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/elementwise_add_pe.hpp"
#include "lite/backends/fpga/KD/pes/scale_pe.hpp"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
......@@ -50,6 +51,18 @@ class ElementwiseAddActivationCompute
zynqmp::ElementwiseAddPE pe_;
};
class ElementwiseMulCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~ElementwiseMulCompute() = default;
private:
zynqmp::ScalePE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/fpga/fc_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
......@@ -30,7 +31,6 @@ void FcCompute::PrepareForRun() {
zynqmp::FullyConnectedParam& fc_param = pe_.param();
param.output->mutable_data<float16>();
fc_param.input = param.input->ZynqTensor();
fc_param.output = param.output->ZynqTensor();
fc_param.filter = param.w->ZynqTensor();
......@@ -41,8 +41,11 @@ void FcCompute::PrepareForRun() {
}
void FcCompute::Run() {
auto& param = this->Param<param_t>();
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::FullyConnectedParam& fc_param = pe_.param();
Debugger::get_instance().registerOutput("fc", fc_param.output);
#endif
}
} // namespace fpga
......
......@@ -37,10 +37,6 @@ class FcCompute
private:
zynqmp::FullyConnectedPE pe_;
zynqmp::Tensor input_;
zynqmp::Tensor output_;
zynqmp::Tensor filter_;
zynqmp::Tensor bias_;
};
} // namespace fpga
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/fpga/feed_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
......@@ -25,21 +26,29 @@ using float16 = zynqmp::float16;
void FeedCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
// ====================================================
zynqmp::InputParam& conv_param = pe_.param();
Tensor& x = param.feed_list->at(param.col);
param.out->Resize(x.dims());
param.out->mutable_data<float16>();
conv_param.input = x.ZynqTensor();
conv_param.output = param.out->ZynqTensor();
// ====================================================
zynqmp::InputParam& feed_param = pe_.param();
feed_param.input = x.ZynqTensor();
feed_param.output = param.out->ZynqTensor();
pe_.init();
pe_.apply();
}
void FeedCompute::Run() {
auto& param = this->Param<param_t>();
Tensor& x = param.feed_list->at(param.col);
pe_.dispatch();
auto out_lod = param.out->mutable_lod();
*out_lod = x.lod();
#ifdef FPGA_PRINT_TENSOR
zynqmp::InputParam& feed_param = pe_.param();
Debugger::get_instance().registerOutput("feed", feed_param.output);
#endif
}
} // namespace fpga
......@@ -50,7 +59,7 @@ void FeedCompute::Run() {
REGISTER_LITE_KERNEL(
feed, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::FeedCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
......
......@@ -32,8 +32,6 @@ class FeedCompute
private:
zynqmp::InputPE pe_;
zynqmp::Tensor input_;
zynqmp::Tensor output_;
};
} // namespace fpga
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/fetch_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
......@@ -25,35 +26,65 @@ using float16 = zynqmp::float16;
void FetchCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
// ====================================================
zynqmp::OutputParam& conv_param = pe_.param();
zynqmp::OutputParam& fetch_param = pe_.param();
auto fetch_list = param.fetch_list;
if (fetch_list->size() <= static_cast<size_t>(param.col)) {
fetch_list->resize(param.col + 1);
}
Tensor& out = param.fetch_list->at(param.col);
out.Resize(param.input->dims());
out.mutable_data<float16>();
out.mutable_data<float>();
conv_param.input = param.input->ZynqTensor();
conv_param.output = out.ZynqTensor();
fetch_param.input = param.input->ZynqTensor();
fetch_param.output = out.ZynqTensor();
pe_.init();
pe_.apply();
}
void FetchCompute::Run() { pe_.dispatch(); }
void FetchCompute::Run() {
pe_.dispatch();
auto& param = this->Param<param_t>();
#ifdef FPGA_PRINT_TENSOR
zynqmp::OutputParam& fetch_param = pe_.param();
Debugger::get_instance().registerOutput("fetch", fetch_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
fetch, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::FetchCompute, def)
REGISTER_LITE_KERNEL(fetch,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::FetchCompute,
fpga_host)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(fetch,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::FetchCompute,
host_host)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
......@@ -31,8 +31,6 @@ class FetchCompute
private:
zynqmp::OutputPE pe_;
zynqmp::Tensor input_;
zynqmp::Tensor output_;
};
} // namespace fpga
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
// #include <chrono>
#include <iostream>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/gru_utils.h"
#include "lite/backends/arm/math/sequence2batch.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/kernels/fpga/gru_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/backends/fpga/KD/pes/gru_util.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
inline lite_api::ActivationType get_gru_act_type(const std::string& type) {
if (type == "sigmoid") {
return lite_api::ActivationType::kSigmoid;
} else if (type == "tanh") {
return lite_api::ActivationType::kTanh;
} else if (type == "relu") {
return lite_api::ActivationType::kRelu;
} else if (type == "identity") {
return lite_api::ActivationType::kIndentity;
} else {
LOG(FATAL) << "unsupported activation type: " << type;
}
}
void GRUCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
param.hidden->mutable_data<float>();
auto input = param.input;
auto h0 = param.h0;
auto weight = param.weight;
auto bias = param.bias;
zynqmp::GRUParam& gru_param = pe_.param();
gru_param.input = input->ZynqTensor();
if (h0 != nullptr) {
gru_param.h0 = h0->ZynqTensor();
}
gru_param.weight = weight->ZynqTensor();
gru_param.bias = bias->ZynqTensor();
gru_param.batch_gate = param.batch_gate->ZynqTensor();
gru_param.batch_reset_hidden_prev =
param.batch_reset_hidden_prev->ZynqTensor();
gru_param.batch_hidden = param.batch_hidden->ZynqTensor();
gru_param.hidden = param.hidden->ZynqTensor();
gru_param.gate_activation = param.gate_activation;
gru_param.activation = param.activation;
pe_.init();
pe_.apply();
}
void GRUCompute::Run() {
auto& param = this->Param<param_t>();
param.hidden->mutable_data<float>();
// auto& ctx = this->ctx_->template As<ARMContext>();
// inputs
auto input = param.input;
auto h0 = param.h0;
auto weight = param.weight;
auto bias = param.bias;
// outputs
auto batch_gate = param.batch_gate;
auto batch_reset_hidden_prev = param.batch_reset_hidden_prev;
auto batch_hidden = param.batch_hidden;
auto hidden = param.hidden;
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
auto batch_size = input->dims()[0];
const float* weight_data = weight->data<float>();
float* batch_gate_data = batch_gate->mutable_data<float>();
lite::arm::math::LoDTensor2BatchFunctor<float> to_batch;
to_batch(*input, batch_gate, true, param.is_reverse); // 1.
save_tensor(batch_gate, "_batch_gate.txt");
if (bias) {
auto bias_data = bias->data<float>(); // 2.
lite::arm::math::gru_add_with_bias(batch_gate_data,
bias_data,
batch_gate_data,
batch_size,
frame_size * 3);
// save_tensor(const_cast<Tensor*>(bias), "_bias.txt");
save_tensor(batch_gate, "_after_bias.txt");
std::cout << "================= bias =================\n";
}
zynqmp::GRUTensors gru_tensors;
lite::arm::math::GRUMetaValue<float> gru_value;
gru_value.gate_weight = const_cast<float*>(weight_data);
gru_value.state_weight =
const_cast<float*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
std::vector<uint64_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
// lite::arm::math::ReorderInitState<float>(*h0, order, &ordered_h0, true);
// //3.
gru_value.prev_out_value = ordered_h0.mutable_data<float>();
gru_tensors.pre_output = ordered_h0.ZynqTensor();
std::cout << "================= h0 =================\n";
} else {
gru_value.prev_out_value = nullptr;
gru_tensors.pre_output = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node = get_gru_act_type(param.activation);
auto active_gate = get_gru_act_type(param.gate_activation);
save_float(gru_value.gate_weight, "_gate_weight.txt", weight->numel());
batch_gate->ZynqTensor()->saveToFile("batch_gate.txt");
zynqmp::Tensor float_input;
zynqmp::Tensor hidden_out;
std::cout << "seq_len::" << seq_len << std::endl;
// exit(-1);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
gru_value.output_value =
batch_hidden->mutable_data<float>() + bstart * batch_hidden->dims()[1];
gru_value.gate_value =
batch_gate->mutable_data<float>() + bstart * batch_gate->dims()[1];
gru_value.reset_output_value =
batch_reset_hidden_prev->mutable_data<float>() +
bstart * batch_reset_hidden_prev->dims()[1];
zynqmp::Shape float_input_shape(zynqmp::NC,
{cur_batch_size, batch_gate->dims()[1]});
float* float_data =
float_input.mutableData<float>(zynqmp::FP32, float_input_shape);
memcpy(float_data,
gru_value.gate_value,
batch_gate->dims()[1] * sizeof(float));
float_input.flush();
float* hidden_data =
hidden_out.mutableData<float>(zynqmp::FP32, float_input_shape);
// memcpy(hidden_prev_data, )
// zynqmp::Tensor* gate = pe_.gate();
gru_tensors.gate = &float_input;
gru_tensors.output = &hidden_out;
pe_.GRUCOmpute(gru_tensors,
frame_size,
cur_batch_size,
active_node,
active_gate,
param.origin_mode);
// TODO(chonwhite): copy data back to original tensor;
gru_tensors.pre_output = gru_tensors.output;
// gru_value.prev_out_value = gru_value.output_value;
}
lite::arm::math::Batch2LoDTensorFunctor<float> to_seq; // 5.
*(batch_hidden->mutable_lod()) = batch_gate->lod();
batch_hidden->mutable_data<float>();
to_seq(*batch_hidden, hidden);
save_tensor(const_cast<Tensor*>(input), "_input.txt");
save_tensor(hidden, "_gru.txt");
exit(-1);
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
gru, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::GRUCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/backends/fpga/KD/pes/elementwise_add_pe.hpp"
#include "lite/backends/fpga/KD/pes/fully_connected_pe.hpp"
#include "lite/backends/fpga/KD/pes/gru_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class GRUCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::GRUParam;
GRUCompute() = default;
void PrepareForRun() override;
void Run() override;
virtual ~GRUCompute() = default;
private:
zynqmp::Tensor pre_output_;
zynqmp::Tensor pre_bias_;
zynqmp::Tensor weight_;
zynqmp::ElementwiseAddPE bias_ew_pe_;
zynqmp::FullyConnectedPE pre_out_pe_;
zynqmp::FullyConnectedPE reset_out_pe_;
// zynqmp::Tensor input_;
zynqmp::GRUPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/fpga/im2sequence_compute.h"
#include "lite/backends/fpga/KD/float16.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void im2sequence(const float16* input,
const int input_c,
const int input_h,
const int input_w,
const int kernel_h,
const int kernel_w,
const int pad_top,
const int pad_bottom,
const int pad_left,
const int pad_right,
const int stride_h,
const int stride_w,
const int out_h,
const int out_w,
float16* out) {
int window_size = kernel_h * kernel_w;
int out_rows = out_h * out_w;
int out_cols = input_c * window_size;
int H_pad = input_h + pad_top + pad_bottom;
int W_pad = input_w + pad_left + pad_right;
float16 zero = zynqmp::float_to_half(0.0f);
for (int h_id = 0; h_id < out_h; h_id++) {
for (int w_id = 0; w_id < out_w; w_id++) {
// consider dilation.
int start_h = h_id * stride_h - pad_top;
int start_w = w_id * stride_w - pad_left;
for (int c_id = 0; c_id < input_c; c_id++) {
for (int k_h_id = 0; k_h_id < kernel_h; k_h_id++) {
int in_h_id = start_h + k_h_id;
bool exceed_flag = (in_h_id < 0) || (in_h_id >= H_pad);
int out_start_id =
(h_id * out_w + w_id) * out_cols + c_id * window_size;
for (int k_w_id = 0; k_w_id < kernel_w; k_w_id++) {
int in_w_id = start_w + k_w_id;
exceed_flag = exceed_flag || (in_w_id < 0) || (in_w_id >= W_pad);
int input_id = (c_id * input_h + in_h_id) * input_w + in_w_id;
int out_id = out_start_id + k_h_id * kernel_w + k_w_id;
out[out_id] = exceed_flag ? zero : input[input_id];
}
}
}
}
}
}
template <typename T>
void hwc_to_chw(T* chw_data,
const T* hwc_data,
int num,
int channel,
int height,
int width) {
int chw = channel * height * width;
int wc = width * channel;
int wh = width * height;
int index = 0;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
chw_data[n * chw + c * wh + h * width + w] = hwc_data[index];
index++;
}
}
}
}
}
void Im2SequenceCompute::PrepareForRun() {}
void Im2SequenceCompute::Run() {
auto& param = this->Param<operators::Im2SequenceParam>();
auto kernels = param.kernels;
auto strides = param.strides;
auto paddings = param.paddings;
const auto* x_data = param.X->data<float16>();
float16* o_data =
reinterpret_cast<float16*>(param.Out->mutable_data<float16>());
float16* o2 = o_data;
auto input_dims = param.X->dims();
int im_num = input_dims[0];
int im_size = param.X->numel() / im_num;
param.X->ZynqTensor()->syncToCPU();
float16* chw_data = new float16[param.X->numel()];
hwc_to_chw<float16>(chw_data,
x_data,
param.X->dims()[0],
param.X->dims()[1],
param.X->dims()[2],
param.X->dims()[3]);
const float16* in = chw_data;
int out_cols = input_dims[1] * kernels[0] * kernels[1];
int total_rows = 0;
std::vector<uint64_t> im_offset;
im_offset.push_back(total_rows);
if (param.Y) {
const auto* y_data = param.Y->data<int>();
auto out_strides = param.out_strides;
std::vector<int> im_real_h;
std::vector<int> im_real_w;
std::vector<int> out_h_vec;
std::vector<int> out_w_vec;
for (int im_id = 0; im_id < im_num; im_id++) {
int real_h = y_data[im_id * 2 + 0];
int real_w = y_data[im_id * 2 + 1];
int tmp_real_h = (real_h + out_strides[0] - 1) / out_strides[0];
int tmp_real_w = (real_w + out_strides[1] - 1) / out_strides[1];
im_real_h.push_back(tmp_real_h);
im_real_w.push_back(tmp_real_w);
int out_h =
(tmp_real_h + paddings[0] + paddings[1] - kernels[0]) / strides[0] +
1;
int out_w =
(tmp_real_w + paddings[2] + paddings[3] - kernels[1]) / strides[1] +
1;
out_h_vec.push_back(out_h);
out_w_vec.push_back(out_w);
total_rows += out_h * out_w;
im_offset.push_back(total_rows);
}
auto out_dims = param.Out->dims();
out_dims[0] = total_rows;
param.Out->Resize(out_dims);
int out_offset = 0;
for (int im_id = 0; im_id < im_num; im_id++) {
im2sequence(in + im_id * im_size,
input_dims[1],
input_dims[2],
input_dims[3],
param.kernels[0],
param.kernels[1],
param.paddings[0],
param.paddings[1],
param.paddings[2],
param.paddings[3],
param.strides[0],
param.strides[1],
out_h_vec[im_id],
out_w_vec[im_id],
o2 + im_offset[im_id] * out_cols);
}
} else {
int out_h =
(input_dims[2] + paddings[0] + paddings[1] - kernels[0]) / strides[0] +
1;
int out_w =
(input_dims[3] + paddings[2] + paddings[3] - kernels[1]) / strides[1] +
1;
for (int im_id = 0; im_id < im_num; im_id++) {
int out_size_per_im = out_h * out_w * out_cols;
im2sequence(in + im_id * im_size,
input_dims[1],
input_dims[2],
input_dims[3],
param.kernels[0],
param.kernels[1],
param.paddings[0],
param.paddings[1],
param.paddings[2],
param.paddings[3],
param.strides[0],
param.strides[1],
out_h,
out_w,
o2 + im_id * out_size_per_im);
im_offset.push_back(uint64_t((im_id + 1) * out_h * out_w));
}
auto lod = param.Out->mutable_lod();
lod->resize(1);
(*lod)[0] = im_offset;
}
delete[] chw_data;
param.Out->ZynqTensor()->flush();
param.Out->ZynqTensor()->copyScaleFrom(param.X->ZynqTensor());
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(im2sequence,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::Im2SequenceCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
// #include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/operators/im2sequence_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class Im2SequenceCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::Im2SequenceParam;
void PrepareForRun() override;
void Run() override;
~Im2SequenceCompute() {}
private:
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -45,7 +45,23 @@ class IoCopyHostToFpgaCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kFPGA));
param.y->CopyDataFrom(*param.x);
param.y->mutable_data<float16>();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
param.y->ZynqTensor()->invalidate();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod();
}
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
......@@ -81,7 +97,30 @@ class IoCopyFpgaToHostCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kFPGA));
param.y->CopyDataFrom(*param.x);
param.y->mutable_data<float>();
param.y->ZynqTensor()->setDataType(zynqmp::FP32);
param.x->ZynqTensor()->syncToDevice();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
param.y->ZynqTensor()->flush();
auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod();
// param.x->ZynqTensor()->saveToFile("io_x", true);
// param.y->ZynqTensor()->saveToFile("io_y", true);
}
std::string doc() const override { return "Copy IO from FPGA to HOST"; }
......@@ -100,12 +139,27 @@ REGISTER_LITE_KERNEL(io_copy,
host_to_device)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy,
kFPGA,
kAny,
kAny,
paddle::lite::kernels::fpga::IoCopyHostToFpgaCompute,
host_to_device_any_any)
.BindInput("Input",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy,
......@@ -119,9 +173,25 @@ REGISTER_LITE_KERNEL(io_copy,
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy,
kFPGA,
kAny,
kAny,
paddle::lite::kernels::fpga::IoCopyFpgaToHostCompute,
device_to_host_22)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy_once,
......@@ -132,12 +202,12 @@ REGISTER_LITE_KERNEL(io_copy_once,
host_to_device_once)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy_once,
......@@ -148,8 +218,8 @@ REGISTER_LITE_KERNEL(io_copy_once,
device_to_host_once)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
......
......@@ -26,16 +26,95 @@ namespace fpga {
using float16 = zynqmp::float16;
void TransHwcToChw(Tensor* dest, const Tensor* src) {}
void TransChwToHwc(Tensor* dest, const Tensor* src) {}
template <typename T>
void convert_to_hwc(
T* chw_data, T* hwc_data, int num, int channel, int height, int width) {
int chw = channel * height * width;
int wc = width * channel;
int index = 0;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
hwc_data[n * chw + h * wc + w * channel + c] = chw_data[index];
index++;
}
}
}
}
}
template <typename T>
void hwc_to_chw(
T* chw_data, T* hwc_data, int num, int channel, int height, int width) {
int chw = channel * height * width;
int wc = width * channel;
int wh = width * height;
int index = 0;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
chw_data[n * chw + c * wh + h * width + w] = hwc_data[index];
index++;
}
}
}
}
}
void TransHwcToChw(Tensor* dest, const Tensor* src) {
if (src->ZynqTensor()->dataType() == zynqmp::FP32) {
float* chw = dest->mutable_data<float>();
float* hwc = const_cast<float*>(src->data<float>());
int num = dest->dims()[0];
int channel = dest->dims()[1];
int height = 1;
if (dest->dims().size() > 2) {
height = dest->dims()[2];
}
int width = 1;
if (dest->dims().size() > 3) {
width = dest->dims()[3];
}
hwc_to_chw<float>(chw, hwc, num, channel, height, width);
}
if (src->ZynqTensor()->dataType() == zynqmp::FP16) {
float16* chw = dest->mutable_data<float16>();
float16* hwc = const_cast<float16*>(src->data<float16>());
int num = dest->dims()[0];
int channel = dest->dims()[1];
int height = 1;
if (dest->dims().size() > 2) {
height = dest->dims()[2];
}
int width = 1;
if (dest->dims().size() > 3) {
width = dest->dims()[3];
}
hwc_to_chw<float16>(chw, hwc, num, channel, height, width);
}
}
void TransChwToHwc(Tensor* dest, const Tensor* src) {
std::cout << "chw to hwc \n";
exit(-1);
}
class TransHwcToChwCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kAny), DATALAYOUT(kNHWC)> {
public:
void Run() override {
auto& param = Param<operators::LayoutParam>();
auto out_data = param.y->mutable_data<float16>(TARGET(kFPGA));
param.x->ZynqTensor()->syncToCPU();
TransHwcToChw(param.y, param.x);
param.y->ZynqTensor()->flush();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod();
}
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
......@@ -97,6 +176,22 @@ REGISTER_LITE_KERNEL(layout,
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kFPGA,
kAny,
kNHWC,
paddle::lite::kernels::fpga::TransHwcToChwCompute,
hwc_to_chw_arm_float)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kFPGA,
kAny,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/mul_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void MulCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
// ====================================================
zynqmp::FullyConnectedParam& fc_param = pe_.param();
param.output->mutable_data<float16>();
fc_param.input = param.x->ZynqTensor();
fc_param.output = param.output->ZynqTensor();
fc_param.filter = param.y->ZynqTensor();
fc_param.bias = &bias_;
int channel = fc_param.filter->shape().channel();
zynqmp::Shape bias_shape(zynqmp::N, {channel});
float* bias_data =
fc_param.bias->mutableData<float>(zynqmp::FP32, bias_shape);
memset(bias_data, 0, channel * sizeof(float));
bias_.flush();
pe_.init();
pe_.apply();
}
void mul(MulCompute* k) {
auto& param = k->Param<operators::MulParam>();
int num = param.x->dims()[0];
int channel = param.x->dims()[1];
int fn = param.y->dims()[1];
float16* out_data = param.output->mutable_data<float16>();
int g_index = 0;
for (int n = 0; n < 1; n++) {
for (int on = 0; on < fn; on++) {
float sum = 0;
int si = 0;
for (int c = 0; c < channel; c++) {
float value = zynqmp::half_to_float(param.x->data<float16>()[si]);
int index = c * fn + on;
float weight = param.y->data<float>()[index];
sum += value * weight;
si++;
}
out_data[g_index] = zynqmp::float_to_half(sum);
g_index++;
}
}
}
void MulCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::FullyConnectedParam& fc_param = pe_.param();
Debugger::get_instance().registerOutput("mul", fc_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
mul, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::MulCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/fully_connected_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class MulCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::MulParam;
void PrepareForRun() override;
void Run() override;
virtual ~MulCompute() = default;
private:
zynqmp::FullyConnectedPE pe_;
zynqmp::Tensor bias_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/multiclass_nms_compute.h"
#include <map>
#include <utility>
#include <vector>
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static void GetMaxScoreIndex(const std::vector<T>& scores,
const T threshold,
int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(),
sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
T PolyIoU(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
LOG(FATAL) << "PolyIoU not implement.";
}
template <class T>
void SliceOneClass(const Tensor& items,
const int class_id,
Tensor* one_class_item) {
T* item_data = one_class_item->mutable_data<T>();
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
const int64_t class_num = items.dims()[1];
if (items.dims().size() == 3) {
int64_t item_size = items.dims()[2];
for (int i = 0; i < num_item; ++i) {
std::memcpy(item_data + i * item_size,
items_data + i * class_num * item_size + class_id * item_size,
sizeof(T) * item_size);
}
} else {
for (int i = 0; i < num_item; ++i) {
item_data[i] = items_data[i * class_num + class_id];
}
}
}
template <typename T>
void NMSFast(const Tensor& bbox,
const Tensor& scores,
const T score_threshold,
const T nms_threshold,
const T eta,
const int64_t top_k,
std::vector<int>* selected_indices,
const bool normalized) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
// 8: [x1 y1 x2 y2 x3 y3 x4 y4]
// 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = T(0.);
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
normalized);
}
// 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
if (box_size == 8 || box_size == 16 || box_size == 24 ||
box_size == 32) {
overlap = PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
box_size,
normalized);
}
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T>
void MultiClassNMS(const operators::MulticlassNmsParam& param,
const Tensor& scores,
const Tensor& bboxes,
const int scores_size,
std::map<int, std::vector<int>>* indices,
int* num_nmsed_out) {
int64_t background_label = param.background_label;
int64_t nms_top_k = param.nms_top_k;
int64_t keep_top_k = param.keep_top_k;
bool normalized = param.normalized;
T nms_threshold = static_cast<T>(param.nms_threshold);
T nms_eta = static_cast<T>(param.nms_eta);
T score_threshold = static_cast<T>(param.score_threshold);
int num_det = 0;
int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1];
for (int64_t c = 0; c < class_num; ++c) {
Tensor bbox_slice, score_slice;
if (c == background_label) continue;
if (scores_size == 3) {
scores.Slice<T>(score_slice, c, c + 1);
bbox_slice = bboxes;
} else {
score_slice.Resize({scores.dims()[0], 1});
bbox_slice.Resize({scores.dims()[0], 4});
SliceOneClass<T>(scores, c, &score_slice);
SliceOneClass<T>(bboxes, c, &bbox_slice);
}
NMSFast(bboxes,
score_slice,
score_threshold,
nms_threshold,
nms_eta,
nms_top_k,
&((*indices)[c]),
normalized);
if (scores_size == 2) {
std::stable_sort((*indices)[c].begin(), (*indices)[c].end());
}
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
Tensor score_slice;
const T* sdata;
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
if (scores_size == 3) {
sdata = scores_data + label * scores.dims()[1];
} else {
score_slice.Resize({scores.dims()[0], 1});
SliceOneClass<T>(scores, label, &score_slice);
sdata = score_slice.data<T>();
}
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(),
score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
if (scores_size == 2) {
for (const auto& it : new_indices) {
int label = it.first;
std::stable_sort(new_indices[label].begin(), new_indices[label].end());
}
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T>
void MultiClassOutput(const Tensor& scores,
const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size,
Tensor* outs) {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
if (scores_size == 2) {
box_size = bboxes.dims()[2];
}
int64_t out_dim = box_size + 2;
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->mutable_data<T>();
const T* sdata;
Tensor bbox;
bbox.Resize({scores.dims()[0], box_size});
int count = 0;
for (const auto& it : selected_indices) {
int label = it.first;
const std::vector<int>& indices = it.second;
if (scores_size == 2) {
SliceOneClass<T>(bboxes, label, &bbox);
} else {
sdata = scores_data + label * predict_dim;
}
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
odata[count * out_dim] = label; // label
const T* bdata;
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++;
}
}
}
void MulticlassNmsCompute::Run() {
auto& param = Param<operators::MulticlassNmsParam>();
auto* boxes = param.bboxes;
auto* scores = param.scores;
auto* outs = param.out;
outs->mutable_data<float>();
auto score_dims = scores->dims();
auto score_size = score_dims.size();
auto box_dims = boxes->dims();
int64_t box_dim = boxes->dims()[2];
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<uint64_t> batch_starts = {0};
int64_t batch_size = score_dims[0];
int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0;
Tensor boxes_slice, scores_slice;
int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores->Slice<float>(scores_slice, i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes->Slice<float>(boxes_slice, i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
} else {
auto boxes_lod = boxes->lod().back();
scores->Slice<float>(scores_slice, boxes_lod[i], boxes_lod[i + 1]);
boxes->Slice<float>(boxes_slice, boxes_lod[i], boxes_lod[i + 1]);
}
std::map<int, std::vector<int>> indices;
MultiClassNMS<float>(
param, scores_slice, boxes_slice, score_size, &indices, &num_nmsed_out);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
uint64_t num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({1, 1});
float* od = outs->mutable_data<float>();
od[0] = -1;
batch_starts = {0, 1};
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores->Slice<float>(scores_slice, i, i + 1);
boxes->Slice<float>(boxes_slice, i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice.Resize({score_dims[2], box_dim});
} else {
auto boxes_lod = boxes->lod().back();
scores->Slice<float>(scores_slice, boxes_lod[i], boxes_lod[i + 1]);
boxes->Slice<float>(boxes_slice, boxes_lod[i], boxes_lod[i + 1]);
}
int64_t s = static_cast<int64_t>(batch_starts[i]);
int64_t e = static_cast<int64_t>(batch_starts[i + 1]);
if (e > s) {
Tensor out;
outs->Slice<float>(out, s, e);
MultiClassOutput<float>(
scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out);
outs->ZynqTensor()->copyFrom(out.ZynqTensor());
}
}
}
LoD lod;
lod.emplace_back(batch_starts);
outs->set_lod(lod);
#ifdef FPGA_PRINT_TENSOR
Debugger::get_instance().registerOutput("boxes", boxes->ZynqTensor());
Debugger::get_instance().registerOutput("scores", scores->ZynqTensor());
Debugger::get_instance().registerOutput("nms", outs->ZynqTensor());
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::MulticlassNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::MulticlassNmsCompute,
def2)
.BindInput("BBoxes",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Scores",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class MulticlassNmsCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void Run() override;
virtual ~MulticlassNmsCompute() = default;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/norm_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void NormCompute::PrepareForRun() {
auto& param = this->Param<operators::NormParam>();
param.Out->mutable_data<float16>();
zynqmp::NormParam& norm_param = pe_.param();
norm_param.input = param.X->ZynqTensor();
norm_param.output = param.Out->ZynqTensor();
norm_param.epsilon = param.epsilon;
pe_.init();
pe_.apply();
}
void NormCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::NormParam& norm_param = pe_.param();
Debugger::get_instance().registerOutput("norm", norm_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
norm, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::NormCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Norm",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
// #include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/norm_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class NormCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::NormParam;
void PrepareForRun() override;
void Run() override;
~NormCompute() {}
private:
zynqmp::NormPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -18,6 +18,8 @@
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -26,26 +28,32 @@ namespace fpga {
using float16 = zynqmp::float16;
void PoolCompute::PrepareForRun() {
zynqmp::PoolingParam& pool_param = pe_.param();
auto& param = Param<operators::PoolParam>();
param.output->mutable_data<float16>();
zynqmp::PoolingParam& pool_param = pe_.param();
pool_param.input = param.x->ZynqTensor();
pool_param.output = param.output->ZynqTensor();
pool_param.relu.enabled = false;
pool_param.type = param.pooling_type == "max" ? zynqmp::PoolingType::MAX
: zynqmp::PoolingType::AVERAGE;
pool_param.globalPooling = param.global_pooling;
pool_param.kernelSize = param.ksize;
pool_param.strides = param.strides;
pool_param.paddings = param.paddings;
int pad_h = (*param.paddings)[0];
int pad_w = (*param.paddings)[2];
pool_param.paddings = std::vector<int>({pad_h, pad_w});
pe_.init();
pe_.apply();
}
void PoolCompute::Run() { pe_.dispatch(); }
void PoolCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::PoolingParam& pool_param = pe_.param();
Debugger::get_instance().registerOutput("pooling", pool_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
......
......@@ -46,7 +46,7 @@ std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
if (param_->global_pooling) {
ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
param_->paddings[i] = 0;
(*param_->paddings)[i] = 0;
ksize[i] = static_cast<int>(x_dims[i + 2]);
}
}
......@@ -59,7 +59,7 @@ std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
for (size_t i = 0; i < param_->ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(x_dims[i + 2],
param_->ksize[i],
param_->paddings[i],
(*param_->paddings)[i],
param_->strides[i],
param_->ceil_mode));
}
......@@ -76,7 +76,7 @@ void pool_compute_ref(const operators::PoolParam& param) {
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = param.paddings;
std::vector<int> paddings = *param.paddings;
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
......@@ -103,7 +103,7 @@ void pool_compute_ref(const operators::PoolParam& param) {
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int pad_w = paddings[2];
if (global_pooling == true) {
for (int n = 0; n < in_n; ++n) {
......@@ -230,7 +230,7 @@ TEST(pool_fpga, compute) {
}
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
*param.paddings = {pad, pad, pad, pad};
param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/kernels/fpga/prior_box_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip,
std::vector<float>* output_aspect_ratior) {
constexpr float epsilon = 1e-6;
output_aspect_ratior->clear();
output_aspect_ratior->push_back(1.0f);
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
float ar = input_aspect_ratior[i];
bool already_exist = false;
for (size_t j = 0; j < output_aspect_ratior->size(); ++j) {
if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) {
already_exist = true;
break;
}
}
if (!already_exist) {
output_aspect_ratior->push_back(ar);
if (flip) {
output_aspect_ratior->push_back(1.0f / ar);
}
}
}
}
void PriorBoxCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
bool is_flip = param.flip;
bool is_clip = param.clip;
std::vector<float> min_size = param.min_sizes;
std::vector<float> max_size = param.max_sizes;
std::vector<float> aspect_ratio = param.aspect_ratios;
std::vector<float> variance = param.variances_;
int img_w = param.img_w;
int img_h = param.img_h;
float step_w = param.step_w;
float step_h = param.step_h;
float offset = param.offset;
std::vector<float> aspect_ratios_vec;
ExpandAspectRatios(aspect_ratio, is_flip, &aspect_ratios_vec);
size_t prior_num = aspect_ratios_vec.size() * min_size.size();
prior_num += max_size.size();
std::vector<std::string> order = param.order;
bool min_max_aspect_ratios_order = param.min_max_aspect_ratios_order;
int win1 = param.input->dims()[3];
int hin1 = param.input->dims()[2];
DDim shape_out({hin1, win1, prior_num, 4});
param.boxes->Resize(shape_out);
param.variances->Resize(shape_out);
param.boxes->mutable_data<float>();
param.variances->mutable_data<float>();
// ====================================================
zynqmp::PriorBoxParam& priobox_param = pe_.param();
priobox_param.input = param.input->ZynqTensor();
priobox_param.image = param.image->ZynqTensor();
priobox_param.outputBoxes = param.boxes->ZynqTensor();
priobox_param.outputVariances = param.variances->ZynqTensor();
priobox_param.minSizes = param.min_sizes;
priobox_param.maxSizes = param.max_sizes;
priobox_param.aspectRatios = param.aspect_ratios;
priobox_param.variances = param.variances_;
priobox_param.minMaxAspectRatiosOrder = min_max_aspect_ratios_order;
priobox_param.flip = param.flip;
priobox_param.clip = param.clip;
priobox_param.stepW = param.step_w;
priobox_param.stepH = param.step_h;
priobox_param.offset = param.offset;
pe_.init();
pe_.apply();
}
void PriorBoxCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::PriorBoxParam& priobox_param = pe_.param();
Debugger::get_instance().registerOutput("pb_boxes",
priobox_param.outputBoxes);
Debugger::get_instance().registerOutput("pb_variances",
priobox_param.outputVariances);
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(prior_box,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::PriorBoxCompute,
def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Image",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// REGISTER_LITE_KERNEL(prior_box,
// kFPGA,
// kFP16,
// kNHWC,
// paddle::lite::kernels::fpga::PriorBoxCompute,
// def)
// .BindInput("Input", {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindInput("Image", {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
// .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/prior_box_pe.hpp"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class PriorBoxCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::PriorBoxParam;
void PrepareForRun() override;
void Run() override;
virtual ~PriorBoxCompute() = default;
private:
zynqmp::PriorBoxPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/reshape_compute.h"
#include <vector>
#include "lite/operators/reshape_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void ReshapeCompute::Run() {
auto& param = Param<operators::ReshapeParam>();
param.output->mutable_data<float16>();
auto x = param.x;
// auto actual_shape = param.actual_shape;
Tensor* actual_shape = nullptr; // TODO(chonwhite) change it.
auto output = param.output;
bool inplace = param.inplace;
auto x_dims = x->dims();
auto output_dims = output->dims();
if (actual_shape) {
auto actual_shape_dims = actual_shape->dims();
auto* actual_shape_data = actual_shape->data<int>();
auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production());
output_dims = lite::operators::ValidateShape(shape, x_dims);
output->Resize(output_dims);
}
if (inplace) {
output->ShareDataWith(*x);
} else {
output->CopyDataFrom(*x);
}
param.x->ZynqTensor()->saveToFile("reshape_in", true);
output->ZynqTensor()->saveToFile("reshape_out", true);
output->Resize(output_dims);
}
// void ReshapeComputeFpgaToHost::Run() {
// auto& param = Param<operators::ReshapeParam>();
// param.output->mutable_data<float>();
// auto x = param.x;
// // auto actual_shape = param.actual_shape;
// Tensor* actual_shape = nullptr; // TODO(chonwhite) change it.
// auto output = param.output;
// bool inplace = param.inplace;
// auto x_dims = x->dims();
// auto output_dims = output->dims();
// if (actual_shape) {
// auto actual_shape_dims = actual_shape->dims();
// auto* actual_shape_data = actual_shape->data<int>();
// auto shape = std::vector<int>(
// actual_shape_data, actual_shape_data +
// actual_shape_dims.production());
// output_dims = lite::operators::ValidateShape(shape, x_dims);
// output->Resize(output_dims);
// }
// if (inplace) {
// output->ShareDataWith(*x);
// } else {
// output->CopyDataFrom(*x);
// }
// output->Resize(output_dims);
// }
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(reshape,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(reshape2,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(flatten,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(flatten2,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class ReshapeCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void Run() override;
virtual ~ReshapeCompute() = default;
};
class ReshapeComputeFpgaToHost
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void Run() override;
virtual ~ReshapeComputeFpgaToHost() = default;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -19,7 +19,37 @@ namespace lite {
namespace kernels {
namespace fpga {
void ScaleCompute::Run() {}
void ScaleCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
param.output->mutable_data<float16>();
zynqmp::ScaleParam& scale_param = pe_.param();
scale_param.input = param.x->ZynqTensor();
scale_param.output = param.output->ZynqTensor();
int channel = scale_param.input->shape().channel();
zynqmp::Tensor* scale = new zynqmp::Tensor();
zynqmp::Tensor* bias = new zynqmp::Tensor();
zynqmp::Shape shape(zynqmp::N, {channel});
float* scale_data = scale->mutableData<float>(zynqmp::FP32, shape);
float* bias_data = bias->mutableData<float>(zynqmp::FP32, shape);
float scale_value = param.scale;
float bias_value = param.bias_after_scale ? param.bias : 0;
for (int i = 0; i < channel; ++i) {
scale_data[i] = scale_value;
bias_data[i] = bias_value;
}
scale_param.scale = scale;
scale_param.bias = bias;
pe_.init();
pe_.apply();
}
void ScaleCompute::Run() { pe_.dispatch(); }
} // namespace fpga
} // namespace kernels
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/scale_pe.hpp"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
......@@ -21,12 +23,20 @@ namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
class ScaleCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ScaleParam;
void PrepareForRun() override;
void Run() override;
virtual ~ScaleCompute() = default;
private:
zynqmp::ScalePE pe_;
};
} // namespace fpga
......
......@@ -33,7 +33,13 @@ void SoftmaxCompute::PrepareForRun() {
pe_.apply();
}
void SoftmaxCompute::Run() { pe_.dispatch(); }
void SoftmaxCompute::Run() {
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::SoftmaxParam& softmax_param = pe_.param();
Debugger::get_instance().registerOutput("softmax", softmax_param.output);
#endif
}
} // namespace fpga
} // namespace kernels
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/kernels/fpga/transpose_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
using float16 = zynqmp::float16;
void transposeCompute(operators::TransposeParam param) {
// copy from;
const auto* input_x = param.x;
const auto input_x_dims = input_x->dims();
input_x->ZynqTensor()->invalidate();
input_x->ZynqTensor()->unalignImage();
Tensor float_input;
float_input.Resize(input_x_dims);
float_input.mutable_data<float>();
float_input.ZynqTensor()->copyFrom(input_x->ZynqTensor());
// const auto* input_x_data = input_x->data<float>();
const auto* input_x_data = float_input.data<float>();
// auto& param = this->Param<param_t>();
auto* out = param.output;
const auto axis = param.axis;
auto* out_data = out->mutable_data<float>();
size_t ndim = axis.size();
std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
xstride[j] = 1;
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
}
auto numel = input_x->numel();
size_t pind = 0;
std::vector<int> ind(ndim);
for (int i = 0; i < numel; i++) {
out_data[i] = input_x_data[pind];
ind[0]++;
pind += xstride[0];
for (int j = 0; j < ndim - 1; j++) {
if (ind[j] == xdim[j]) {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xout[j];
} else {
break;
}
}
}
}
// Transpose
void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
// param.output->mutable_data<float16>();
}
// Transpose2
void Transpose2Compute::Run() {
auto& param = this->Param<param_t>();
param.output->mutable_data<float>();
param.x->ZynqTensor()->invalidate();
param.x->ZynqTensor()->unalignImage();
if (param.x->dims().size() != 4) {
transposeCompute(param);
// auto out = param.Out();
// auto out_data = out->data<half>();
// int num = input_x_dims[1];
// int channel = input_x_dims[2];
// int index = 0;
// for (int n = 0; n < num; n++) {
// for (int c = 0; c < channel; c++) {
// out_data[c * num + n] = input_x_data[n * channel + c];
// index++;
// }
// }
// param.output->ZynqTensor()->copyFrom(param.x->ZynqTensor());
} else {
param.x->ZynqTensor()->saveToFile("tx", true);
param.output->ZynqTensor()->copyFrom(param.x->ZynqTensor());
param.output->ZynqTensor()->saveToFile("to", true);
}
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
// Transpose
REGISTER_LITE_KERNEL(transpose,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::TransposeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Transpose2
REGISTER_LITE_KERNEL(transpose2,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::Transpose2Compute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/kernel.h"
#include "lite/operators/transpose_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
// Transpose
class TransposeCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::TransposeParam;
void Run() override;
virtual ~TransposeCompute() = default;
};
// Transpose2
class Transpose2Compute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::TransposeParam;
void Run() override;
virtual ~Transpose2Compute() = default;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册