提交 52c6c1a8 编写于 作者: J jameswu2014

format modify

上级 2f507f76
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "fpga/V1/api.h"
#include <memory>
#include "fpga/V1/bias_scale.h"
#include "fpga/V1/deconv_filter.h"
#include "fpga/V1/filter.h"
......@@ -368,7 +369,8 @@ void expand_conv_arg(ConvArgs *arg) {
auto filter_pad_width_mul_channel =
args.image.pad_width * args.image.channels;
auto image_amount_per_row_multi_win_first =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height);
image_amount_per_row *
(ROW_PARALLEL_NUM * args.kernel.stride_h - args.image.pad_height);
auto image_amount_per_row_multi_win =
image_amount_per_row * (ROW_PARALLEL_NUM * args.kernel.stride_h);
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include <fstream>
#include <iomanip>
#include <iostream>
#include <utility>
#include "common/enforce.h"
#include "fpga/common/driver.h"
......@@ -147,8 +148,6 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) {
}
}
void memory_release(struct fpga_memory *memory) {
void *ptr = nullptr;
......@@ -160,8 +159,6 @@ void memory_release(struct fpga_memory *memory) {
}
}
uint64_t vaddr_to_paddr_driver(void *address) {
uint64_t paddr = 0;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(address);
......@@ -215,8 +212,8 @@ void *fpga_malloc_driver(size_t size) {
g_fpgainfo.fd_mem, FPGA_MEM_PHY_ADDR);
PADDLE_MOBILE_ENFORCE(ret != (void *)-1, "Should not be -1");
args.pVM= (void *)ret;
args.pPHY =(void *)0;
args.pVM = reinterpret_cast<void *>(ret);
args.pPHY = reinterpret_cast<void *>(0);
do_ioctl(IOCTL_MEMORY_VM2PHY, &args);
phy_addr = (uint64_t)args.pPHY;
......@@ -237,9 +234,8 @@ void fpga_free_driver(void *ptr) {
g_fpgainfo.fpga_addr2size_map.erase(iter);
munmap(ptr, size);
p_addr = vaddr_to_paddr_driver(ptr);
pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE;
// p_addr = vaddr_to_paddr_driver(ptr);
// pos = (p_addr - g_fpgainfo.memory_info->mem_start) / FPGA_PAGE_SIZE;
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(ptr);
if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) {
......@@ -299,7 +295,7 @@ int open_device_driver() {
g_fpgainfo.FpgaRegVirAddr =
(uint64_t *)fpga_reg_malloc(FPGA_REG_SIZE); // NOLINT
//fpga_memory_add();
// fpga_memory_add();
pl_init();
......
......@@ -53,8 +53,8 @@ struct MemoryCacheArgs {
};
struct MemoryVM2PHYArgs {
void* pVM;
void* pPHY;
void *pVM;
void *pPHY;
};
#define IOCTL_FPGA_MAGIC 'F'
......@@ -62,7 +62,6 @@ struct MemoryVM2PHYArgs {
#define IOCTL_MEMCACHE_FLUSH _IOW(IOCTL_FPGA_MAGIC, 13, struct MemoryCacheArgs)
#define IOCTL_MEMORY_VM2PHY _IOWR(IOCTL_FPGA_MAGIC, 15, struct MemoryVM2PHYArgs)
struct fpga_pe {
char type_name[MAX_TYPE_NAME_LENTH + 1];
struct pe_data_s *outer;
......
......@@ -82,7 +82,6 @@ void RoiAlignPoolOp<DeviceType, T>::InferShape() const {
}
#endif
#ifdef ROI_PERSPECTIVE_OP
template <typename DeviceType, typename T>
void RoiPerspectiveOp<DeviceType, T>::InferShape() const {
......
......@@ -38,7 +38,6 @@ DECLARE_OPERATOR(PSRoiPool, PSRoiPoolParam, PSRoiPoolKernel);
DECLARE_OPERATOR(RoiAlignPool, RoiAlignPoolParam, RoiAlignPoolKernel);
#endif
#ifdef ROI_PERSPECTIVE_OP
DECLARE_OPERATOR(RoiPerspective, RoiPerspectiveParam, RoiPerspectiveKernel);
#endif
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -157,13 +158,15 @@ DECLARE_KERNEL(PSRoiPool, PSRoiPoolParam);
template <typename Dtype>
class RoiAlignPoolParam : public OpParam {
public:
RoiAlignPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope *scope)
RoiAlignPoolParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
pooled_width_ = OpParam::GetAttr<int>("pooled_width", attrs);
......@@ -183,7 +186,6 @@ class RoiAlignPoolParam : public OpParam {
std::shared_ptr<Tensor> float_input, float_output;
fpga::BypassArgs input_arg, output_arg;
#endif
};
DECLARE_KERNEL(RoiAlignPool, RoiAlignPoolParam);
......
......@@ -56,7 +56,7 @@ void dealign(float *src, float *dst, int input_c, int input_h, int input_w) {
}
template <>
void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto input = param.InputX();
auto input = const_cast<Tensor *>(param.InputX());
if (input->type() == typeid(float)) {
auto output = param.Out();
output->ShareDataWith(*input);
......@@ -74,14 +74,13 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
param.Out()->fpga_data_num * sizeof(float));
if(param.Out()->fpga_data_num != product(input->dims())){
if (param.Out()->fpga_data_num != product(input->dims())) {
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp);
}
}
template class FetchKernel<FPGA, float>;
......
......@@ -74,10 +74,11 @@ void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) {
auto *output = param.Output();
auto in = input->data<float>();
auto N = input->dims()[0];
output->Resize({N, output->dims()[1], output->dims()[2], output->dims()[3]});
output->Resize(
{N, output->dims()[1], output->dims()[2], output->dims()[3]});
auto len = output->numel();
auto out = output->mutable_data<float>();
int C = input->dims()[1], H = input->dims()[2],//N = input->dims()[0],
int C = input->dims()[1], H = input->dims()[2], // N = input->dims()[0],
W = input->dims()[3];
int HW = H * W, CHW = C * H * W, WC = W * C;
......
......@@ -65,14 +65,13 @@ bool ProposalKernel<FPGA, float>::Init(ProposalParam<FPGA> *param) {
args.output.scale_address = param->float_score->scale;
param->score_arg = args;
param->score_index_= std::make_shared<Tensor>();
param->score_index_ = std::make_shared<Tensor>();
param->score_index_->mutable_data<int32_t>({input->numel()});
auto score_index = param->score_index_->data<int32_t>();
for (int i = 0; i < input->numel(); ++i){
for (int i = 0; i < input->numel(); ++i) {
score_index[i] = i;
}
return true;
}
template <typename T>
......@@ -342,9 +341,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
const Tensor &im_info_slice, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
const Tensor &score_index,
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) {
const Tensor &score_index, int pre_nms_top_n, int post_nms_top_n,
float nms_thresh, float min_size, float eta) {
auto *scores_data = scores_slice.data<T>();
// Sort index
......@@ -354,7 +352,8 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
/*for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}*/
std::memcpy(index,score_index.data<int32_t>(),scores_slice.numel()*sizeof(int) );
std::memcpy(index, score_index.data<int32_t>(),
scores_slice.numel() * sizeof(int));
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
......@@ -504,7 +503,7 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
auto score_index = *(param.score_index_.get());
int pre_nms_top_n = param.pre_nms_topn_;
int post_nms_top_n = 100;//param.post_nms_topn_;
int post_nms_top_n = 100; // param.post_nms_topn_;
float nms_thresh = param.nms_thresh_;
float min_size = param.min_size_;
float eta = param.eta_;
......@@ -541,8 +540,8 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = ProposalForOneImage<float>(
im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,score_index,
pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta);
im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice,
score_index, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef PSROI_POOL_OP
#include <cmath>
#include <memory>
#include <vector>
#include "operators/kernel/detection_kernel.h"
......@@ -72,22 +73,17 @@ bool PSRoiPoolKernel<FPGA, float>::Init(PSRoiPoolParam<FPGA>* param) {
}
template <typename Dtype>
void PSROIPooling(
const Dtype* bottom_data, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const Dtype* bottom_rois, const int output_dim,
const int group_size, Dtype* top_data,
int index, int nid,
const Dtype Bin_size_h,
const Dtype Bin_size_w,
const Dtype roi_start_h,
const Dtype roi_start_w,
const int ctop, const int ph, const int roi_batch_ind)
{
void PSROIPooling(const Dtype* bottom_data, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const Dtype* bottom_rois,
const int output_dim, const int group_size, Dtype* top_data,
int index, int nid, const Dtype Bin_size_h,
const Dtype Bin_size_w, const Dtype roi_start_h,
const Dtype roi_start_w, const int ctop, const int ph,
const int roi_batch_ind) {
int pw = index;
int hstart = floor(static_cast<Dtype>(ph) * Bin_size_h + roi_start_h);
int wstart = floor(static_cast<Dtype>(pw)* Bin_size_w + roi_start_w);
int wstart = floor(static_cast<Dtype>(pw) * Bin_size_w + roi_start_w);
int hend = ceil(static_cast<Dtype>(ph + 1) * Bin_size_h + roi_start_h);
int wend = ceil(static_cast<Dtype>(pw + 1) * Bin_size_w + roi_start_w);
......@@ -98,9 +94,9 @@ const int ctop, const int ph, const int roi_batch_ind)
wend = std::min(std::max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c = (ctop*group_size + ph)*group_size + pw;
int c = (ctop * group_size + ph) * group_size + pw;
Dtype bin_area = (hend - hstart)*(wend - wstart);
Dtype bin_area = (hend - hstart) * (wend - wstart);
bottom_data += (roi_batch_ind * channels + c) * height * width;
Dtype out_sum = 0;
for (int h = hstart; h < hend; ++h) {
......@@ -110,15 +106,14 @@ const int ctop, const int ph, const int roi_batch_ind)
}
}
top_data[nid + index] = is_empty? 0. : out_sum/bin_area;
top_data[nid + index] = is_empty ? 0. : out_sum / bin_area;
}
void convert_to_chw(float **data_in, int channel, int height, int width,
void convert_to_chw(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float *data_tmp =
(float *)fpga::fpga_malloc(channel * height * width * sizeof(float)); // NOLINT
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(channel * height * width * sizeof(float))); // NOLINT
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
......@@ -134,10 +129,10 @@ void convert_to_chw(float **data_in, int channel, int height, int width,
fpga::fpga_free(data_in_tmp);
}
void convert_to_hwc(float **data_in, int channel, int height, int width,
void convert_to_hwc(float** data_in, int channel, int height, int width,
int num) {
float* data_in_tmp = *data_in;
float *data_tmp = reinterpret_cast<float *>(
float* data_tmp = reinterpret_cast<float*>(
fpga::fpga_malloc(num * channel * height * width * sizeof(float)));
int64_t amount_per_row = width * channel;
for (int n = 0; n < num; n++) {
......@@ -155,7 +150,6 @@ void convert_to_hwc(float **data_in, int channel, int height, int width,
fpga::fpga_free(data_in_tmp);
}
template <>
void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
......@@ -180,13 +174,14 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
int rois_num = rois->dims()[0];
auto data_nhwc = in->mutable_data<float>();
convert_to_chw(&data_nhwc, input_channels, height, width, 1);
fpga::image::convert_to_chw(&data_nhwc, input_channels, height, width, 1);
framework::DDim dims_out_new = framework::make_ddim(
{rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])),
(param.output_)->dims()[3]});
(param.output_)->Resize(dims_out_new);
const float* input_data = data_nhwc; // in->data<float>();
float* input_data = data_nhwc; // in->data<float>();
// shared_ptr<float> input_data(data_nhwc);
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
auto rois_batch_id_data = rois_batch_id_list.mutable_data<int>();
......@@ -208,9 +203,9 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
//for (int n = 0; n < rois_batch_size; ++n) {
//for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
//rois_batch_id_data[i] = n;
// for (int n = 0; n < rois_batch_size; ++n) {
// for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
// rois_batch_id_data[i] = n;
// }
//}
auto output_data = out->mutable_data<float>();
......@@ -220,10 +215,14 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
for (int n = 0; n < rois_num; ++n) {
// [start, end) interval for spatial sampling
auto offset_input_rois = input_rois + n * 4;
auto roi_start_w = static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h = static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w = static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_end_h = static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
auto roi_start_w =
static_cast<float>(round(offset_input_rois[0])) * spatial_scale;
auto roi_start_h =
static_cast<float>(round(offset_input_rois[1])) * spatial_scale;
auto roi_end_w =
static_cast<float>(round(offset_input_rois[2]) + 1.) * spatial_scale;
auto roi_end_h =
static_cast<float>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
auto roi_height = std::max(roi_end_h - roi_start_h, 0.1f); // avoid 0
......@@ -233,22 +232,25 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
auto bin_size_h = roi_height / static_cast<float>(pooled_height);
auto bin_size_w = roi_width / static_cast<float>(pooled_width);
int roi_batch_ind = 0;//rois_batch_id_data[n];
//std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
for(int c = 0; c < output_channels; ++c){
for(int ph = 0; ph < pooled_height; ph++){
int roi_batch_ind = 0; // rois_batch_id_data[n];
// std::cout << "roi_batch_ind: " << roi_batch_ind << std::endl;
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < pooled_height; ph++) {
int index = pooled_width;
int nid = n * output_channels * pooled_height * pooled_width + c * pooled_width * pooled_height + ph * pooled_width;
for(int idx = 0; idx < index; idx++){
PSROIPooling<float>(input_data,input_channels,height,width,pooled_height,pooled_width,
input_rois,output_channels,pooled_height,output_data, idx, nid, bin_size_h, bin_size_w, roi_start_h, roi_start_w, c, ph, roi_batch_ind);
int nid = n * output_channels * pooled_height * pooled_width +
c * pooled_width * pooled_height + ph * pooled_width;
for (int idx = 0; idx < index; idx++) {
PSROIPooling<float>(input_data, input_channels, height, width,
pooled_height, pooled_width, input_rois,
output_channels, pooled_height, output_data, idx,
nid, bin_size_h, bin_size_w, roi_start_h,
roi_start_w, c, ph, roi_batch_ind);
}
}
}
}
convert_to_hwc(&output_data, output_channels, pooled_height,
fpga::fpga_free(input_data);
fpga::image::convert_to_hwc(&output_data, output_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(output_data);
}
......@@ -257,4 +259,3 @@ void PSRoiPoolKernel<FPGA, float>::Compute(const PSRoiPoolParam<FPGA>& param) {
} // namespace paddle_mobile
#endif // PSROI_POOL_OP
......@@ -24,10 +24,8 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
auto dims = param->input_x_->dims();
PADDLE_MOBILE_ENFORCE(dims[1] * dims[3] % IMAGE_ALIGNMENT == 0,
"data not aligned");
......@@ -58,11 +56,9 @@ bool RoiAlignPoolKernel<FPGA, float>::Init(RoiAlignPoolParam<FPGA>* param) {
param->output_->mutable_data<float>(dims_out_new);
return true;
}
template <typename T>
struct PreCalc {
int pos1;
......@@ -77,19 +73,11 @@ struct PreCalc {
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
const int height, const int width, const int pooled_height,
const int pooled_width, const int iy_upper, const int ix_upper,
T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
int roi_bin_grid_h, int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) { // NOLINT
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
......@@ -128,8 +116,8 @@ void pre_calc_for_bilinear_interpolate(
x = 0;
}
int y_low = (int)y;
int x_low = (int)x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
......@@ -172,22 +160,13 @@ void pre_calc_for_bilinear_interpolate(
}
template <typename T>
void ROIAlignForward(
const int nthreads,
const T* bottom_data,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* bottom_rois,
T* top_data) {
void ROIAlignForward(const int nthreads, const T* bottom_data,
const T& spatial_scale, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int sampling_ratio,
const T* bottom_rois, T* top_data) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
......@@ -227,23 +206,12 @@ void ROIAlignForward(
// we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation
std::vector<PreCalc<T>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate(
height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_start_h,
roi_start_w,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
pre_calc);
height, width, pooled_height, pooled_width, roi_bin_grid_h,
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
roi_bin_grid_h, roi_bin_grid_w, pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
......@@ -276,10 +244,9 @@ void ROIAlignForward(
} // for n
}
template <>
void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& param) {
void RoiAlignPoolKernel<FPGA, float>::Compute(
const RoiAlignPoolParam<FPGA>& param) {
auto input_tensor = param.float_input.get();
fpga::PerformBypass(param.input_arg);
fpga::fpga_invalidate(input_tensor->data<float>(),
......@@ -312,19 +279,18 @@ void RoiAlignPoolKernel<FPGA, float>::Compute(const RoiAlignPoolParam<FPGA>& par
const int index = input_channels * pooled_height * pooled_width * rois_num;
auto rois_data = rois->data<float>();
auto top_data = param.output_->mutable_data<float>();
for (int i = 0; i < index; ++i){
ROIAlignForward<float>( index,data_nhwc,spatial_scale,input_channels,height,width,
pooled_height,pooled_width,sampe_ratio,rois_data,top_data);
for (int i = 0; i < index; ++i) {
ROIAlignForward<float>(index, data_nhwc, spatial_scale, input_channels,
height, width, pooled_height, pooled_width,
sampe_ratio, rois_data, top_data);
}
fpga::image::convert_to_hwc(&top_data, input_channels, pooled_height,
pooled_width, rois_num);
out->reset_data_ptr(top_data);
}
} // namespace operators
} // namespace paddle_mobile
#endif // ROIALIGN_POOL_OP
......@@ -105,7 +105,8 @@ void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
} else {
if (param.FpgaArgs().output.activation.activation_type != fpga::SOFTMAX) {
Tensor *out = param.Out();
out->Resize({in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]});
out->Resize(
{in_x->dims()[0], out->dims()[1], out->dims()[2], out->dims()[3]});
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
}
......
......@@ -45,7 +45,8 @@ void Transpose2Kernel<FPGA, float>::Compute(
auto input = param.InputX();
auto output = param.Out();
output->Resize({input->dims()[0], output->dims()[1], output->dims()[2], output->dims()[3]});
output->Resize({input->dims()[0], output->dims()[1], output->dims()[2],
output->dims()[3]});
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册