提交 52c6c1a8 编写于 作者: J jameswu2014

format modify

上级 2f507f76
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "fpga/V1/api.h" #include "fpga/V1/api.h"
#include <memory>
#include "fpga/V1/bias_scale.h" #include "fpga/V1/bias_scale.h"
#include "fpga/V1/deconv_filter.h" #include "fpga/V1/deconv_filter.h"
#include "fpga/V1/filter.h" #include "fpga/V1/filter.h"
...@@ -368,7 +369,8 @@ void expand_conv_arg(ConvArgs *arg) { ...@@ -368,7 +369,8 @@ void expand_conv_arg(ConvArgs *arg) {
auto filter_pad_width_mul_channel = auto filter_pad_width_mul_channel =
args.image.pad_width * args.image.channels; args.image.pad_width * args.image.channels;
auto image_amount_per_row_multi_win_first = auto image_amount_per_row_multi_win_first =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height); image_amount_per_row *
(ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height);
auto image_amount_per_row_multi_win = auto image_amount_per_row_multi_win =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h); image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h);
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <utility>
#include "common/enforce.h" #include "common/enforce.h"
#include "fpga/common/driver.h" #include "fpga/common/driver.h"
...@@ -147,8 +148,6 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) { ...@@ -147,8 +148,6 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) {
} }
} }
void memory_release(struct fpga_memory *memory) { void memory_release(struct fpga_memory *memory) {
void *ptr = nullptr; void *ptr = nullptr;
...@@ -160,8 +159,6 @@ void memory_release(struct fpga_memory *memory) { ...@@ -160,8 +159,6 @@ void memory_release(struct fpga_memory *memory) {
} }
} }
uint64_t vaddr_to_paddr_driver(void *address) { uint64_t vaddr_to_paddr_driver(void *address) {
uint64_t paddr = 0; uint64_t paddr = 0;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(address); auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(address);
...@@ -209,14 +206,14 @@ void *fpga_malloc_driver(size_t size) { ...@@ -209,14 +206,14 @@ void *fpga_malloc_driver(size_t size) {
struct MemoryVM2PHYArgs args; struct MemoryVM2PHYArgs args;
struct MemoryCacheArgs args_c; struct MemoryCacheArgs args_c;
// memory_request(g_fpgainfo.memory_info, size, &phy_addr); // memory_request(g_fpgainfo.memory_info, size, &phy_addr);
ret = mmap64(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, ret = mmap64(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED,
g_fpgainfo.fd_mem, FPGA_MEM_PHY_ADDR); g_fpgainfo.fd_mem, FPGA_MEM_PHY_ADDR);
PADDLE_MOBILE_ENFORCE(ret != (void *)-1, "Should not be -1"); PADDLE_MOBILE_ENFORCE(ret != (void *)-1, "Should not be -1");
args.pVM= (void *)ret; args.pVM = reinterpret_cast<void *>(ret);
args.pPHY =(void *)0; args.pPHY = reinterpret_cast<void *>(0);
do_ioctl(IOCTL_MEMORY_VM2PHY, &args); do_ioctl(IOCTL_MEMORY_VM2PHY, &args);
phy_addr = (uint64_t)args.pPHY; phy_addr = (uint64_t)args.pPHY;
...@@ -237,9 +234,8 @@ void fpga_free_driver(void *ptr) { ...@@ -237,9 +234,8 @@ void fpga_free_driver(void *ptr) {
g_fpgainfo.fpga_addr2size_map.erase(iter); g_fpgainfo.fpga_addr2size_map.erase(iter);
munmap(ptr, size); munmap(ptr, size);
p_addr = vaddr_to_paddr_driver(ptr); // p_addr = vaddr_to_paddr_driver(ptr);
pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE; // pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(ptr); auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(ptr);
if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) { if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) {
...@@ -299,7 +295,7 @@ int open_device_driver() { ...@@ -299,7 +295,7 @@ int open_device_driver() {
g_fpgainfo.FpgaRegVirAddr = g_fpgainfo.FpgaRegVirAddr =
(uint64_t *)fpga_reg_malloc(FPGA_REG_SIZE); // NOLINT (uint64_t *)fpga_reg_malloc(FPGA_REG_SIZE); // NOLINT
//fpga_memory_add(); // fpga_memory_add();
pl_init(); pl_init();
...@@ -310,7 +306,7 @@ int close_device_driver() { ...@@ -310,7 +306,7 @@ int close_device_driver() {
pl_destroy(); pl_destroy();
fpga_reg_free(g_fpgainfo.FpgaRegVirAddr); fpga_reg_free(g_fpgainfo.FpgaRegVirAddr);
memory_release(g_fpgainfo.memory_info); memory_release(g_fpgainfo.memory_info);
return 0; return 0;
} }
......
...@@ -53,15 +53,14 @@ struct MemoryCacheArgs { ...@@ -53,15 +53,14 @@ struct MemoryCacheArgs {
}; };
struct MemoryVM2PHYArgs { struct MemoryVM2PHYArgs {
void* pVM; void *pVM;
void* pPHY; void *pPHY;
}; };
#define IOCTL_FPGA_MAGIC 'F' #define IOCTL_FPGA_MAGIC 'F'
#define IOCTL_MEMCACHE_INVAL _IOW(IOCTL_FPGA_MAGIC, 12, struct MemoryCacheArgs) #define IOCTL_MEMCACHE_INVAL _IOW(IOCTL_FPGA_MAGIC, 12, struct MemoryCacheArgs)
#define IOCTL_MEMCACHE_FLUSH _IOW(IOCTL_FPGA_MAGIC, 13, struct MemoryCacheArgs) #define IOCTL_MEMCACHE_FLUSH _IOW(IOCTL_FPGA_MAGIC, 13, struct MemoryCacheArgs)
#define IOCTL_MEMORY_VM2PHY _IOWR(IOCTL_FPGA_MAGIC, 15, struct MemoryVM2PHYArgs) #define IOCTL_MEMORY_VM2PHY _IOWR(IOCTL_FPGA_MAGIC, 15, struct MemoryVM2PHYArgs)
struct fpga_pe { struct fpga_pe {
char type_name[MAX_TYPE_NAME_LENTH + 1]; char type_name[MAX_TYPE_NAME_LENTH + 1];
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#define FILTER_ELEMENT_ALIGNMENT (16) // Filter element number aligned to 16 #define FILTER_ELEMENT_ALIGNMENT (16) // Filter element number aligned to 16
#define BS_NUM_ALIGNMENT (8) #define BS_NUM_ALIGNMENT (8)
#define BIAS_NUM_ALIGNMENT (16) #define BIAS_NUM_ALIGNMENT (16)
#define ROW_PARALLEL_NUM (3) #define ROW_PARALLEL_NUM (3)
#endif #endif
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -74,15 +74,14 @@ void RoiAlignPoolOp<DeviceType, T>::InferShape() const { ...@@ -74,15 +74,14 @@ void RoiAlignPoolOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.input_x_->dims(); auto out_dims = this->param_.input_x_->dims();
out_dims[0] = rois_dims[0]; out_dims[0] = rois_dims[0];
// out_dims[1] = // out_dims[1] =
// output_channels; // input_dims[1] / (pooled_height * pooled_width); // output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height; out_dims[2] = pooled_height;
out_dims[3] = pooled_width; out_dims[3] = pooled_width;
this->param_.output_->Resize(out_dims); this->param_.output_->Resize(out_dims);
} }
#endif #endif
#ifdef ROI_PERSPECTIVE_OP #ifdef ROI_PERSPECTIVE_OP
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void RoiPerspectiveOp<DeviceType, T>::InferShape() const { void RoiPerspectiveOp<DeviceType, T>::InferShape() const {
......
...@@ -38,7 +38,6 @@ DECLARE_OPERATOR(PSRoiPool, PSRoiPoolParam, PSRoiPoolKernel); ...@@ -38,7 +38,6 @@ DECLARE_OPERATOR(PSRoiPool, PSRoiPoolParam, PSRoiPoolKernel);
DECLARE_OPERATOR(RoiAlignPool, RoiAlignPoolParam, RoiAlignPoolKernel); DECLARE_OPERATOR(RoiAlignPool, RoiAlignPoolParam, RoiAlignPoolKernel);
#endif #endif
#ifdef ROI_PERSPECTIVE_OP #ifdef ROI_PERSPECTIVE_OP
DECLARE_OPERATOR(RoiPerspective, RoiPerspectiveParam, RoiPerspectiveKernel); DECLARE_OPERATOR(RoiPerspective, RoiPerspectiveParam, RoiPerspectiveKernel);
#endif #endif
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/op_param.h" #include "operators/op_param.h"
...@@ -157,18 +158,20 @@ DECLARE_KERNEL(PSRoiPool, PSRoiPoolParam); ...@@ -157,18 +158,20 @@ DECLARE_KERNEL(PSRoiPool, PSRoiPoolParam);
template <typename Dtype> template <typename Dtype>
class RoiAlignPoolParam : public OpParam { class RoiAlignPoolParam : public OpParam {
public: public:
RoiAlignPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs, RoiAlignPoolParam(const VariableNameMap &inputs,
const AttributeMap &attrs, const Scope *scope) const VariableNameMap &outputs, const AttributeMap &attrs,
: OpParam(inputs, outputs, attrs, scope) { Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope); input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
input_rois_ = input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope); OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope); output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs); pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
pooled_width_ = OpParam::GetAttr<int>("pooled_width", attrs); pooled_width_ = OpParam::GetAttr<int>("pooled_width", attrs);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs); spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
sampling_ratio_ = OpParam::GetAttr<float>("sampling_ratio", attrs); sampling_ratio_ = OpParam::GetAttr<float>("sampling_ratio", attrs);
} }
public: public:
...@@ -180,10 +183,9 @@ class RoiAlignPoolParam : public OpParam { ...@@ -180,10 +183,9 @@ class RoiAlignPoolParam : public OpParam {
float spatial_scale_; float spatial_scale_;
int sampling_ratio_; int sampling_ratio_;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
std::shared_ptr<Tensor> float_input, float_output; std::shared_ptr<Tensor> float_input, float_output;
fpga::BypassArgs input_arg, output_arg; fpga::BypassArgs input_arg, output_arg;
#endif #endif
}; };
DECLARE_KERNEL(RoiAlignPool, RoiAlignPoolParam); DECLARE_KERNEL(RoiAlignPool, RoiAlignPoolParam);
......
...@@ -56,7 +56,7 @@ void dealign(float *src, float *dst, int input_c, int input_h, int input_w) { ...@@ -56,7 +56,7 @@ void dealign(float *src, float *dst, int input_c, int input_h, int input_w) {
} }
template <> template <>
void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto input = param.InputX(); auto input = const_cast<Tensor *>(param.InputX());
if (input->type() == typeid(float)) { if (input->type() == typeid(float)) {
auto output = param.Out(); auto output = param.Out();
output->ShareDataWith(*input); output->ShareDataWith(*input);
...@@ -73,15 +73,14 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { ...@@ -73,15 +73,14 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
reinterpret_cast<float *>(param.fpga_bypass_args.output.address); reinterpret_cast<float *>(param.fpga_bypass_args.output.address);
fpga::fpga_invalidate(param.fpga_bypass_args.output.address, fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
param.Out()->fpga_data_num * sizeof(float)); param.Out()->fpga_data_num * sizeof(float));
if(param.Out()->fpga_data_num != product(input->dims())){ if (param.Out()->fpga_data_num != product(input->dims())) {
float *data_tmp = float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float))); reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW); dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float)); memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp); free(data_tmp);
} }
} }
template class FetchKernel<FPGA, float>; template class FetchKernel<FPGA, float>;
......
...@@ -74,10 +74,11 @@ void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) { ...@@ -74,10 +74,11 @@ void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) {
auto *output = param.Output(); auto *output = param.Output();
auto in = input->data<float>(); auto in = input->data<float>();
auto N = input->dims()[0]; auto N = input->dims()[0];
output->Resize({N, output->dims()[1], output->dims()[2], output->dims()[3]}); output->Resize(
{N, output->dims()[1], output->dims()[2], output->dims()[3]});
auto len = output->numel(); auto len = output->numel();
auto out = output->mutable_data<float>(); auto out = output->mutable_data<float>();
int C = input->dims()[1], H = input->dims()[2],//N = input->dims()[0], int C = input->dims()[1], H = input->dims()[2], // N = input->dims()[0],
W = input->dims()[3]; W = input->dims()[3];
int HW = H * W, CHW = C * H * W, WC = W * C; int HW = H * W, CHW = C * H * W, WC = W * C;
......
...@@ -65,13 +65,12 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) { ...@@ -65,13 +65,12 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) {
args.output.scale_address = param->float_score->scale; args.output.scale_address = param->float_score->scale;
param->score_arg = args; param->score_arg = args;
param->score_index_= std::make_shared<Tensor>(); param->score_index_ = std::make_shared<Tensor>();
param->score_index_->mutable_data<int32_t>({input->numel()}); param->score_index_->mutable_data<int32_t>({input->numel()});
auto score_index = param->score_index_->data<int32_t>(); auto score_index = param->score_index_->data<int32_t>();
for (int i = 0; i < input->numel(); ++i){ for (int i = 0; i < input->numel(); ++i) {
score_index[i] = i; score_index[i] = i;
} }
return true; return true;
} }
...@@ -342,9 +341,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -342,9 +341,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
const Tensor &im_info_slice, const Tensor &anchors, const Tensor &variances, const Tensor &im_info_slice, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4] const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1] const Tensor &scores_slice, // [N, 1]
const Tensor &score_index, const Tensor &score_index, int pre_nms_top_n, int post_nms_top_n,
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, float nms_thresh, float min_size, float eta) {
float eta) {
auto *scores_data = scores_slice.data<T>(); auto *scores_data = scores_slice.data<T>();
// Sort index // Sort index
...@@ -354,8 +352,9 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -354,8 +352,9 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
/*for (int i = 0; i < scores_slice.numel(); ++i) { /*for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i; index[i] = i;
}*/ }*/
std::memcpy(index,score_index.data<int32_t>(),scores_slice.numel()*sizeof(int) ); std::memcpy(index, score_index.data<int32_t>(),
scores_slice.numel() * sizeof(int));
auto compare = [scores_data](const int64_t &i, const int64_t &j) { auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j]; return scores_data[i] > scores_data[j];
}; };
...@@ -504,7 +503,7 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) { ...@@ -504,7 +503,7 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
auto score_index = *(param.score_index_.get()); auto score_index = *(param.score_index_.get());
int pre_nms_top_n = param.pre_nms_topn_; int pre_nms_top_n = param.pre_nms_topn_;
int post_nms_top_n = 100;//param.post_nms_topn_; int post_nms_top_n = 100; // param.post_nms_topn_;
float nms_thresh = param.nms_thresh_; float nms_thresh = param.nms_thresh_;
float min_size = param.min_size_; float min_size = param.min_size_;
float eta = param.eta_; float eta = param.eta_;
...@@ -541,8 +540,8 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) { ...@@ -541,8 +540,8 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>( std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>(
im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,score_index, im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,
pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta); score_index, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = tensor_pair.first; Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second; Tensor &scores = tensor_pair.second;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PSROI_POOL_OP #ifdef PSROI_POOL_OP
#include <cmath> #include <cmath>
#include <vector> #include <memory>
#include "operators/kernel/detection_kernel.h" #include <vector>
#include "operators/kernel/detection_kernel.h"
#include "fpga/V1/api.h"
#include "fpga/V1/image.h" #include "fpga/V1/api.h"
namespace paddle_mobile { #include "fpga/V1/image.h"
namespace operators { namespace paddle_mobile {
namespace operators {
template <>
bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) { template <>
auto dims = param->input_x_->dims(); bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0, auto dims = param->input_x_->dims();
"data not aligned"); PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned");
param->float_input = std::make_shared<Tensor>();
param->float_input->mutable_data<float>(param->input_x_->dims()); param->float_input = std::make_shared<Tensor>();
// param->float_output = std::make_shared<Tensor>(); param->float_input->mutable_data<float>(param->input_x_->dims());
// param->float_output = std::make_shared<Tensor>();
auto input = param->input_x_;
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; auto input = param->input_x_;
args.input_layout_type = fpga::LAYOUT_HWC; fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.output_layout_type = fpga::LAYOUT_HWC; args.input_layout_type = fpga::LAYOUT_HWC;
args.input_data_type = fpga::DATA_TYPE_FP16; args.output_layout_type = fpga::LAYOUT_HWC;
args.output_data_type = fpga::DATA_TYPE_FP32; args.input_data_type = fpga::DATA_TYPE_FP16;
args.image.address = input->data<half>(); args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.height = (uint32_t)input->dims()[2]; args.image.address = input->data<half>();
args.image.width = (uint32_t)input->dims()[3]; args.image.height = (uint32_t)input->dims()[2];
args.image.channels = (uint32_t)input->dims()[1]; args.image.width = (uint32_t)input->dims()[3];
args.output.address = param->float_input->mutable_data<float>(); args.image.channels = (uint32_t)input->dims()[1];
args.output.scale_address = param->float_input->scale; args.output.address = param->float_input->mutable_data<float>();
param->input_arg = args; args.output.scale_address = param->float_input->scale;
param->input_arg = args;
auto* rois = param->input_rois_;
int rois_num = rois->dims()[0]; auto* rois = param->input_rois_;
framework::DDim dims_out_new = framework::make_ddim( int rois_num = rois->dims()[0];
{rois_num, param->output_->dims()[1], param->output_->dims()[2], framework::DDim dims_out_new = framework::make_ddim(
param->output_->dims()[3]}); {rois_num, param->output_->dims()[1], param->output_->dims()[2],
param->output_->Resize(dims_out_new); param->output_->dims()[3]});
// fpga::format_fp16_ofm(param->output_); param->output_->Resize(dims_out_new);
// fpga::format_fp16_ofm(param->output_);
param->output_->mutable_data<float>(dims_out_new);
// auto output = param->float_output.get(); param->output_->mutable_data<float>(dims_out_new);
// param->output_ = output; // auto output = param->float_output.get();
/* args.input_data_type = fpga::DATA_TYPE_FP32; // param->output_ = output;
args.output_data_type = fpga::DATA_TYPE_FP16; /* args.input_data_type = fpga::DATA_TYPE_FP32;
args.image.address = output->data<float>(); args.output_data_type = fpga::DATA_TYPE_FP16;
args.image.height = (uint32_t)output->dims()[2]; args.image.address = output->data<float>();
args.image.width = (uint32_t)output->dims()[3]; args.image.height = (uint32_t)output->dims()[2];
args.image.channels = (uint32_t)output->dims()[1] ; args.image.width = (uint32_t)output->dims()[3];
args.output.address = param->output_->mutable_data<half>(); args.image.channels = (uint32_t)output->dims()[1] ;
args.output.scale_address = param->output_->scale; args.output.address = param->output_->mutable_data<half>();
param->output_arg = args;*/ args.output.scale_address = param->output_->scale;
param->output_arg = args;*/
return true;
} return true;
}
template <typename Dtype>
void PSROIPooling( template <typename Dtype>
const Dtype* bottom_data, const int channels, void PSROIPooling(const Dtype* bottom_data, const int channels,
const int height, const int width, const int height, const int width, const int pooled_height,
const int pooled_height, const int pooled_width, const int pooled_width, const Dtype* bottom_rois,
const Dtype* bottom_rois, const int output_dim, const int output_dim, const int group_size, Dtype* top_data,
const int group_size, Dtype* top_data, int index, int nid, const Dtype Bin_size_h,
int index, int nid, const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype Bin_size_h, const Dtype roi_start_w, const int ctop, const int ph,
const Dtype Bin_size_w, const int roi_batch_ind) {
const Dtype roi_start_h, int pw = index;
const Dtype roi_start_w, int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
const int ctop, const int ph, const int roi_batch_ind) int wstart = floor(static_cast<Dtype>(pw) * Bin_size_w + roi_start_w);
{ int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
int pw = index; int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w);
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw)* Bin_size_w + roi_start_w); // Add roi offsets and clip to input boundaries
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h); hstart = std::min(std::max(hstart, 0), height);
int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w); hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width);
// Add roi offsets and clip to input boundaries wend = std::min(std::max(wend, 0), width);
hstart = std::min(std::max(hstart, 0), height); bool is_empty = (hend <= hstart) || (wend <= wstart);
hend = std::min(std::max(hend, 0), height);
wstart = std::min(std::max(wstart, 0), width); int c = (ctop * group_size + ph) * group_size + pw;
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart); Dtype bin_area = (hend - hstart) * (wend - wstart);
bottom_data += (roi_batch_ind * channels + c) * height * width;
int c = (ctop*group_size + ph)*group_size + pw; Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
Dtype bin_area = (hend - hstart)*(wend - wstart); for (int w = wstart; w < wend; ++w) {
bottom_data += (roi_batch_ind * channels + c) * height * width; int bottom_index = h * width + w;
Dtype out_sum = 0; out_sum += bottom_data[bottom_index];
for (int h = hstart; h < hend; ++h) { }
for (int w = wstart; w < wend; ++w) { }
int bottom_index = h * width + w;
out_sum += bottom_data[bottom_index]; top_data[nid + index] = is_empty ? 0. : out_sum / bin_area;
} }
}
void convert_to_chw(float** data_in, int channel, int height, int width,
top_data[nid + index] = is_empty? 0. : out_sum/bin_area; int num) {
float* data_in_tmp = *data_in;
} float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(channel * height * width * sizeof(float))); // NOLINT
void convert_to_chw(float **data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float *data_tmp =
(float *)fpga::fpga_malloc(channel * height * width * sizeof(float)); // NOLINT
int64_t amount_per_side = width * height; int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) { for (int h = 0; h < height; h++) {
...@@ -130,15 +125,15 @@ void convert_to_chw(float **data_in, int channel, int height, int width, ...@@ -130,15 +125,15 @@ void convert_to_chw(float **data_in, int channel, int height, int width,
} }
} }
} }
*data_in = data_tmp; *data_in = data_tmp;
fpga::fpga_free(data_in_tmp); fpga::fpga_free(data_in_tmp);
} }
void convert_to_hwc(float **data_in, int channel, int height, int width, void convert_to_hwc(float** data_in, int channel, int height, int width,
int num) { int num) {
float* data_in_tmp = *data_in; float* data_in_tmp = *data_in;
float *data_tmp = reinterpret_cast<float *>( float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(num * channel * height * width * sizeof(float))); fpga::fpga_malloc(num * channel * height * width * sizeof(float)));
int64_t amount_per_row = width * channel; int64_t amount_per_row = width * channel;
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) { for (int c = 0; c < channel; c++) {
...@@ -151,110 +146,116 @@ void convert_to_hwc(float **data_in, int channel, int height, int width, ...@@ -151,110 +146,116 @@ void convert_to_hwc(float **data_in, int channel, int height, int width,
} }
} }
} }
*data_in = data_tmp; *data_in = data_tmp;
fpga::fpga_free(data_in_tmp); fpga::fpga_free(data_in_tmp);
} }
template <>
template <> void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) { auto input_tensor = param.float_input.get();
auto input_tensor = param.float_input.get(); fpga::PerformBypass(param.input_arg);
fpga::PerformBypass(param.input_arg); fpga::fpga_invalidate(input_tensor->data<float>(),
fpga::fpga_invalidate(input_tensor->data<float>(), input_tensor->numel() * sizeof(float));
input_tensor->numel() * sizeof(float));
auto* in = input_tensor;
auto* in = input_tensor; auto* rois = param.input_rois_;
auto* rois = param.input_rois_; auto* out = param.output_; // param.float_output.get();
auto* out = param.output_; // param.float_output.get();
auto pooled_height = param.pooled_height_;
auto pooled_height = param.pooled_height_; auto pooled_width = param.pooled_width_;
auto pooled_width = param.pooled_width_; auto spatial_scale = param.spatial_scale_;
auto spatial_scale = param.spatial_scale_; auto output_channels = param.output_channels_;
auto output_channels = param.output_channels_;
auto in_dims = in->dims();
auto in_dims = in->dims(); int batch_size = in_dims[0];
int batch_size = in_dims[0]; int input_channels = in_dims[1];
int input_channels = in_dims[1]; int height = in_dims[2];
int height = in_dims[2]; int width = in_dims[3];
int width = in_dims[3]; int rois_num = rois->dims()[0];
int rois_num = rois->dims()[0];
auto data_nhwc = in->mutable_data<float>();
auto data_nhwc = in->mutable_data<float>(); fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width, 1);
convert_to_chw(&data_nhwc, input_channels, height, width, 1); framework::DDim dims_out_new = framework::make_ddim(
framework::DDim dims_out_new = framework::make_ddim( {rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])), (param.output_)->dims()[3]});
(param.output_)->dims()[3]}); (param.output_)->Resize(dims_out_new);
(param.output_)->Resize(dims_out_new);
float* input_data = data_nhwc; // in->data<float>();
const float* input_data = data_nhwc; // in->data<float>(); // shared_ptr<float> input_data(data_nhwc);
framework::Tensor rois_batch_id_list; framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num}); rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>(); auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty"); PADDLE_MOBILE_ENFORCE(rois->NumLevels() > 0, "ROIS should not be empty");
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
rois_batch_size == batch_size, rois_batch_size == batch_size,
"the rois_batch_size and input(X) batch_size should be the same."); "the rois_batch_size and input(X) batch_size should be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_MOBILE_ENFORCE(rois_num_with_lod == rois_num, PADDLE_MOBILE_ENFORCE(rois_num_with_lod == rois_num,
"the rois_num from input and lod must be the same"); "the rois_num from input and lod must be the same");
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
input_channels == output_channels * pooled_height * pooled_width, input_channels == output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of " "the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width"); "output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD // calculate batch id index for each roi according to LoD
//for (int n = 0; n < rois_batch_size; ++n) { // for (int n = 0; n < rois_batch_size; ++n) {
//for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { // for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
//rois_batch_id_data[i] = n; // rois_batch_id_data[i] = n;
// } // }
//} //}
auto output_data = out->mutable_data<float>(); auto output_data = out->mutable_data<float>();
auto input_rois = rois->data<float>(); auto input_rois = rois->data<float>();
// calculate psroipooling, parallel processing can be implemented per ROI // calculate psroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
// [start, end) interval for spatial sampling // [start, end) interval for spatial sampling
auto offset_input_rois = input_rois + n * 4; auto offset_input_rois = input_rois + n * 4;
auto roi_start_w = static_cast<float>(round(offset_input_rois[0])) * spatial_scale; auto roi_start_w =
auto roi_start_h = static_cast<float>(round(offset_input_rois[1])) * spatial_scale; static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_end_w = static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale; auto roi_start_h =
auto roi_end_h = static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale; static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w =
// Force too small rois to be 1 x 1 static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0 auto roi_end_h =
auto roi_width = std::max(roi_end_w - roi_start_w, 0.1f); static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Compute bin size w and h at input feature map // Force too small rois to be 1 x 1
auto bin_size_h = roi_height / static_cast<float>(pooled_height); auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0
auto bin_size_w = roi_width / static_cast<float>(pooled_width); auto roi_width = std::max(roi_end_w - roi_start_w, 0.1f);
int roi_batch_ind = 0;//rois_batch_id_data[n]; // Compute bin size w and h at input feature map
//std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl; auto bin_size_h = roi_height / static_cast<float>(pooled_height);
for(int c = 0; c < output_channels; ++c){ auto bin_size_w = roi_width / static_cast<float>(pooled_width);
for(int ph = 0; ph < pooled_height; ph++){ int roi_batch_ind = 0; // rois_batch_id_data[n];
int index = pooled_width; // std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
int nid = n * output_channels * pooled_height * pooled_width + c * pooled_width * pooled_height + ph * pooled_width; for (int c = 0; c < output_channels; ++c) {
for(int idx = 0; idx < index; idx++){ for (int ph = 0; ph < pooled_height; ph++) {
PSROIPooling<float>(input_data,input_channels,height,width,pooled_height,pooled_width, int index = pooled_width;
input_rois,output_channels,pooled_height,output_data, idx, nid, bin_size_h, bin_size_w, roi_start_h, roi_start_w, c, ph, roi_batch_ind); int nid = n * output_channels * pooled_height * pooled_width +
} c * pooled_width * pooled_height + ph * pooled_width;
} for (int idx = 0; idx < index; idx++) {
} PSROIPooling<float>(input_data, input_channels, height, width,
} pooled_height, pooled_width, input_rois,
output_channels, pooled_height, output_data, idx,
convert_to_hwc(&output_data, output_channels, pooled_height, nid, bin_size_h, bin_size_w, roi_start_h,
pooled_width, rois_num); roi_start_w, c, ph, roi_batch_ind);
out->reset_data_ptr(output_data); }
} }
}
} // namespace operators }
} // namespace paddle_mobile fpga::fpga_free(input_data);
fpga::image::convert_to_hwc(&output_data, output_channels, pooled_height,
#endif // PSROI_POOL_OP pooled_width, rois_num);
out->reset_data_ptr(output_data);
}
} // namespace operators
} // namespace paddle_mobile
#endif // PSROI_POOL_OP
...@@ -24,10 +24,8 @@ limitations under the License. */ ...@@ -24,10 +24,8 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) { bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
auto dims = param->input_x_->dims(); auto dims = param->input_x_->dims();
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0, PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned"); "data not aligned");
...@@ -58,11 +56,9 @@ bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) { ...@@ -58,11 +56,9 @@ bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
param->output_->mutable_data<float>(dims_out_new); param->output_->mutable_data<float>(dims_out_new);
return true; return true;
} }
template <typename T> template <typename T>
struct PreCalc { struct PreCalc {
int pos1; int pos1;
...@@ -77,30 +73,22 @@ struct PreCalc { ...@@ -77,30 +73,22 @@ struct PreCalc {
template <typename T> template <typename T>
void pre_calc_for_bilinear_interpolate( void pre_calc_for_bilinear_interpolate(
const int height, const int height, const int width, const int pooled_height,
const int width, const int pooled_width, const int iy_upper, const int ix_upper,
const int pooled_height, T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
const int pooled_width, int roi_bin_grid_h, int roi_bin_grid_w,
const int iy_upper, std::vector<PreCalc<T>>& pre_calc) { // NOLINT
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0; int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) { for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) { for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) { for (int iy = 0; iy < iy_upper; iy++) {
const T yy = roi_start_h + ph * bin_size_h + const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5 static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < ix_upper; ix++) { for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w + const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w); static_cast<T>(roi_bin_grid_w);
T x = xx; T x = xx;
T y = yy; T y = yy;
...@@ -128,8 +116,8 @@ void pre_calc_for_bilinear_interpolate( ...@@ -128,8 +116,8 @@ void pre_calc_for_bilinear_interpolate(
x = 0; x = 0;
} }
int y_low = (int)y; int y_low = static_cast<int>(y);
int x_low = (int)x; int x_low = static_cast<int>(x);
int y_high; int y_high;
int x_high; int x_high;
...@@ -172,22 +160,13 @@ void pre_calc_for_bilinear_interpolate( ...@@ -172,22 +160,13 @@ void pre_calc_for_bilinear_interpolate(
} }
template <typename T> template <typename T>
void ROIAlignForward( void ROIAlignForward(const int nthreads, const T* bottom_data,
const int nthreads, const T& spatial_scale, const int channels,
const T* bottom_data, const int height, const int width, const int pooled_height,
const T& spatial_scale, const int pooled_width, const int sampling_ratio,
const int channels, const T* bottom_rois, T* top_data) {
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* bottom_rois,
T* top_data) {
int n_rois = nthreads / channels / pooled_width / pooled_height; int n_rois = nthreads / channels / pooled_width / pooled_height;
for (int n = 0; n < n_rois; n++) { for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height; int index_n = n * channels * pooled_width * pooled_height;
...@@ -195,8 +174,8 @@ void ROIAlignForward( ...@@ -195,8 +174,8 @@ void ROIAlignForward(
const T* offset_bottom_rois = bottom_rois + n * 4; const T* offset_bottom_rois = bottom_rois + n * 4;
int roi_batch_ind = 0; int roi_batch_ind = 0;
// if (roi_cols == 5) { // if (roi_cols == 5) {
// roi_batch_ind = offset_bottom_rois[0]; // roi_batch_ind = offset_bottom_rois[0];
// offset_bottom_rois++; // offset_bottom_rois++;
// } // }
// Do not using rounding; this implementation detail is critical // Do not using rounding; this implementation detail is critical
...@@ -217,70 +196,58 @@ void ROIAlignForward( ...@@ -217,70 +196,58 @@ void ROIAlignForward(
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2 : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin // We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
// we want to precalculate indeces and weights shared by all chanels, // we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation // this is the key point of optimiation
std::vector<PreCalc<T>> pre_calc( std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate( pre_calc_for_bilinear_interpolate(
height, height, width, pooled_height, pooled_width, roi_bin_grid_h,
width, roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
pooled_height, roi_bin_grid_h, roi_bin_grid_w, pre_calc);
pooled_width,
roi_bin_grid_h, for (int c = 0; c < channels; c++) {
roi_bin_grid_w, int index_n_c = index_n + c * pooled_width * pooled_height;
roi_start_h, const T* offset_bottom_data =
roi_start_w, bottom_data + (roi_batch_ind * channels + c) * height * width;
bin_size_h, int pre_calc_index = 0;
bin_size_w,
roi_bin_grid_h, for (int ph = 0; ph < pooled_height; ph++) {
roi_bin_grid_w, for (int pw = 0; pw < pooled_width; pw++) {
pre_calc); int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int c = 0; c < channels; c++) { for (int iy = 0; iy < roi_bin_grid_h; iy++) {
int index_n_c = index_n + c * pooled_width * pooled_height; for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T* offset_bottom_data = PreCalc<T> pc = pre_calc[pre_calc_index];
bottom_data + (roi_batch_ind * channels + c) * height * width; output_val += pc.w1 * offset_bottom_data[pc.pos1] +
int pre_calc_index = 0; pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
for (int ph = 0; ph < pooled_height; ph++) { pc.w4 * offset_bottom_data[pc.pos4];
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw; pre_calc_index += 1;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
pc.w4 * offset_bottom_data[pc.pos4];
pre_calc_index += 1;
}
} }
output_val /= count; }
output_val /= count;
top_data[index] = output_val; top_data[index] = output_val;
} // for pw } // for pw
} // for ph } // for ph
} // for c } // for c
} // for n } // for n
} }
template <> template <>
void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& param) { void RoiAlignPoolKernel<FPGA, float>::Compute(
const RoiAlignPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get(); auto input_tensor = param.float_input.get();
fpga::PerformBypass(param.input_arg); fpga::PerformBypass(param.input_arg);
fpga::fpga_invalidate(input_tensor->data<float>(), fpga::fpga_invalidate(input_tensor->data<float>(),
input_tensor->numel() * sizeof(float)); input_tensor->numel() * sizeof(float));
...@@ -308,23 +275,22 @@ void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& par ...@@ -308,23 +275,22 @@ void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& par
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])), {rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]}); (param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new); (param.output_)->Resize(dims_out_new);
const int index = input_channels * pooled_height * pooled_width * rois_num; const int index = input_channels * pooled_height * pooled_width * rois_num;
auto rois_data = rois->data<float>(); auto rois_data = rois->data<float>();
auto top_data = param.output_->mutable_data<float>(); auto top_data = param.output_->mutable_data<float>();
for (int i = 0; i < index; ++i){ for (int i = 0; i < index; ++i) {
ROIAlignForward<float>( index,data_nhwc,spatial_scale,input_channels,height,width, ROIAlignForward<float>(index, data_nhwc, spatial_scale, input_channels,
pooled_height,pooled_width,sampe_ratio,rois_data,top_data); height, width, pooled_height, pooled_width,
sampe_ratio, rois_data, top_data);
} }
fpga::image::convert_to_hwc(&top_data, input_channels, pooled_height, fpga::image::convert_to_hwc(&top_data, input_channels, pooled_height,
pooled_width, rois_num); pooled_width, rois_num);
out->reset_data_ptr(top_data); out->reset_data_ptr(top_data);
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // ROIALIGN_POOL_OP #endif // ROIALIGN_POOL_OP
...@@ -105,7 +105,8 @@ void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) { ...@@ -105,7 +105,8 @@ void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
} else { } else {
if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) { if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) {
Tensor *out = param.Out(); Tensor *out = param.Out();
out->Resize({in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]}); out->Resize(
{in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]});
math::SoftmaxFuntor<CPU, float>()(in_x, out); math::SoftmaxFuntor<CPU, float>()(in_x, out);
} }
} }
......
...@@ -44,8 +44,9 @@ void Transpose2Kernel<FPGA, float>::Compute( ...@@ -44,8 +44,9 @@ void Transpose2Kernel<FPGA, float>::Compute(
// Transpose2Compute<float>(param); // Transpose2Compute<float>(param);
auto input = param.InputX(); auto input = param.InputX();
auto output = param.Out(); auto output = param.Out();
output->Resize({input->dims()[0], output->dims()[1], output->dims()[2], output->dims()[3]}); output->Resize({input->dims()[0], output->dims()[1], output->dims()[2],
output->dims()[3]});
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册