未验证 提交 9cc38d33 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge pull request #1445 from jameswu2014/my-cool-stuff

Add support for RFCN for FPGA track close #1432
...@@ -30,9 +30,9 @@ void format_image(framework::Tensor *image_tensor) { ...@@ -30,9 +30,9 @@ void format_image(framework::Tensor *image_tensor) {
auto data_ptr = image_tensor->data<float>(); auto data_ptr = image_tensor->data<float>();
auto external_ptr = reinterpret_cast<float *>(image_tensor->external_data); auto external_ptr = reinterpret_cast<float *>(image_tensor->external_data);
float *p_data = external_ptr == nullptr ? data_ptr : external_ptr; float *p_data = external_ptr == nullptr ? data_ptr : external_ptr;
float *old_p = p_data;
image::format_image(&p_data, channel, height, width); image::format_image(&p_data, channel, height, width);
if (old_p != p_data) { if (p_data != data_ptr) {
image_tensor->reset_data_ptr(p_data); image_tensor->reset_data_ptr(p_data);
} }
} }
...@@ -48,9 +48,9 @@ void format_fp16_ofm(framework::Tensor *ofm_tensor) { ...@@ -48,9 +48,9 @@ void format_fp16_ofm(framework::Tensor *ofm_tensor) {
auto dims = ofm_tensor->dims(); auto dims = ofm_tensor->dims();
size_t memory_size = 0; size_t memory_size = 0;
if (dims.size() == 4) { if (dims.size() == 4) {
auto channel = dims[1], height = dims[2], width = dims[3]; auto channel = dims[1], height = dims[2], width = dims[3], num = dims[0];
memory_size = memory_size = num * height * align_to_x(channel * width, IMAGE_ALIGNMENT) *
height * align_to_x(channel * width, IMAGE_ALIGNMENT) * sizeof(half); sizeof(half);
} else if (dims.size() == 2) { } else if (dims.size() == 2) {
memory_size = align_to_x(dims[1], IMAGE_ALIGNMENT) * sizeof(half); memory_size = align_to_x(dims[1], IMAGE_ALIGNMENT) * sizeof(half);
} else { } else {
...@@ -960,10 +960,10 @@ void fill_DWDeconv_arg(struct DWDeconvArgs *arg, framework::Tensor *input, ...@@ -960,10 +960,10 @@ void fill_DWDeconv_arg(struct DWDeconvArgs *arg, framework::Tensor *input,
sizeof(int16_t)); sizeof(int16_t));
arg->dw_conv_args[i]->output.scale_address = arg->dw_conv_args[i]->output.scale_address =
static_cast<float *>(fpga_malloc(2 * sizeof(float))); static_cast<float *>(fpga_malloc(2 * sizeof(float)));
arg->vector_dw_conv_space.push_back(std::shared_ptr<char>( arg->vector_dw_conv_space.push_back(std::shared_ptr<char>( // NOLINT
reinterpret_cast<char *>(arg->dw_conv_args[i]->output.address), reinterpret_cast<char *>(arg->dw_conv_args[i]->output.address),
deleter)); deleter));
arg->vector_dw_conv_space.push_back(std::shared_ptr<char>( arg->vector_dw_conv_space.push_back(std::shared_ptr<char>( // NOLINT
reinterpret_cast<char *>(arg->dw_conv_args[i]->output.scale_address), reinterpret_cast<char *>(arg->dw_conv_args[i]->output.scale_address),
deleter)); deleter));
} }
......
...@@ -21,15 +21,37 @@ namespace paddle_mobile { ...@@ -21,15 +21,37 @@ namespace paddle_mobile {
namespace fpga { namespace fpga {
namespace image { namespace image {
void convert_to_hwc(float **data_in, int channel, int height, int width) { void convert_to_hwc(float **data_in, int channel, int height, int width,
int num) {
float *data_tmp = reinterpret_cast<float *>(
fpga_malloc(num * channel * height * width * sizeof(float)));
int64_t amount_per_row = width * channel;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) {
for (int h = 0; h < height; h++) {
int64_t offset_height = h * amount_per_row;
for (int w = 0; w < width; w++) {
*(data_tmp + n * channel * height * width + offset_height +
w * channel + c) = *((*data_in)++);
}
}
}
}
*data_in = data_tmp;
}
void convert_to_chw(float **data_in, int channel, int height, int width,
int num) {
float *data_tmp = float *data_tmp =
(float *)fpga_malloc(channel * height * width * sizeof(float)); // NOLINT (float *)fpga_malloc(channel * height * width * sizeof(float)); // NOLINT
int64_t amount_per_row = width * channel; int64_t amount_per_side = width * height;
for (int c = 0; c < channel; c++) { for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) { for (int h = 0; h < height; h++) {
int64_t offset_height = h * amount_per_row;
for (int w = 0; w < width; w++) { for (int w = 0; w < width; w++) {
*(data_tmp + offset_height + w * channel + c) = *((*data_in)++); for (int c = 0; c < channel; c++) {
*(data_tmp + n * height * width * channel + c * amount_per_side +
width * h + w) = *((*data_in)++);
}
} }
} }
} }
...@@ -55,7 +77,7 @@ void align_element_conv(float **data_in, int height, int cw) { ...@@ -55,7 +77,7 @@ void align_element_conv(float **data_in, int height, int cw) {
} }
void format_image(float **data_in, int channel, int height, int width) { void format_image(float **data_in, int channel, int height, int width) {
convert_to_hwc(data_in, channel, height, width); // convert_to_hwc(data_in, channel, height, width);
int cw = channel * width; int cw = channel * width;
int align_cw = align_to_x(cw, IMAGE_ALIGNMENT); int align_cw = align_to_x(cw, IMAGE_ALIGNMENT);
if (align_cw != cw) { if (align_cw != cw) {
...@@ -132,8 +154,8 @@ void split_image(int16_t *image_in, const float *scale_in, void **images_out, ...@@ -132,8 +154,8 @@ void split_image(int16_t *image_in, const float *scale_in, void **images_out,
for (int i = 0; i < image_num; i++) { for (int i = 0; i < image_num; i++) {
des_offset = h * align_to_x(channel_nums[i] * width, IMAGE_ALIGNMENT) + des_offset = h * align_to_x(channel_nums[i] * width, IMAGE_ALIGNMENT) +
w * channel_nums[i]; w * channel_nums[i];
memcpy((int16_t *)images_out[i] + des_offset, image_in + src_offset, memcpy(reinterpret_cast<int16_t *>(images_out[i]) + des_offset,
channel_nums[i] * sizeof(int16_t)); image_in + src_offset, channel_nums[i] * sizeof(int16_t));
src_offset += channel_nums[i]; src_offset += channel_nums[i];
} }
} }
......
...@@ -20,7 +20,11 @@ namespace paddle_mobile { ...@@ -20,7 +20,11 @@ namespace paddle_mobile {
namespace fpga { namespace fpga {
namespace image { namespace image {
void convert_to_hwc(float** data_in, int channel, int height, int width); void convert_to_hwc(float** data_in, int channel, int height, int width,
int num = 1);
void convert_to_chw(float** data_in, int channel, int height, int width,
int num = 1);
void align_element_conv(float** data_in, int height, int cw); void align_element_conv(float** data_in, int height, int cw);
void format_image(float** data_in, int channel, int height, int width); void format_image(float** data_in, int channel, int height, int width);
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "framework/operator.h" #include "framework/operator.h"
#include <memory>
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -70,7 +70,12 @@ void OperatorBase<Dtype>::Run() { ...@@ -70,7 +70,12 @@ void OperatorBase<Dtype>::Run() {
auto vari = this->scope_->FindVar(var_vec_in[i]); auto vari = this->scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) { if (vari->IsInitialized()) {
const Tensor *tensor = vari->template Get<framework::LoDTensor>(); const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; if (tensor) {
DLOG << type_ << " input- " << key << "=" << *tensor;
#ifdef PADDLE_MOBILE_FPGA
DLOG << var_vec_in[i];
#endif
}
} }
} }
} }
...@@ -80,7 +85,12 @@ void OperatorBase<Dtype>::Run() { ...@@ -80,7 +85,12 @@ void OperatorBase<Dtype>::Run() {
auto vari = scope_->FindVar(var_vec_out[i]); auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) { if (vari->IsInitialized()) {
const Tensor *tensor = vari->template Get<framework::LoDTensor>(); const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor; if (tensor) {
DLOG << type_ << " output- " << key << "=" << *tensor;
#ifdef PADDLE_MOBILE_FPGA
DLOG << var_vec_out[i];
#endif
}
} }
} }
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -80,7 +81,9 @@ class OperatorBase { ...@@ -80,7 +81,9 @@ class OperatorBase {
} }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
void InsertTensors(); void InsertTensors();
void ChangeNameMap(string key, std::vector<string> value);
#endif #endif
protected: protected:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::string type_; std::string type_;
...@@ -95,6 +98,7 @@ class OperatorBase { ...@@ -95,6 +98,7 @@ class OperatorBase {
template <typename Dtype, typename ParamType, typename KernelType> template <typename Dtype, typename ParamType, typename KernelType>
class OperatorWithKernel : public OperatorBase<Dtype> { class OperatorWithKernel : public OperatorBase<Dtype> {
public: public:
#ifndef PADDLE_MOBILE_FPGA1
OperatorWithKernel(const std::string &type, const VariableNameMap &inputs, OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
...@@ -104,6 +108,25 @@ class OperatorWithKernel : public OperatorBase<Dtype> { ...@@ -104,6 +108,25 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
kernel_.InitCLHelper(scope->GetCLScpoe()); kernel_.InitCLHelper(scope->GetCLScpoe());
#endif #endif
} }
#else
OperatorWithKernel(const std::string &type, const VariableNameMap inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {
static int feed_num = 0;
static int fetch_num = 0;
if (type == "feed") {
auto new_name = string("feed") + std::to_string(feed_num++);
auto var = scope->Var(new_name);
(const_cast<VariableNameMap &>(inputs)).at("X") = {string(new_name)};
} else if (type == "fetch") {
auto new_name = string("fetch") + std::to_string(fetch_num++);
auto var = scope->Var(new_name);
(const_cast<VariableNameMap &>(outputs)).at("Out") = {string(new_name)};
}
param_ = ParamType(inputs, outputs, attrs, *scope);
}
#endif
virtual void RunImpl() { this->kernel_.Compute(this->param_); } virtual void RunImpl() { this->kernel_.Compute(this->param_); }
virtual void InferShape() const = 0; virtual void InferShape() const = 0;
......
...@@ -126,6 +126,8 @@ std::vector<Variable *> Scope::VarContain(const std::string substring) { ...@@ -126,6 +126,8 @@ std::vector<Variable *> Scope::VarContain(const std::string substring) {
return v; return v;
} }
void Scope::InsertVar(const std::string str, Variable *var) {}
void Scope::print_vars() { void Scope::print_vars() {
DLOG << "====================start to print variables================="; DLOG << "====================start to print variables=================";
for (auto pair : vars_) { for (auto pair : vars_) {
......
...@@ -86,6 +86,7 @@ class Scope { ...@@ -86,6 +86,7 @@ class Scope {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
Variable *Var(const std::string &name, const int id); Variable *Var(const std::string &name, const int id);
std::vector<Variable *> VarContain(const std::string substring); std::vector<Variable *> VarContain(const std::string substring);
void InsertVar(const std::string str, Variable *var);
void print_vars(); void print_vars();
#endif #endif
......
...@@ -43,9 +43,11 @@ bool AnchorGeneratorKernel<FPGA, float>::Init( ...@@ -43,9 +43,11 @@ bool AnchorGeneratorKernel<FPGA, float>::Init(
// DLOG << "stride_height: " << stride_height; // DLOG << "stride_height: " << stride_height;
for (int h_idx = 0; h_idx < feature_height; ++h_idx) { for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
int offset0 = h_idx * feature_width * num_anchors * 4;
for (int w_idx = 0; w_idx < feature_width; ++w_idx) { for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
int offset = h_idx * w_idx * num_anchors * 4; int offset1 = w_idx * num_anchors * 4;
for (int idx = 0; idx < num_anchors; idx++) { for (int idx = 0; idx < num_anchors; idx++) {
int offset = offset0 + offset1 + idx * 4;
anchor_ptr[offset + 0] = anchor_ptr[offset + 0] =
anchors_offset[idx * 4 + 0] + w_idx * stride_width; anchors_offset[idx * 4 + 0] + w_idx * stride_width;
anchor_ptr[offset + 1] = anchor_ptr[offset + 1] =
......
...@@ -25,11 +25,6 @@ bool FeedKernel<FPGA, float>::Init(FeedParam<FPGA> *param) { ...@@ -25,11 +25,6 @@ bool FeedKernel<FPGA, float>::Init(FeedParam<FPGA> *param) {
input->Resize(output->dims()); input->Resize(output->dims());
if (output->dims().size() != 4) { if (output->dims().size() != 4) {
auto input_ptr = input->mutable_data<float>();
size_t size = output->numel() * sizeof(float);
auto p = fpga::fpga_malloc(size);
memcpy(p, input_ptr, size);
output->reset_data_ptr(p);
return true; return true;
} }
fpga::format_fp16_ofm(output); fpga::format_fp16_ofm(output);
...@@ -41,7 +36,14 @@ void FeedKernel<FPGA, float>::Compute(const FeedParam<FPGA> &param) { ...@@ -41,7 +36,14 @@ void FeedKernel<FPGA, float>::Compute(const FeedParam<FPGA> &param) {
auto output = param.Out(); auto output = param.Out();
auto input = const_cast<LoDTensor *>(param.InputX()); auto input = const_cast<LoDTensor *>(param.InputX());
if (input->dims().size() != 4) { if (output->dims().size() != 4) {
size_t size = output->numel() * sizeof(float);
auto output_ptr = output->data<float>();
auto input_ptr = input->data<float>();
auto external_ptr = reinterpret_cast<float *>(input->external_data);
float *p_data = external_ptr == nullptr ? input_ptr : external_ptr;
memcpy(output_ptr, p_data, size);
input->external_data = nullptr;
return; return;
} }
......
...@@ -49,17 +49,20 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) { ...@@ -49,17 +49,20 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) {
template <> template <>
void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto input = param.InputX(); auto input = const_cast<Tensor *>(param.InputX());
if (input->type() == typeid(float)) { if (input->type() == typeid(float)) {
auto output = param.Out(); auto output = param.Out();
output->ShareDataWith(*input); output->ShareDataWith(*input);
return; return;
} }
fpga::PerformBypass(param.fpga_bypass_args); fpga::BypassArgs args = param.fpga_bypass_args;
auto data = (input->mutable_data<half>());
args.image.address = static_cast<void *>(data);
fpga::PerformBypass(args);
fpga::fpga_invalidate(param.fpga_bypass_args.output.address, fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
param.fpga_bypass_args.image.channels * sizeof(float)); param.fpga_bypass_args.image.channels * sizeof(float));
// TODO: DEalign: get rid of extra 0 // TODO(zhangyang): DEalign: get rid of extra 0
} }
template class FetchKernel<FPGA, float>; template class FetchKernel<FPGA, float>;
......
...@@ -22,15 +22,29 @@ namespace operators { ...@@ -22,15 +22,29 @@ namespace operators {
template <> template <>
bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) { bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) {
auto *input = const_cast<Tensor *>(param->Input()); auto *input = const_cast<Tensor *>(param->Input());
auto input_ptr = input->data<half>(); auto *output = param->Output();
Tensor *output = param->Output();
fpga::format_fp16_ofm(output);
auto output_ptr = output->mutable_data<half>();
vector<int> ksize = param->Ksize(); vector<int> ksize = param->Ksize();
vector<int> strides = param->Strides(); vector<int> strides = param->Strides();
vector<int> paddings = param->Paddings(); vector<int> paddings = param->Paddings();
std::string pooling_type = param->PoolingType(); std::string pooling_type = param->PoolingType();
if (input->type() == typeid(float)) {
int channels = input->dims()[1];
int height = input->dims()[2];
int width = input->dims()[3];
int num = input->dims()[0];
int out_width = (width + 2 * paddings[1] - ksize[1]) / strides[1] + 1;
int out_height = (height + 2 * paddings[0] - ksize[0]) / strides[0] + 1;
framework::DDim dim =
framework::make_ddim({num, channels, out_height, out_width});
output->mutable_data<float>(dim);
return true;
}
auto input_ptr = input->data<half>();
fpga::format_fp16_ofm(output);
auto output_ptr = output->mutable_data<half>();
fpga::PoolingArgs poolArgs = {0}; fpga::PoolingArgs poolArgs = {0};
poolArgs.mode = pooling_type == "max" ? 0 : 1; // max:0, avg:1 poolArgs.mode = pooling_type == "max" ? 0 : 1; // max:0, avg:1
poolArgs.kernel_reciprocal = poolArgs.kernel_reciprocal =
...@@ -54,6 +68,31 @@ bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) { ...@@ -54,6 +68,31 @@ bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) {
template <> template <>
void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) { void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) {
auto *input = const_cast<Tensor *>(param.Input());
if (input->type() == typeid(float)) {
auto *output = param.Output();
auto in = input->data<float>();
auto len = output->numel();
auto out = output->mutable_data<float>();
int N = input->dims()[0], C = input->dims()[1], H = input->dims()[2],
W = input->dims()[3];
int HW = H * W, CHW = C * H * W, WC = W * C;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
out[n * C + c] = 0;
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
out[n * C + c] += in[n * CHW + h * WC + w * C +
c]; // in[n * CHW + c * HW + h * W + w]; //
}
}
out[n * C + c] /= HW;
}
}
return;
}
fpga::ComputeFpgaPool(param.FpgaArgs()); fpga::ComputeFpgaPool(param.FpgaArgs());
} }
} // namespace operators } // namespace operators
......
...@@ -67,6 +67,30 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) { ...@@ -67,6 +67,30 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) {
return true; return true;
} }
template <typename T>
void CPUGather(const Tensor &src, const Tensor &index, Tensor *output) {
PADDLE_MOBILE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1),
"Dim not correct");
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
const T *p_src = src.data<T>();
const int *p_index = index.data<int>();
T *p_output = output->data<T>();
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) { void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
auto *out_data = dst->data<void>(); auto *out_data = dst->data<void>();
...@@ -103,38 +127,49 @@ static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, ...@@ -103,38 +127,49 @@ static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas,
T bbox_center_x = 0, bbox_center_y = 0; T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0; T bbox_width = 0, bbox_height = 0;
if (variances) { /*
bbox_center_x = if (variances) {
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width + bbox_center_x =
anchor_center_x; variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width
bbox_center_y = variances_data[i * len + 1] * + anchor_center_x; bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height + bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y; anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] * bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2], bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) * kBBoxClipDefault)) *
anchor_width; anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] * bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3], bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) * kBBoxClipDefault)) *
anchor_height; anchor_height;
} else { } else {
bbox_center_x = */
bbox_deltas_data[i * len] * anchor_width + anchor_center_x; bbox_center_x = bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y = bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) * /*
anchor_width; bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3], kBBoxClipDefault)) *
kBBoxClipDefault)) * anchor_width;
anchor_height; bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
} kBBoxClipDefault)) *
anchor_height;
*/
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
// }
proposals_data[i * len] = bbox_center_x - bbox_width / 2; proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2; proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1; /*
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1; //wong
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
//wong
*/
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
} }
// return proposals; // return proposals;
} }
...@@ -328,9 +363,12 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -328,9 +363,12 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
anchor_sel.mutable_data<T>({index_t.numel(), 4}); anchor_sel.mutable_data<T>({index_t.numel(), 4});
var_sel.mutable_data<T>({index_t.numel(), 4}); var_sel.mutable_data<T>({index_t.numel(), 4});
CPUGather<T>(scores_slice, index_t, &scores_sel);
CPUGather<T>(bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(anchors, index_t, &anchor_sel);
Tensor proposals; Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}); proposals.mutable_data<T>({index_t.numel(), 4});
BoxCoder<T>(&anchor_sel, &bbox_sel, &var_sel, &proposals); BoxCoder<T>(&anchor_sel, &bbox_sel, nullptr, &proposals);
ClipTiledBoxes<T>(im_info_slice, &proposals); ClipTiledBoxes<T>(im_info_slice, &proposals);
...@@ -341,6 +379,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -341,6 +379,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
bbox_sel.mutable_data<T>({keep.numel(), 4}); bbox_sel.mutable_data<T>({keep.numel(), 4});
scores_filter.mutable_data<T>({keep.numel(), 1}); scores_filter.mutable_data<T>({keep.numel(), 1});
CPUGather<T>(proposals, keep, &bbox_sel);
CPUGather<T>(scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter); return std::make_pair(bbox_sel, scores_filter);
} }
...@@ -351,14 +391,86 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -351,14 +391,86 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
keep_nms.Resize({post_nms_top_n}); keep_nms.Resize({post_nms_top_n});
} }
proposals.mutable_data<T>({keep_nms.numel(), 4}); // proposals.mutable_data<T>({keep_nms.numel(), 4});//original
scores_sel.mutable_data<T>({keep_nms.numel(), 1}); // scores_sel.mutable_data<T>({keep_nms.numel(), 1});//original
proposals.mutable_data<T>({post_nms_top_n, 4}); // wong
scores_sel.mutable_data<T>({post_nms_top_n, 1}); // wong
CPUGather<T>(bbox_sel, keep_nms, &proposals);
CPUGather<T>(scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel); return std::make_pair(proposals, scores_sel);
} }
template <> template <>
void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) { void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
auto input_score = param.scores_;
auto input_score_data = input_score->data<half>();
auto input_score_data_tmp = input_score->data<half>();
uint32_t score_n, score_height, score_width, score_channels;
auto input_bbox = param.bbox_deltas_;
auto input_bbox_data = input_bbox->data<half>();
auto input_bbox_data_tmp = input_bbox->data<half>();
uint32_t bbox_n, bbox_height, bbox_width, bbox_channels;
score_n = (uint32_t)(input_score->dims()[0]);
score_channels = (uint32_t)(input_score->dims()[1]);
score_height = (uint32_t)(input_score->dims()[2]);
score_width = (uint32_t)(input_score->dims()[3]);
bbox_n = (uint32_t)(input_bbox->dims()[0]);
bbox_channels = (uint32_t)(input_bbox->dims()[1]);
bbox_height = (uint32_t)(input_bbox->dims()[2]);
bbox_width = (uint32_t)(input_bbox->dims()[3]);
// score_tmp->init(typeid(half));
std::shared_ptr<Tensor> score_tmp = std::make_shared<Tensor>();
score_tmp->Resize(param.scores_->dims());
score_tmp->mutable_data<half>();
std::shared_ptr<Tensor> bbox_tmp = std::make_shared<Tensor>();
bbox_tmp->Resize(param.bbox_deltas_->dims());
bbox_tmp->mutable_data<half>();
auto score_tmp_data = score_tmp->data<half>();
auto bbox_tmp_data = bbox_tmp->data<half>();
int64_t amount_per_side = score_width * score_height;
int idx = 0;
fpga::fpga_invalidate(
input_score_data_tmp,
score_height * score_width * score_channels * sizeof(half));
for (int h = 0; h < score_height; h++) {
for (int w = 0; w < score_width; w++) {
for (int c = 0; c < score_channels; c++) {
idx++;
// DLOG << "wong input_score: "<<
// paddle_mobile::fpga::fp16_2_fp32(input_score_data[idx]);
*(score_tmp_data + c * amount_per_side + score_width * h + w) =
(*(input_score_data_tmp++));
}
}
}
amount_per_side = bbox_width * bbox_height;
fpga::fpga_invalidate(input_bbox_data_tmp, bbox_height * bbox_width *
bbox_channels * sizeof(half));
for (int h = 0; h < bbox_height; h++) {
for (int w = 0; w < bbox_width; w++) {
for (int c = 0; c < bbox_channels; c++) {
idx++;
// DLOG << "wong input_score: "<<
// paddle_mobile::fpga::fp16_2_fp32(input_score_data[idx]);
*(bbox_tmp_data + c * amount_per_side + bbox_width * h + w) =
(*(input_bbox_data_tmp++));
}
}
}
struct paddle_mobile::fpga::BypassArgs temp_score_arg;
struct paddle_mobile::fpga::BypassArgs temp_bbox_arg;
temp_score_arg = param.score_arg;
temp_score_arg.image.address = score_tmp->data<half>();
temp_bbox_arg = param.bbox_arg;
temp_bbox_arg.image.address = bbox_tmp->data<half>();
auto score_tensor = param.float_score.get(); auto score_tensor = param.float_score.get();
fpga::PerformBypass(param.score_arg); fpga::PerformBypass(param.score_arg);
fpga::fpga_invalidate(score_tensor->data<float>(), fpga::fpga_invalidate(score_tensor->data<float>(),
...@@ -396,23 +508,23 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) { ...@@ -396,23 +508,23 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
int64_t w_bbox = bbox_dim[3]; int64_t w_bbox = bbox_dim[3];
// //
Tensor bbox_deltas_swap, scores_swap; rpn_rois->mutable_data<float>({bbox_deltas->numel(), 4});
bbox_deltas_swap.mutable_data<float>({num, h_bbox, w_bbox, c_bbox}); rpn_roi_probs->mutable_data<float>({scores->numel(), 1});
scores_swap.mutable_data<float>({num, h_score, w_score, c_score});
framework::LoD lod; framework::LoD lod;
lod.resize(1); lod.resize(1);
auto &lod0 = lod[0]; auto &lod0 = lod[0];
lod0.push_back(0); lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4}); anchors.Resize({anchors.numel(), 4});
variances.Resize({variances.numel(), 4});
int64_t num_proposals = 0; int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) { for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1); Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); Tensor bbox_deltas_slice = (*bbox_tensor).Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1); Tensor scores_slice = (*score_tensor).Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox, 4});
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>( std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>(
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "operators/kernel/detection_kernel.h" #include "operators/kernel/detection_kernel.h"
#include "fpga/V1/api.h"
#include "fpga/V1/image.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -29,8 +31,7 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) { ...@@ -29,8 +31,7 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
param->float_input = std::make_shared<Tensor>(); param->float_input = std::make_shared<Tensor>();
param->float_input->mutable_data<float>(param->input_x_->dims()); param->float_input->mutable_data<float>(param->input_x_->dims());
param->float_output = std::make_shared<Tensor>(); // param->float_output = std::make_shared<Tensor>();
param->float_output->mutable_data<float>(param->output_->dims());
auto input = param->input_x_; auto input = param->input_x_;
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
...@@ -46,22 +47,90 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) { ...@@ -46,22 +47,90 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
args.output.scale_address = param->float_input->scale; args.output.scale_address = param->float_input->scale;
param->input_arg = args; param->input_arg = args;
fpga::format_fp16_ofm(param->output_); auto* rois = param->input_rois_;
int rois_num = rois->dims()[0];
input = param->float_output.get(); framework::DDim dims_out_new = framework::make_ddim(
args.input_data_type = fpga::DATA_TYPE_FP32; {rois_num, param->output_->dims()[1], param->output_->dims()[2],
args.output_data_type = fpga::DATA_TYPE_FP16; param->output_->dims()[3]});
args.image.address = input->data<float>(); param->output_->Resize(dims_out_new);
args.image.height = (uint32_t)input->dims()[2]; // fpga::format_fp16_ofm(param->output_);
args.image.width = (uint32_t)input->dims()[3];
args.image.channels = (uint32_t)input->dims()[1]; param->output_->mutable_data<float>(dims_out_new);
args.output.address = param->output_->mutable_data<half>(); // auto output = param->float_output.get();
args.output.scale_address = param->output_->scale; // param->output_ = output;
param->input_arg = args; /* args.input_data_type = fpga::DATA_TYPE_FP32;
args.output_data_type = fpga::DATA_TYPE_FP16;
args.image.address = output->data<float>();
args.image.height = (uint32_t)output->dims()[2];
args.image.width = (uint32_t)output->dims()[3];
args.image.channels = (uint32_t)output->dims()[1] ;
args.output.address = param->output_->mutable_data<half>();
args.output.scale_address = param->output_->scale;
param->output_arg = args;*/
return true; return true;
} }
template <typename Dtype>
void PSROIPooling(const Dtype* bottom_data, const Dtype spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width,
const Dtype* bottom_rois, const int output_dim,
const int group_size, Dtype* top_data,
// int* mapping_channel,
int index, int* rois_batch_id) {
// The output is in order (n, ctop, ph, pw)
// static int cnt = 0;
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
bottom_rois += n * 4;
int roi_batch_ind = rois_batch_id[n]; // bottom_rois[0];
Dtype roi_start_w = static_cast<Dtype>(round(bottom_rois[0])) * spatial_scale;
Dtype roi_start_h = static_cast<Dtype>(round(bottom_rois[1])) * spatial_scale;
Dtype roi_end_w =
static_cast<Dtype>(round(bottom_rois[2]) + 1.) * spatial_scale;
Dtype roi_end_h =
static_cast<Dtype>(round(bottom_rois[3]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
Dtype roi_width = std::max(roi_end_w - roi_start_w, 0.1f); // avoid 0
Dtype roi_height = std::max(roi_end_h - roi_start_h, 0.1f);
// Compute w and h at bottom
Dtype bin_size_h = roi_height / static_cast<Dtype>(pooled_height);
Dtype bin_size_w = roi_width / static_cast<Dtype>(pooled_width);
int hstart = floor(static_cast<Dtype>(ph) * bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw) * bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(static_cast<Dtype>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int gw = pw;
int gh = ph;
int c = (ctop * group_size + gh) * group_size + gw;
bottom_data += (roi_batch_ind * channels + c) * height * width;
Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h * width + w;
out_sum += bottom_data[bottom_index];
}
}
Dtype bin_area = (hend - hstart) * (wend - wstart);
top_data[index] = is_empty ? 0. : out_sum / bin_area;
}
template <> template <>
void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) { void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get(); auto input_tensor = param.float_input.get();
...@@ -71,7 +140,7 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) { ...@@ -71,7 +140,7 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto* in = input_tensor; auto* in = input_tensor;
auto* rois = param.input_rois_; auto* rois = param.input_rois_;
auto* out = param.float_output.get(); auto* out = param.output_; // param.float_output.get();
auto pooled_height = param.pooled_height_; auto pooled_height = param.pooled_height_;
auto pooled_width = param.pooled_width_; auto pooled_width = param.pooled_width_;
...@@ -85,18 +154,17 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) { ...@@ -85,18 +154,17 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
// TODO auto in_stride = framework::stride(in_dims); auto data_nhwc = in->mutable_data<float>();
// TODO auto out_stride = framework::stride(out->dims()); fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width);
auto in_stride = framework::DDim dims_out_new = framework::make_ddim(
framework::stride({batch_size, height, width, input_channels}); {rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
auto out_stride = framework::stride( (param.output_)->dims()[3]});
{out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]}); (param.output_)->Resize(dims_out_new);
const float* input_data = in->data<float>(); const float* input_data = data_nhwc; // in->data<float>();
framework::Tensor rois_batch_id_list; framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num}); rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>(); auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
return;
PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty"); PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty");
...@@ -124,78 +192,18 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) { ...@@ -124,78 +192,18 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto input_rois = rois->data<float>(); auto input_rois = rois->data<float>();
// calculate psroipooling, parallel processing can be implemented per ROI // calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// set roi batch id int index = pooled_height * pooled_width * output_channels * rois_num;
int roi_batch_id = rois_batch_id_data[n]; for (int idx = 0; idx < index; idx++) {
PSROIPooling<float>(input_data, spatial_scale, input_channels, height,
// [start, end) interval for spatial sampling width, pooled_height, pooled_width, input_rois,
auto offset_input_rois = input_rois + n * 4; output_channels, pooled_height, output_data, idx,
auto roi_start_w = rois_batch_id_data);
static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h =
static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w =
static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_end_h =
static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0
auto roi_width = std::max(roi_end_w - roi_start_w, 0.1f);
// Compute bin size w and h at input feature map
auto bin_size_h = roi_height / static_cast<float>(pooled_height);
auto bin_size_w = roi_width / static_cast<float>(pooled_width);
DLOG << 3;
// calculate each pixel of the output feature map.
int out_roi_offset = n * out_stride[0];
for (int c = 0; c < output_channels; ++c) {
// per category
// int out_plane_offset = out_roi_offset + c * out_stride[1];
int out_plane_offset = out_roi_offset + c;
for (int ph = 0; ph < pooled_height; ++ph) {
// TODO int out_row_offset = out_plane_offset + ph *
// out_stride[2];
int out_row_offset = out_plane_offset + ph * out_stride[1];
for (int pw = 0; pw < pooled_width; ++pw) {
// calculate w and h at input feature map
int hstart = floor(static_cast<float>(ph) * bin_size_h + roi_start_h);
int wstart = floor(static_cast<float>(pw) * bin_size_w + roi_start_w);
int hend =
ceil(static_cast<float>(ph + 1) * bin_size_h + roi_start_h);
int wend =
ceil(static_cast<float>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
wstart = std::min(std::max(wstart, 0), width);
hend = std::min(std::max(hend, 0), height);
wend = std::min(std::max(wend, 0), width);
// TODO int output_index = out_row_offset + pw;
int output_index = out_row_offset + pw * output_channels;
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
// TODO int input_plane_offset =
// TODO roi_batch_id * in_stride[0] + input_channel *
// in_stride[1];
int input_plane_offset = roi_batch_id * in_stride[0] + input_channel;
auto offset_input_data = input_data + input_plane_offset;
float out_sum = 0.;
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * in_stride[1] + iw * input_channel;
out_sum += offset_input_data[input_index];
}
}
float bin_area = (hend - hstart) * (wend - wstart);
output_data[output_index] = is_empty ? 0. : out_sum / bin_area;
}
}
}
} }
fpga::format_image(out); //
fpga::PerformBypass(param.output_arg); fpga::image::convert_to_hwc(&output_data, output_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(output_data);
} }
} // namespace operators } // namespace operators
......
...@@ -47,21 +47,11 @@ bool Reshape2Kernel<FPGA, float>::Init(Reshape2Param<FPGA> *param) { ...@@ -47,21 +47,11 @@ bool Reshape2Kernel<FPGA, float>::Init(Reshape2Param<FPGA> *param) {
void reshape(LoDTensor *input, LoDTensor *output) { void reshape(LoDTensor *input, LoDTensor *output) {
// Subscript r means after reshape // Subscript r means after reshape
// TODO zhangyang verify this function
float *input_ptr_f, *output_ptr_f; auto input_ptr = input->data<half>();
half *input_ptr_h, *output_ptr_h; auto output_ptr = output->data<half>();
bool is_float = false; output->scale[0] = input->scale[0];
output->scale[1] = input->scale[1];
if (input->type() == typeid(float)) {
input_ptr_f = input->data<float>();
output_ptr_f = output->data<float>();
is_float = true;
} else {
input_ptr_h = input->data<half>();
output_ptr_h = output->data<half>();
}
auto C = static_cast<int>(input->dims()[1]); auto C = static_cast<int>(input->dims()[1]);
auto H = static_cast<int>(input->dims()[2]); auto H = static_cast<int>(input->dims()[2]);
...@@ -77,6 +67,8 @@ void reshape(LoDTensor *input, LoDTensor *output) { ...@@ -77,6 +67,8 @@ void reshape(LoDTensor *input, LoDTensor *output) {
auto WCr_align = fpga::align_to_x(WCr, IMAGE_ALIGNMENT); auto WCr_align = fpga::align_to_x(WCr, IMAGE_ALIGNMENT);
auto HWr = Hr * Wr; auto HWr = Hr * Wr;
fpga::fpga_invalidate(input_ptr, H * WC_align * sizeof(half));
int offset_align = 0; int offset_align = 0;
int offset_r = 0, offset_align_r = 0; int offset_r = 0, offset_align_r = 0;
int cr = 0, hr = 0, wr = 0; int cr = 0, hr = 0, wr = 0;
...@@ -87,21 +79,17 @@ void reshape(LoDTensor *input, LoDTensor *output) { ...@@ -87,21 +79,17 @@ void reshape(LoDTensor *input, LoDTensor *output) {
int offset1 = w * C + offset0; int offset1 = w * C + offset0;
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
offset_align = offset1 + c; offset_align = offset1 + c;
offset_r = c * HW + h * W + c; offset_r = c * HW + h * W + w;
cr = offset_r / HWr; cr = offset_r / HWr;
hr = offset_r % HWr / Wr; hr = offset_r % HWr / Wr;
wr = offset_r % Wr; wr = offset_r % Wr;
offset_align_r = hr * WCr_align + wr * Cr + cr; offset_align_r = hr * WCr_align + wr * Cr + cr;
// DLOG << "hwc"<< h<< " " << w << " " << c; output_ptr[offset_align_r] = input_ptr[offset_align];
// DLOG << "hrwrcr" << hr<< " " << wr << " " << cr;
if (is_float) {
output_ptr_f[offset_align_r] = input_ptr_f[offset_align];
} else {
output_ptr_h[offset_align_r] = input_ptr_h[offset_align];
}
} }
} }
} }
fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half));
} }
template <> template <>
...@@ -123,6 +111,9 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) { ...@@ -123,6 +111,9 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
output->Resize(framework::make_ddim(shape)); output->Resize(framework::make_ddim(shape));
if (output->dims() == input->dims()) { if (output->dims() == input->dims()) {
DLOG << "No need to reshape"; DLOG << "No need to reshape";
output->ShareDataWith(*input);
framework::LoD lod = input->lod();
output->set_lod(lod);
return; return;
} }
......
...@@ -33,13 +33,18 @@ bool SliceKernel<FPGA, float>::Init(SliceParam<FPGA>* param) { ...@@ -33,13 +33,18 @@ bool SliceKernel<FPGA, float>::Init(SliceParam<FPGA>* param) {
template <> template <>
void SliceKernel<FPGA, float>::Compute(const SliceParam<FPGA>& param) { void SliceKernel<FPGA, float>::Compute(const SliceParam<FPGA>& param) {
// Only support slicing in channel dimension // Only support slicing in channel dimension
// Only support half data
// W must be aligned to 16
auto input = param.input_; auto input = param.input_;
DLOG << input; auto output = param.output_;
int HW = input->dims()[2] * input->dims()[3]; int HW = input->dims()[2] * input->dims()[3];
int channel = input->dims()[1]; int channel = input->dims()[1];
auto input_ptr = input->data<half>(); auto input_ptr = input->data<half>();
auto output_ptr = param.output_->data<half>(); auto output_ptr = output->data<half>();
output->scale[0] = input->scale[0];
output->scale[1] = input->scale[1];
int start = param.starts_[0], end = param.ends_[0]; int start = param.starts_[0], end = param.ends_[0];
start = start < 0 ? start + channel : start; start = start < 0 ? start + channel : start;
...@@ -47,9 +52,10 @@ void SliceKernel<FPGA, float>::Compute(const SliceParam<FPGA>& param) { ...@@ -47,9 +52,10 @@ void SliceKernel<FPGA, float>::Compute(const SliceParam<FPGA>& param) {
start = start > channel ? channel : start; start = start > channel ? channel : start;
end = end > channel ? channel : end; end = end > channel ? channel : end;
int len = end - start; int len = end - start;
size_t size = len * sizeof(half);
for (int i = 0; i < HW; i++) { for (int i = 0; i < HW; i++) {
memcpy(output_ptr + len * i, input_ptr + i * channel + start, len); memcpy(output_ptr + len * i, input_ptr + i * channel + start, size);
} }
} }
} // namespace operators } // namespace operators
......
...@@ -23,14 +23,21 @@ namespace operators { ...@@ -23,14 +23,21 @@ namespace operators {
template <> template <>
bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
auto input = const_cast<LoDTensor *>(param->InputX()); auto input = const_cast<LoDTensor *>(param->InputX());
auto input_ptr = input->data<half>(); auto dims = framework::vectorize(input->dims());
half *input_ptr;
auto out = param->Out(); auto out = param->Out();
if (input->type() == typeid(float)) {
out->Resize(framework::make_ddim(dims));
out->mutable_data<float>(framework::make_ddim(dims));
} else {
input_ptr = input->data<half>();
}
auto float_input = new Tensor; auto float_input = new Tensor;
PADDLE_MOBILE_ENFORCE(input->dims().size() == 4, PADDLE_MOBILE_ENFORCE(input->dims().size() == 4,
"Softmax should have 4-order input"); "Softmax should have 4-order input");
auto dims = framework::vectorize(input->dims());
auto channel = dims[3]; auto channel = dims[3];
if (channel == 1) { // This input is generated by FC op, dims = [N C 1 1] if (channel == 1) { // This input is generated by FC op, dims = [N C 1 1]
PADDLE_MOBILE_ENFORCE(dims[2] == 1, "Softmax input must come from FC op"); PADDLE_MOBILE_ENFORCE(dims[2] == 1, "Softmax input must come from FC op");
...@@ -41,9 +48,12 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { ...@@ -41,9 +48,12 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
float_input->Resize(framework::make_ddim(dims)); float_input->Resize(framework::make_ddim(dims));
if (channel != 2) { // Use CPU if (channel != 2) { // Use CPU
out->Resize(framework::make_ddim(dims));
out->mutable_data<float>(framework::make_ddim(dims));
float_input->init(typeid(float)); float_input->init(typeid(float));
fpga::format_fp32_ofm(float_input); float_input->mutable_data<float>(framework::make_ddim(dims));
fpga::format_fp32_ofm(out); // fpga::format_fp32_ofm(float_input);
// fpga::format_fp32_ofm(out);
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_layout_type = fpga::LAYOUT_HWC; args.input_layout_type = fpga::LAYOUT_HWC;
...@@ -51,7 +61,7 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { ...@@ -51,7 +61,7 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
args.input_data_type = fpga::DATA_TYPE_FP16; args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32; args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = input_ptr; args.image.address = input_ptr;
args.image.height = (uint32_t)dims[1]; args.image.height = (uint32_t)dims[1] * dims[0];
args.image.width = (uint32_t)dims[2]; args.image.width = (uint32_t)dims[2];
args.image.channels = (uint32_t)dims[3]; args.image.channels = (uint32_t)dims[3];
args.output.address = float_input->data<float>(); args.output.address = float_input->data<float>();
...@@ -80,14 +90,23 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { ...@@ -80,14 +90,23 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
template <> template <>
void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) { void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
fpga::PerformBypass(param.FpgaArgs()); auto *in_x = (param.InputX());
if (in_x->type() == typeid(half)) {
if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) { fpga::PerformBypass(param.FpgaArgs());
Tensor *out = param.Out(); if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) {
Tensor *in_x = param.FloatInput(); Tensor *out = param.Out();
fpga::fpga_invalidate(in_x->data<float>(), in_x->numel() * sizeof(float)); Tensor *in_x2 = param.FloatInput();
math::SoftmaxFuntor<CPU, float>()(in_x, out);
fpga::fpga_flush(out->data<float>(), out->memory_size()); fpga::fpga_invalidate(in_x2->data<float>(),
in_x2->numel() * sizeof(float));
math::SoftmaxFuntor<CPU, float>()(in_x2, out);
fpga::fpga_flush(out->data<float>(), out->memory_size());
}
} else {
if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) {
Tensor *out = param.Out();
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
} }
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream> #include <fstream>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include "../test_include.h" #include "../test_include.h"
#ifdef PADDLE_MOBILE_FPGA_V1 #ifdef PADDLE_MOBILE_FPGA_V1
#include "fpga/V1/api.h" #include "fpga/V1/api.h"
#endif #endif
#ifdef PADDLE_MOBILE_FPGA_V2 #ifdef PADDLE_MOBILE_FPGA_V2
#include "fpga/V2/api.h" #include "fpga/V2/api.h"
#endif #endif
void readStream(std::string filename, float *buf) { void readStream(std::string filename, float *buf) {
std::ifstream in; std::ifstream in;
in.open(filename, std::ios::in); in.open(filename, std::ios::in);
if (!in.is_open()) { if (!in.is_open()) {
std::cout << "open File Failed." << std::endl; std::cout << "open File Failed." << std::endl;
return; return;
} }
string strOne; string strOne;
int i = 0; int i = 0;
while (!in.eof()) { while (!in.eof()) {
in >> buf[i]; in >> buf[i];
i++; i++;
} }
in.close(); in.close();
} }
void convert_to_chw(int16_t **data_in, int channel, int height, int width, void convert_to_chw(int16_t **data_in, int channel, int height, int width,
int16_t *data_tmp) { int16_t *data_tmp) {
int64_t amount_per_side = width * height; int64_t amount_per_side = width * height;
for (int h = 0; h < height; h++) { for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) { for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) { for (int c = 0; c < channel; c++) {
*(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++); *(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++);
} }
} }
} }
} }
void dump(std::string filename, Tensor input_tensor) { void dump(std::string filename, Tensor input_tensor) {
auto dataptr = reinterpret_cast<half *>(input_tensor.get_data()); auto dataptr = reinterpret_cast<half *>(input_tensor.get_data());
std::ofstream out(filename.c_str()); std::ofstream out(filename.c_str());
float result = 0; float result = 0;
for (int i = 0; i < input_tensor.numel(); ++i) { for (int i = 0; i < input_tensor.numel(); ++i) {
result = paddle_mobile::fpga::fp16_2_fp32(dataptr[i]); result = paddle_mobile::fpga::fp16_2_fp32(dataptr[i]);
out << result << std::endl; out << result << std::endl;
} }
out.close(); out.close();
} }
void dump_stride_half(std::string filename, Tensor input_tensor, void dump_stride_half(std::string filename, Tensor input_tensor,
const int dumpnum) { const int dumpnum) {
int c = (input_tensor.dims())[1]; int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2]; int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3]; int w = (input_tensor.dims())[3];
auto data_ptr = input_tensor.get_data(); auto data_ptr = input_tensor.get_data();
auto *data_tmp = auto *data_tmp =
reinterpret_cast<half *>(malloc(c * h * w * sizeof(int16_t))); reinterpret_cast<half *>(malloc(c * h * w * sizeof(int16_t)));
auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr); auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr);
convert_to_chw(&data_ptr_16, c, h, w, data_tmp); convert_to_chw(&data_ptr_16, c, h, w, data_tmp);
std::ofstream out(filename.c_str()); std::ofstream out(filename.c_str());
float result = 0; float result = 0;
int stride = input_tensor.numel() / dumpnum; int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) { for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]); result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl; out << result << std::endl;
} }
out.close(); out.close();
free(data_tmp); free(data_tmp);
} }
void dump_stride_float(std::string filename, Tensor input_tensor, void dump_stride_float(std::string filename, Tensor input_tensor,
const int dumpnum) { const int dumpnum) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data()); auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data());
std::ofstream out(filename.c_str()); std::ofstream out(filename.c_str());
float result = 0; float result = 0;
int stride = input_tensor.numel() / dumpnum; int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) { for (int i = 0; i < input_tensor.numel(); i += stride) {
result = data_ptr[i]; result = data_ptr[i];
out << result << std::endl; out << result << std::endl;
} }
out.close(); out.close();
} }
static const char *g_resnet50 = "../models/resnet50"; static const char *g_resnet50 = "../models/resnet50";
const std::string g_image_src_float = "../images/image_src_float"; const std::string g_image_src_float = "../images/image_src_float"; // NOLINT
int main() { int main() {
paddle_mobile::fpga::open_device(); paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_resnet50), true)) { if (paddle_mobile.Load(std::string(g_resnet50), true)) {
Tensor input_tensor; Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(2), SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(2),
static_cast<float>(2)); static_cast<float>(2));
readStream(g_image_src_float, readStream(g_image_src_float,
input_tensor.mutable_data<float>({1, 3, 224, 224})); input_tensor.mutable_data<float>({1, 3, 224, 224}));
paddle_mobile.FeedData(input_tensor); paddle_mobile.FeedData(input_tensor);
paddle_mobile.Predict_To(-1); paddle_mobile.Predict_To(-1);
for (int i = 0; i < 73; i++) { for (int i = 0; i < 73; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i); auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "resnet50_result_" + std::to_string(i); std::string saveName = "resnet50_result_" + std::to_string(i);
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(), paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(),
tensor_ptr->numel() * sizeof(half)); tensor_ptr->numel() * sizeof(half));
dump_stride_half(saveName, (*tensor_ptr), 20); // dump_stride_half(saveName, (*tensor_ptr), 20);
// dump(saveName, (*tensor_ptr)); // dump(saveName, (*tensor_ptr));
} }
auto tensor_ptr = paddle_mobile.FetchResult(73); auto tensor_ptr = paddle_mobile.FetchResult(73);
dump_stride_float("resnet50_result_73", (*tensor_ptr), 20); // dump_stride_float("resnet50_result_73", (*tensor_ptr), 20);
tensor_ptr = paddle_mobile.FetchResult(74); tensor_ptr = paddle_mobile.FetchResult(74);
dump_stride_float("resnet50_result_74", (*tensor_ptr), 9999); // dump_stride_float("resnet50_result_74", (*tensor_ptr), 9999);
float max = 0; float max = 0;
auto data_ptr = tensor_ptr->data<float>(); auto data_ptr = tensor_ptr->data<float>();
int maximumIdx = 0; int maximumIdx = 0;
for (int i = 0; i < (*tensor_ptr).numel(); i++) { for (int i = 0; i < (*tensor_ptr).numel(); i++) {
if (data_ptr[i] > max) { if (data_ptr[i] > max) {
maximumIdx = i; maximumIdx = i;
max = data_ptr[i]; max = data_ptr[i];
} }
} }
std::cout << "index : " << std::dec << maximumIdx << ", value : " << max std::cout << "index : " << std::dec << maximumIdx << ", value : " << max
<< std::endl; << std::endl;
std::cout << "Computation done" << std::endl; std::cout << "Computation done" << std::endl;
return 0; return 0;
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream> #include <iostream>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#ifdef PADDLE_MOBILE_FPGA_V1 #ifdef PADDLE_MOBILE_FPGA_V1
#include "fpga/V1/api.h" #include "fpga/V1/api.h"
#endif #endif
#ifdef PADDLE_MOBILE_FPGA_V2 #ifdef PADDLE_MOBILE_FPGA_V2
#include "fpga/V2/api.h" #include "fpga/V2/api.h"
#endif #endif
void readStream(std::string filename, uint8_t *buf) { #include <string>
std::ifstream in;
in.open(filename, std::ios::in); void readStream(std::string filename, char *buf) {
if (!in.is_open()) { std::ifstream in;
std::cout << "open File Failed." << std::endl; in.open(filename, std::ios::in | std::ios::binary);
return; if (!in.is_open()) {
} std::cout << "open File Failed." << std::endl;
int i = 0; return;
while (!in.eof()) { }
in >> buf[i];
i++; in.seekg(0, std::ios::end); // go to the end
} auto length = in.tellg(); // report location (this is the length)
in.close(); in.seekg(0, std::ios::beg); // go back to the beginning
} in.read(buf, length);
DLOG << length;
static const char *g_rfcn_combine = "../models/rfcn"; in.close();
static const char *g_image_src_float = "../models/rfcn/data.bin"; }
int main() {
paddle_mobile::fpga::open_device(); void convert_to_chw(int16_t **data_in, int channel, int height, int width,
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile; int num, int16_t *data_tmp) {
int64_t amount_per_side = width * height;
if (paddle_mobile.Load(std::string(g_rfcn_combine) + "/model", for (int n = 0; n < num; n++) {
std::string(g_rfcn_combine) + "/params", true, false, for (int h = 0; h < height; h++) {
1, true)) { for (int w = 0; w < width; w++) {
float img_info[3] = {768, 1536, 768.0f / 960.0f}; for (int c = 0; c < channel; c++) {
auto img = fpga::fpga_malloc(768 * 1536 * 3 * sizeof(float)); *(data_tmp + n * amount_per_side * channel + c * amount_per_side +
readStream(g_image_src_float, reinterpret_cast<uint8_t *>(img)); width * h + w) = *((*data_in)++);
std::vector<void *> v(3, nullptr); }
paddle_mobile.FeedData({img_info, img}); }
paddle_mobile.Predict_To(-1); }
paddle_mobile.GetResults(&v); }
DLOG << "Computation done"; }
fpga::fpga_free(img);
} void dump_stride_half(std::string filename, Tensor input_tensor,
const int dumpnum, bool use_chw) {
return 0; // bool use_chw = true;
} if (input_tensor.dims().size() != 4) return;
int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3];
int n = (input_tensor.dims())[0];
auto data_ptr = input_tensor.get_data();
auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr);
auto data_tmp = data_ptr_16;
if (use_chw) {
data_tmp =
reinterpret_cast<half *>(malloc(n * c * h * w * sizeof(int16_t)));
convert_to_chw(&data_ptr_16, c, h, w, n, data_tmp);
}
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl;
}
out.close();
if (data_tmp != data_ptr_16) {
free(data_tmp);
}
}
void dump_stride_float(std::string filename, Tensor input_tensor,
const int dumpnum) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data());
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = data_ptr[i];
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename, Tensor input_tensor, const int dumpnum,
bool use_chw) {
static int i = 0;
if (input_tensor.numel() == 0) {
return;
}
if (input_tensor.type() == typeid(float)) {
DLOG << "op: " << i++ << ", float data " << input_tensor.numel();
dump_stride_float(filename, input_tensor, dumpnum);
} else {
DLOG << "op: " << i++ << ", half data " << input_tensor.numel();
dump_stride_half(filename, input_tensor, dumpnum, use_chw);
}
DLOG << "dump input address: " << input_tensor.get_data();
}
static const char *g_rfcn_combine = "../models/rfcn";
static const char *g_image_src_float = "../models/rfcn/data.bin";
int main() {
paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_rfcn_combine) + "/model",
std::string(g_rfcn_combine) + "/params", true, false,
1, true)) {
float img_info[3] = {768, 1536, 768.0f / 960.0f};
auto img = reinterpret_cast<float *>(
fpga::fpga_malloc(768 * 1536 * 3 * sizeof(float)));
readStream(g_image_src_float, reinterpret_cast<char *>(img));
std::vector<void *> v(3, nullptr);
paddle_mobile.FeedData({img_info, img});
paddle_mobile.Predict_To(-1);
for (int i = 55; i < 69; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "rfcn_" + std::to_string(i);
// if(i != 58)
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(),
tensor_ptr->numel() * sizeof(float));
// tensor_ptr->numel() * sizeof(float));
if ((i == 48) || (i == 47)) {
dump_stride(saveName, (*tensor_ptr), 20,
false); // 20);//tensor_ptr->numel());
} else if (i == 55) {
dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(),
true); // 20);//tensor_ptr->numel());
} else {
dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(),
true); // 20);//tensor_ptr->numel());
}
/* float result = 0;
std::string str = "softmax_input_data";
float* data =
static_cast<float*>(fpga::fpga_malloc(tensor_ptr->numel() *
sizeof(float))); str = "softmax_output_data"; auto output_ptr =
static_cast<half*>((*tensor_ptr).get_data()); for (int idx = 0; idx <
tensor_ptr->numel(); ++idx)
{
data[idx] = fpga::fp16_2_fp32(output_ptr[idx]);
}
fpga::savefile<float>(str,data, tensor_ptr->numel(), result ); */
}
// paddle_mobile.GetResults(&v);
DLOG << "Computation done";
fpga::fpga_free(img);
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册