提交 52c6c1a8 编写于 作者: J jameswu2014

format modify

上级 2f507f76
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "fpga/V1/api.h"
#include <memory>
#include "fpga/V1/bias_scale.h"
#include "fpga/V1/deconv_filter.h"
#include "fpga/V1/filter.h"
......@@ -368,7 +369,8 @@ void expand_conv_arg(ConvArgs *arg) {
auto filter_pad_width_mul_channel =
args.image.pad_width * args.image.channels;
auto image_amount_per_row_multi_win_first =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height);
image_amount_per_row *
(ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height);
auto image_amount_per_row_multi_win =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h);
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include <fstream>
#include <iomanip>
#include <iostream>
#include <utility>
#include "common/enforce.h"
#include "fpga/common/driver.h"
......@@ -147,8 +148,6 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) {
}
}
void memory_release(struct fpga_memory *memory) {
void *ptr = nullptr;
......@@ -160,8 +159,6 @@ void memory_release(struct fpga_memory *memory) {
}
}
uint64_t vaddr_to_paddr_driver(void *address) {
uint64_t paddr = 0;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(address);
......@@ -209,14 +206,14 @@ void *fpga_malloc_driver(size_t size) {
struct MemoryVM2PHYArgs args;
struct MemoryCacheArgs args_c;
// memory_request(g_fpgainfo.memory_info, size, &phy_addr);
// memory_request(g_fpgainfo.memory_info, size, &phy_addr);
ret = mmap64(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED,
g_fpgainfo.fd_mem, FPGA_MEM_PHY_ADDR);
PADDLE_MOBILE_ENFORCE(ret != (void *)-1, "Should not be -1");
args.pVM= (void *)ret;
args.pPHY =(void *)0;
args.pVM = reinterpret_cast<void *>(ret);
args.pPHY = reinterpret_cast<void *>(0);
do_ioctl(IOCTL_MEMORY_VM2PHY, &args);
phy_addr = (uint64_t)args.pPHY;
......@@ -237,9 +234,8 @@ void fpga_free_driver(void *ptr) {
g_fpgainfo.fpga_addr2size_map.erase(iter);
munmap(ptr, size);
p_addr = vaddr_to_paddr_driver(ptr);
pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE;
// p_addr = vaddr_to_paddr_driver(ptr);
// pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(ptr);
if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) {
......@@ -299,7 +295,7 @@ int open_device_driver() {
g_fpgainfo.FpgaRegVirAddr =
(uint64_t *)fpga_reg_malloc(FPGA_REG_SIZE); // NOLINT
//fpga_memory_add();
// fpga_memory_add();
pl_init();
......@@ -310,7 +306,7 @@ int close_device_driver() {
pl_destroy();
fpga_reg_free(g_fpgainfo.FpgaRegVirAddr);
memory_release(g_fpgainfo.memory_info);
return 0;
}
......
......@@ -53,15 +53,14 @@ struct MemoryCacheArgs {
};
struct MemoryVM2PHYArgs {
void* pVM;
void* pPHY;
void *pVM;
void *pPHY;
};
#define IOCTL_FPGA_MAGIC 'F'
#define IOCTL_MEMCACHE_INVAL _IOW(IOCTL_FPGA_MAGIC, 12, struct MemoryCacheArgs)
#define IOCTL_MEMCACHE_FLUSH _IOW(IOCTL_FPGA_MAGIC, 13, struct MemoryCacheArgs)
#define IOCTL_MEMORY_VM2PHY _IOWR(IOCTL_FPGA_MAGIC, 15, struct MemoryVM2PHYArgs)
#define IOCTL_MEMORY_VM2PHY _IOWR(IOCTL_FPGA_MAGIC, 15, struct MemoryVM2PHYArgs)
struct fpga_pe {
char type_name[MAX_TYPE_NAME_LENTH + 1];
......
......@@ -25,7 +25,7 @@ limitations under the License. */
#define FILTER_ELEMENT_ALIGNMENT (16) // Filter element number aligned to 16
#define BS_NUM_ALIGNMENT (8)
#define BIAS_NUM_ALIGNMENT (16)
#define ROW_PARALLEL_NUM (3)
#define ROW_PARALLEL_NUM (3)
#endif
namespace paddle_mobile {
......
......@@ -74,15 +74,14 @@ void RoiAlignPoolOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.input_x_->dims();
out_dims[0] = rois_dims[0];
// out_dims[1] =
// output_channels; // input_dims[1] / (pooled_height * pooled_width);
// out_dims[1] =
// output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
this->param_.output_->Resize(out_dims);
}
#endif
#ifdef ROI_PERSPECTIVE_OP
template <typename DeviceType, typename T>
void RoiPerspectiveOp<DeviceType, T>::InferShape() const {
......
......@@ -38,7 +38,6 @@ DECLARE_OPERATOR(PSRoiPool, PSRoiPoolParam, PSRoiPoolKernel);
DECLARE_OPERATOR(RoiAlignPool, RoiAlignPoolParam, RoiAlignPoolKernel);
#endif
#ifdef ROI_PERSPECTIVE_OP
DECLARE_OPERATOR(RoiPerspective, RoiPerspectiveParam, RoiPerspectiveKernel);
#endif
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -157,18 +158,20 @@ DECLARE_KERNEL(PSRoiPool, PSRoiPoolParam);
template <typename Dtype>
class RoiAlignPoolParam : public OpParam {
public:
RoiAlignPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
RoiAlignPoolParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
pooled_width_ = OpParam::GetAttr<int>("pooled_width", attrs);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
sampling_ratio_ = OpParam::GetAttr<float>("sampling_ratio", attrs);
sampling_ratio_ = OpParam::GetAttr<float>("sampling_ratio", attrs);
}
public:
......@@ -180,10 +183,9 @@ class RoiAlignPoolParam : public OpParam {
float spatial_scale_;
int sampling_ratio_;
#ifdef PADDLE_MOBILE_FPGA
std::shared_ptr<Tensor> float_input, float_output;
fpga::BypassArgs input_arg, output_arg;
std::shared_ptr<Tensor> float_input, float_output;
fpga::BypassArgs input_arg, output_arg;
#endif
};
DECLARE_KERNEL(RoiAlignPool, RoiAlignPoolParam);
......
......@@ -56,7 +56,7 @@ void dealign(float *src, float *dst, int input_c, int input_h, int input_w) {
}
template <>
void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto input = param.InputX();
auto input = const_cast<Tensor *>(param.InputX());
if (input->type() == typeid(float)) {
auto output = param.Out();
output->ShareDataWith(*input);
......@@ -73,15 +73,14 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
reinterpret_cast<float *>(param.fpga_bypass_args.output.address);
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
param.Out()->fpga_data_num * sizeof(float));
if(param.Out()->fpga_data_num != product(input->dims())){
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp);
if (param.Out()->fpga_data_num != product(input->dims())) {
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp);
}
}
template class FetchKernel<FPGA, float>;
......
......@@ -74,10 +74,11 @@ void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) {
auto *output = param.Output();
auto in = input->data<float>();
auto N = input->dims()[0];
output->Resize({N, output->dims()[1], output->dims()[2], output->dims()[3]});
output->Resize(
{N, output->dims()[1], output->dims()[2], output->dims()[3]});
auto len = output->numel();
auto out = output->mutable_data<float>();
int C = input->dims()[1], H = input->dims()[2],//N = input->dims()[0],
int C = input->dims()[1], H = input->dims()[2], // N = input->dims()[0],
W = input->dims()[3];
int HW = H * W, CHW = C * H * W, WC = W * C;
......
......@@ -65,13 +65,12 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) {
args.output.scale_address = param->float_score->scale;
param->score_arg = args;
param->score_index_= std::make_shared<Tensor>();
param->score_index_ = std::make_shared<Tensor>();
param->score_index_->mutable_data<int32_t>({input->numel()});
auto score_index = param->score_index_->data<int32_t>();
for (int i = 0; i < input->numel(); ++i){
score_index[i] = i;
for (int i = 0; i < input->numel(); ++i) {
score_index[i] = i;
}
return true;
}
......@@ -342,9 +341,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
const Tensor &im_info_slice, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
const Tensor &score_index,
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) {
const Tensor &score_index, int pre_nms_top_n, int post_nms_top_n,
float nms_thresh, float min_size, float eta) {
auto *scores_data = scores_slice.data<T>();
// Sort index
......@@ -354,8 +352,9 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
/*for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}*/
std::memcpy(index,score_index.data<int32_t>(),scores_slice.numel()*sizeof(int) );
std::memcpy(index, score_index.data<int32_t>(),
scores_slice.numel() * sizeof(int));
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
......@@ -504,7 +503,7 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
auto score_index = *(param.score_index_.get());
int pre_nms_top_n = param.pre_nms_topn_;
int post_nms_top_n = 100;//param.post_nms_topn_;
int post_nms_top_n = 100; // param.post_nms_topn_;
float nms_thresh = param.nms_thresh_;
float min_size = param.min_size_;
float eta = param.eta_;
......@@ -541,8 +540,8 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>(
im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,score_index,
pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta);
im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,
score_index, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PSROI_POOL_OP
#include <cmath>
#include <vector>
#include "operators/kernel/detection_kernel.h"
#include "fpga/V1/api.h"
#include "fpga/V1/image.h"
namespace paddle_mobile {
namespace operators {
template <>
bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
auto dims = param->input_x_->dims();
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned");
param->float_input = std::make_shared<Tensor>();
param->float_input->mutable_data<float>(param->input_x_->dims());
// param->float_output = std::make_shared<Tensor>();
auto input = param->input_x_;
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_HWC;
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = input->data<half>();
args.image.height = (uint32_t)input->dims()[2];
args.image.width = (uint32_t)input->dims()[3];
args.image.channels = (uint32_t)input->dims()[1];
args.output.address = param->float_input->mutable_data<float>();
args.output.scale_address = param->float_input->scale;
param->input_arg = args;
auto* rois = param->input_rois_;
int rois_num = rois->dims()[0];
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, param->output_->dims()[1], param->output_->dims()[2],
param->output_->dims()[3]});
param->output_->Resize(dims_out_new);
// fpga::format_fp16_ofm(param->output_);
param->output_->mutable_data<float>(dims_out_new);
// auto output = param->float_output.get();
// param->output_ = output;
/* args.input_data_type = fpga::DATA_TYPE_FP32;
args.output_data_type = fpga::DATA_TYPE_FP16;
args.image.address = output->data<float>();
args.image.height = (uint32_t)output->dims()[2];
args.image.width = (uint32_t)output->dims()[3];
args.image.channels = (uint32_t)output->dims()[1] ;
args.output.address = param->output_->mutable_data<half>();
args.output.scale_address = param->output_->scale;
param->output_arg = args;*/
return true;
}
template <typename Dtype>
void PSROIPooling(
const Dtype* bottom_data, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const Dtype* bottom_rois, const int output_dim,
const int group_size, Dtype* top_data,
int index, int nid,
const Dtype Bin_size_h,
const Dtype Bin_size_w,
const Dtype roi_start_h,
const Dtype roi_start_w,
const int ctop, const int ph, const int roi_batch_ind)
{
int pw = index;
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw)* Bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c = (ctop*group_size + ph)*group_size + pw;
Dtype bin_area = (hend - hstart)*(wend - wstart);
bottom_data += (roi_batch_ind * channels + c) * height * width;
Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h * width + w;
out_sum += bottom_data[bottom_index];
}
}
top_data[nid + index] = is_empty? 0. : out_sum/bin_area;
}
void convert_to_chw(float **data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float *data_tmp =
(float *)fpga::fpga_malloc(channel * height * width * sizeof(float)); // NOLINT
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PSROI_POOL_OP
#include <cmath>
#include <memory>
#include <vector>
#include "operators/kernel/detection_kernel.h"
#include "fpga/V1/api.h"
#include "fpga/V1/image.h"
namespace paddle_mobile {
namespace operators {
template <>
bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
auto dims = param->input_x_->dims();
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned");
param->float_input = std::make_shared<Tensor>();
param->float_input->mutable_data<float>(param->input_x_->dims());
// param->float_output = std::make_shared<Tensor>();
auto input = param->input_x_;
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_HWC;
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = input->data<half>();
args.image.height = (uint32_t)input->dims()[2];
args.image.width = (uint32_t)input->dims()[3];
args.image.channels = (uint32_t)input->dims()[1];
args.output.address = param->float_input->mutable_data<float>();
args.output.scale_address = param->float_input->scale;
param->input_arg = args;
auto* rois = param->input_rois_;
int rois_num = rois->dims()[0];
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, param->output_->dims()[1], param->output_->dims()[2],
param->output_->dims()[3]});
param->output_->Resize(dims_out_new);
// fpga::format_fp16_ofm(param->output_);
param->output_->mutable_data<float>(dims_out_new);
// auto output = param->float_output.get();
// param->output_ = output;
/* args.input_data_type = fpga::DATA_TYPE_FP32;
args.output_data_type = fpga::DATA_TYPE_FP16;
args.image.address = output->data<float>();
args.image.height = (uint32_t)output->dims()[2];
args.image.width = (uint32_t)output->dims()[3];
args.image.channels = (uint32_t)output->dims()[1] ;
args.output.address = param->output_->mutable_data<half>();
args.output.scale_address = param->output_->scale;
param->output_arg = args;*/
return true;
}
template <typename Dtype>
void PSROIPooling(const Dtype* bottom_data, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const Dtype* bottom_rois,
const int output_dim, const int group_size, Dtype* top_data,
int index, int nid, const Dtype Bin_size_h,
const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype roi_start_w, const int ctop, const int ph,
const int roi_batch_ind) {
int pw = index;
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw) * Bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, 0), height);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c = (ctop * group_size + ph) * group_size + pw;
Dtype bin_area = (hend - hstart) * (wend - wstart);
bottom_data += (roi_batch_ind * channels + c) * height * width;
Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h * width + w;
out_sum += bottom_data[bottom_index];
}
}
top_data[nid + index] = is_empty ? 0. : out_sum / bin_area;
}
void convert_to_chw(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(channel * height * width * sizeof(float))); // NOLINT
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
......@@ -130,15 +125,15 @@ void convert_to_chw(float **data_in, int channel, int height, int width,
}
}
}
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
}
void convert_to_hwc(float **data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float *data_tmp = reinterpret_cast<float *>(
fpga::fpga_malloc(num * channel * height * width * sizeof(float)));
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
}
void convert_to_hwc(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(num * channel * height * width * sizeof(float)));
int64_t amount_per_row = width * channel;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) {
......@@ -151,110 +146,116 @@ void convert_to_hwc(float **data_in, int channel, int height, int width,
}
}
}
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
}
template <>
void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
fpga::PerformBypass(param.input_arg);
fpga::fpga_invalidate(input_tensor->data<float>(),
input_tensor->numel() * sizeof(float));
auto* in = input_tensor;
auto* rois = param.input_rois_;
auto* out = param.output_; // param.float_output.get();
auto pooled_height = param.pooled_height_;
auto pooled_width = param.pooled_width_;
auto spatial_scale = param.spatial_scale_;
auto output_channels = param.output_channels_;
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto data_nhwc = in->mutable_data<float>();
convert_to_chw(&data_nhwc, input_channels, height, width, 1);
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new);
const float* input_data = data_nhwc; // in->data<float>();
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty");
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_MOBILE_ENFORCE(
rois_batch_size == batch_size,
"the rois_batch_size and input(X) batch_size should be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_MOBILE_ENFORCE(rois_num_with_lod == rois_num,
"the rois_num from input and lod must be the same");
PADDLE_MOBILE_ENFORCE(
input_channels == output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
//for (int n = 0; n < rois_batch_size; ++n) {
//for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
//rois_batch_id_data[i] = n;
// }
//}
auto output_data = out->mutable_data<float>();
auto input_rois = rois->data<float>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// [start, end) interval for spatial sampling
auto offset_input_rois = input_rois + n * 4;
auto roi_start_w = static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h = static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w = static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_end_h = static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0
auto roi_width = std::max(roi_end_w - roi_start_w, 0.1f);
// Compute bin size w and h at input feature map
auto bin_size_h = roi_height / static_cast<float>(pooled_height);
auto bin_size_w = roi_width / static_cast<float>(pooled_width);
int roi_batch_ind = 0;//rois_batch_id_data[n];
//std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
for(int c = 0; c < output_channels; ++c){
for(int ph = 0; ph < pooled_height; ph++){
int index = pooled_width;
int nid = n * output_channels * pooled_height * pooled_width + c * pooled_width * pooled_height + ph * pooled_width;
for(int idx = 0; idx < index; idx++){
PSROIPooling<float>(input_data,input_channels,height,width,pooled_height,pooled_width,
input_rois,output_channels,pooled_height,output_data, idx, nid, bin_size_h, bin_size_w, roi_start_h, roi_start_w, c, ph, roi_batch_ind);
}
}
}
}
convert_to_hwc(&output_data, output_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(output_data);
}
} // namespace operators
} // namespace paddle_mobile
#endif // PSROI_POOL_OP
*data_in = data_tmp;
fpga::fpga_free(data_in_tmp);
}
template <>
void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
fpga::PerformBypass(param.input_arg);
fpga::fpga_invalidate(input_tensor->data<float>(),
input_tensor->numel() * sizeof(float));
auto* in = input_tensor;
auto* rois = param.input_rois_;
auto* out = param.output_; // param.float_output.get();
auto pooled_height = param.pooled_height_;
auto pooled_width = param.pooled_width_;
auto spatial_scale = param.spatial_scale_;
auto output_channels = param.output_channels_;
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto data_nhwc = in->mutable_data<float>();
fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width, 1);
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new);
float* input_data = data_nhwc; // in->data<float>();
// shared_ptr<float> input_data(data_nhwc);
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty");
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_MOBILE_ENFORCE(
rois_batch_size == batch_size,
"the rois_batch_size and input(X) batch_size should be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_MOBILE_ENFORCE(rois_num_with_lod == rois_num,
"the rois_num from input and lod must be the same");
PADDLE_MOBILE_ENFORCE(
input_channels == output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
// for (int n = 0; n < rois_batch_size; ++n) {
// for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
// rois_batch_id_data[i] = n;
// }
//}
auto output_data = out->mutable_data<float>();
auto input_rois = rois->data<float>();
// calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) {
// [start, end) interval for spatial sampling
auto offset_input_rois = input_rois + n * 4;
auto roi_start_w =
static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h =
static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w =
static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_end_h =
static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0
auto roi_width = std::max(roi_end_w - roi_start_w, 0.1f);
// Compute bin size w and h at input feature map
auto bin_size_h = roi_height / static_cast<float>(pooled_height);
auto bin_size_w = roi_width / static_cast<float>(pooled_width);
int roi_batch_ind = 0; // rois_batch_id_data[n];
// std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < pooled_height; ph++) {
int index = pooled_width;
int nid = n * output_channels * pooled_height * pooled_width +
c * pooled_width * pooled_height + ph * pooled_width;
for (int idx = 0; idx < index; idx++) {
PSROIPooling<float>(input_data, input_channels, height, width,
pooled_height, pooled_width, input_rois,
output_channels, pooled_height, output_data, idx,
nid, bin_size_h, bin_size_w, roi_start_h,
roi_start_w, c, ph, roi_batch_ind);
}
}
}
}
fpga::fpga_free(input_data);
fpga::image::convert_to_hwc(&output_data, output_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(output_data);
}
} // namespace operators
} // namespace paddle_mobile
#endif // PSROI_POOL_OP
......@@ -24,10 +24,8 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
auto dims = param->input_x_->dims();
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned");
......@@ -58,11 +56,9 @@ bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
param->output_->mutable_data<float>(dims_out_new);
return true;
}
template <typename T>
struct PreCalc {
int pos1;
......@@ -77,30 +73,22 @@ struct PreCalc {
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
const int height, const int width, const int pooled_height,
const int pooled_width, const int iy_upper, const int ix_upper,
T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
int roi_bin_grid_h, int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) { // NOLINT
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T x = xx;
T y = yy;
......@@ -128,8 +116,8 @@ void pre_calc_for_bilinear_interpolate(
x = 0;
}
int y_low = (int)y;
int x_low = (int)x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
......@@ -172,22 +160,13 @@ void pre_calc_for_bilinear_interpolate(
}
template <typename T>
void ROIAlignForward(
const int nthreads,
const T* bottom_data,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* bottom_rois,
T* top_data) {
void ROIAlignForward(const int nthreads, const T* bottom_data,
const T& spatial_scale, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int sampling_ratio,
const T* bottom_rois, T* top_data) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
......@@ -195,8 +174,8 @@ void ROIAlignForward(
const T* offset_bottom_rois = bottom_rois + n * 4;
int roi_batch_ind = 0;
// if (roi_cols == 5) {
// roi_batch_ind = offset_bottom_rois[0];
// offset_bottom_rois++;
// roi_batch_ind = offset_bottom_rois[0];
// offset_bottom_rois++;
// }
// Do not using rounding; this implementation detail is critical
......@@ -217,70 +196,58 @@ void ROIAlignForward(
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
// we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation
std::vector<PreCalc<T>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate(
height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_start_h,
roi_start_w,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
pc.w4 * offset_bottom_data[pc.pos4];
pre_calc_index += 1;
}
height, width, pooled_height, pooled_width, roi_bin_grid_h,
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
roi_bin_grid_h, roi_bin_grid_w, pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
pc.w4 * offset_bottom_data[pc.pos4];
pre_calc_index += 1;
}
output_val /= count;
}
output_val /= count;
top_data[index] = output_val;
} // for pw
} // for ph
} // for c
} // for n
top_data[index] = output_val;
} // for pw
} // for ph
} // for c
} // for n
}
template <>
void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
void RoiAlignPoolKernel<FPGA, float>::Compute(
const RoiAlignPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
fpga::PerformBypass(param.input_arg);
fpga::fpga_invalidate(input_tensor->data<float>(),
input_tensor->numel() * sizeof(float));
......@@ -308,23 +275,22 @@ void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& par
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new);
const int index = input_channels * pooled_height * pooled_width * rois_num;
auto rois_data = rois->data<float>();
auto top_data = param.output_->mutable_data<float>();
for (int i = 0; i < index; ++i){
ROIAlignForward<float>( index,data_nhwc,spatial_scale,input_channels,height,width,
pooled_height,pooled_width,sampe_ratio,rois_data,top_data);
for (int i = 0; i < index; ++i) {
ROIAlignForward<float>(index, data_nhwc, spatial_scale, input_channels,
height, width, pooled_height, pooled_width,
sampe_ratio, rois_data, top_data);
}
fpga::image::convert_to_hwc(&top_data, input_channels, pooled_height,
fpga::image::convert_to_hwc(&top_data, input_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(top_data);
out->reset_data_ptr(top_data);
}
} // namespace operators
} // namespace paddle_mobile
#endif // ROIALIGN_POOL_OP
......@@ -105,7 +105,8 @@ void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
} else {
if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) {
Tensor *out = param.Out();
out->Resize({in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]});
out->Resize(
{in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]});
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
}
......
......@@ -44,8 +44,9 @@ void Transpose2Kernel<FPGA, float>::Compute(
// Transpose2Compute<float>(param);
auto input = param.InputX();
auto output = param.Out();
output->Resize({input->dims()[0], output->dims()[1], output->dims()[2], output->dims()[3]});
output->Resize({input->dims()[0], output->dims()[1], output->dims()[2],
output->dims()[3]});
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册