diff --git a/mobile/src/fpga/V2/api.cpp b/mobile/src/fpga/V2/api.cpp index f1d19364f89cfa7118397ab7f33db66c3a78785d..f39d012e08c124feacbd72fa2879e60b352c2785 100644 --- a/mobile/src/fpga/V2/api.cpp +++ b/mobile/src/fpga/V2/api.cpp @@ -359,7 +359,7 @@ void expand_conv_arg(ConvArgs *arg) { if (((res_win % 2) != 0) && (res_win != 1)) { res_win = res_win - 1; } - PADDLE_MOBILE_ENFORCE(res_win >= 2, "window too bigger than fpga volume"); + // PADDLE_MOBILE_ENFORCE(res_win >= 2, "window too bigger than fpga volume"); res_fit = res_win; auto block_num = (output_width + res_fit - 1) / res_fit; @@ -885,7 +885,7 @@ void fill_dwconv_arg(struct DWconvArgs *arg, framework::Tensor *input, int padding_h, int padding_w, float *bias_ptr) { auto filter_ptr = filter->data(); auto input_ptr = input->data(); - auto output_ptr = out->mutable_data(); + auto output_ptr = out->data(); arg->sub_conv_num = 1; arg->relu_enabled = relu_enabled; // arg->output.activation.activation_type = activation_enable; diff --git a/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp index ecc2577bd6ba9f8f21d4cccb94bdc27466b4a5d1..f19eea9915971a45cb5212345cbdf3cb413d539a 100644 --- a/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp @@ -30,16 +30,12 @@ bool ProposalKernel::Init(ProposalParam *param) { int64_t batch = param->scores_->dims()[0]; auto total = post_nms_top_n * batch; param->rpn_rois_->mutable_data({total, 4}); - param->rpn_probs_->mutable_data({total, 1}); + param->rpn_probs_->mutable_data({total, 1}); param->float_bbox = std::make_shared(); param->float_bbox->Resize(param->bbox_deltas_->dims()); param->float_bbox->init(type_id().hash_code()); fpga::format_fp32_ofm(param->float_bbox.get()); - param->float_score = std::make_shared(); - param->float_score->Resize(param->scores_->dims()); - param->float_score->init(type_id().hash_code()); - fpga::format_fp32_ofm(param->float_score.get()); auto input = param->scores_; param->score_index_ = std::make_shared(); @@ -87,8 +83,8 @@ void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) { } template -static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, - Tensor *variances, Tensor *proposals) { +static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, + Tensor *proposals) { T *proposals_data = proposals->mutable_data(); int64_t row = all_anchors->dims()[0]; @@ -96,10 +92,6 @@ static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, auto *bbox_deltas_data = bbox_deltas->data(); auto *anchor_data = all_anchors->data(); - const T *variances_data = nullptr; - if (variances) { - variances_data = variances->data(); - } for (int64_t i = 0; i < row; ++i) { T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0; @@ -244,10 +236,10 @@ static inline Tensor NMS(Tensor *bbox, Tensor *scores, T nms_threshold, // 4: [xmin ymin xmax ymax] int64_t box_size = bbox->dims()[1]; - std::vector scores_data(num_boxes); - std::copy_n(scores->data(), num_boxes, scores_data.begin()); - std::vector> sorted_indices = - GetSortedScoreIndex(scores_data); + std::vector scores_data(num_boxes); + std::copy_n(scores->data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices = + GetSortedScoreIndex(scores_data); std::vector selected_indices; int selected_num = 0; @@ -284,8 +276,7 @@ std::pair ProposalForOneImage( const Tensor &scores_slice, // [N, 1] const Tensor &score_index, int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, float eta) { - auto *scores_data = scores_slice.data(); - + auto *scores_data = scores_slice.data(); // Sort index Tensor index_t; index_t.Resize({scores_slice.numel()}); @@ -306,17 +297,17 @@ std::pair ProposalForOneImage( } Tensor scores_sel, bbox_sel, anchor_sel, var_sel; - scores_sel.mutable_data({index_t.numel(), 1}); + scores_sel.mutable_data({index_t.numel(), 1}); bbox_sel.mutable_data({index_t.numel(), 4}); anchor_sel.mutable_data({index_t.numel(), 4}); var_sel.mutable_data({index_t.numel(), 4}); - CPUGather(scores_slice, index_t, &scores_sel); + CPUGather(scores_slice, index_t, &scores_sel); CPUGather(bbox_deltas_slice, index_t, &bbox_sel); CPUGather(anchors, index_t, &anchor_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}); - BoxCoder(&anchor_sel, &bbox_sel, nullptr, &proposals); + BoxCoder(&anchor_sel, &bbox_sel, &proposals); ClipTiledBoxes(im_info_slice, &proposals); @@ -325,10 +316,10 @@ std::pair ProposalForOneImage( Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}); - scores_filter.mutable_data({keep.numel(), 1}); + scores_filter.mutable_data({keep.numel(), 1}); CPUGather(proposals, keep, &bbox_sel); - CPUGather(scores_sel, keep, &scores_filter); + CPUGather(scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } @@ -341,10 +332,10 @@ std::pair ProposalForOneImage( } proposals.mutable_data({keep_nms.numel(), 4}); // original - scores_sel.mutable_data({keep_nms.numel(), 1}); // original + scores_sel.mutable_data({keep_nms.numel(), 1}); // original CPUGather(bbox_sel, keep_nms, &proposals); - CPUGather(scores_filter, keep_nms, &scores_sel); + CPUGather(scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } @@ -368,69 +359,41 @@ void ProposalKernel::Compute(const ProposalParam ¶m) { bbox_height = (uint32_t)(input_bbox->dims()[2]); bbox_width = (uint32_t)(input_bbox->dims()[3]); - std::shared_ptr score_tmp = std::make_shared(); - score_tmp->Resize(param.scores_->dims()); - score_tmp->mutable_data(); - - std::shared_ptr bbox_tmp = std::make_shared(); - bbox_tmp->Resize(param.bbox_deltas_->dims()); - bbox_tmp->mutable_data(); - - auto score_tmp_data = score_tmp->data(); - auto bbox_tmp_data = bbox_tmp->data(); int64_t amount_per_side = score_width * score_height; - int idx = 0; + int alignedCW = fpga::align_to_x(score_width * score_channels, IMAGE_ALIGNMENT); int unalignedCW = score_width * score_channels; fpga::fpga_invalidate(input_score_data, score_height * alignedCW * sizeof(int8_t)); - for (int h = 0; h < score_height; h++) { - for (int w = 0; w < score_width; w++) { - for (int c = 0; c < score_channels; c++) { - if (alignedCW == unalignedCW) { - *(score_tmp_data + c * amount_per_side + score_width * h + w) = - (*(input_score_data++)); - } else { - idx = h * alignedCW + w * score_channels + c; - *(score_tmp_data + c * amount_per_side + score_width * h + w) = - input_score_data[idx]; - } + + Tensor score_tensor = *input_score; + for(int h = 0; h < score_height; h++){ + for(int w = 0; w < score_width; w++){ + for (int c = 0; c < score_channels; ++c) { + int dstidx = h*unalignedCW + w*score_channels + c; + int srcidx = h*alignedCW + w*score_channels + c; + score_tensor.data()[dstidx] = input_score_data[srcidx]; } } } + amount_per_side = bbox_width * bbox_height; alignedCW = fpga::align_to_x(bbox_width * bbox_channels, IMAGE_ALIGNMENT); unalignedCW = bbox_width * bbox_channels; fpga::fpga_invalidate(input_bbox_data, bbox_height * alignedCW * sizeof(int8_t)); - for (int h = 0; h < bbox_height; h++) { - for (int w = 0; w < bbox_width; w++) { - for (int c = 0; c < bbox_channels; c++) { - if (alignedCW == unalignedCW) { - *(bbox_tmp_data + c * amount_per_side + bbox_width * h + w) = - (*(input_bbox_data++)); - } else { - idx = h * alignedCW + w * bbox_channels + c; - *(bbox_tmp_data + c * amount_per_side + bbox_width * h + w) = - input_bbox_data[idx]; - } - } - } - } - auto score_tensor = param.float_score.get(); - for (int i = 0; i < score_height * score_width * score_channels; i++) { - score_tensor->data()[i] = - score_tmp_data[i] / 127.0 * input_score->scale[0]; - } auto bbox_tensor = param.float_bbox.get(); - for (int i = 0; i < bbox_height * bbox_width * bbox_channels; i++) { - bbox_tensor->data()[i] = - bbox_tmp_data[i] / 127.0 * input_bbox->scale[0]; + for(int h = 0; h < bbox_height; h++){ + for(int w = 0; w < bbox_width; w++){ + for (int c = 0; c < bbox_channels; ++c) { + int dstidx = h*unalignedCW + w*bbox_channels + c; + int srcidx = h*alignedCW + w*bbox_channels + c; + bbox_tensor->data()[dstidx] = ((int)(input_bbox_data[srcidx])) / 127.0 * input_bbox->scale[0]; + } + } } - auto *scores = param.float_score.get(); - auto *bbox_deltas = param.float_bbox.get(); auto *im_info = param.im_info_; auto anchors = *param.anchors_; auto variances = *param.variances_; @@ -447,37 +410,23 @@ void ProposalKernel::Compute(const ProposalParam ¶m) { float min_size = param.min_size_; float eta = param.eta_; - auto &scores_dim = scores->dims(); - int64_t num = scores_dim[0]; - int64_t c_score = scores_dim[1]; - int64_t h_score = scores_dim[2]; - int64_t w_score = scores_dim[3]; - - auto &bbox_dim = bbox_deltas->dims(); - int64_t c_bbox = bbox_dim[1]; - int64_t h_bbox = bbox_dim[2]; - int64_t w_bbox = bbox_dim[3]; - - // - rpn_rois->mutable_data({bbox_deltas->numel(), 4}); - rpn_roi_probs->mutable_data({scores->numel(), 1}); - + rpn_rois->mutable_data({bbox_tensor->numel()/4, 4}); + rpn_roi_probs->mutable_data({input_score->numel()/4, 1}); framework::LoD lod; lod.resize(1); auto &lod0 = lod[0]; lod0.push_back(0); - anchors.Resize({anchors.numel(), 4}); - variances.Resize({variances.numel(), 4}); + anchors.Resize({anchors.numel()/4, 4}); + variances.Resize({variances.numel()/4, 4}); int64_t num_proposals = 0; - for (int64_t i = 0; i < num; ++i) { + for (int64_t i = 0; i < score_n; ++i) { Tensor im_info_slice = im_info->Slice(i, i + 1); Tensor bbox_deltas_slice = (*bbox_tensor).Slice(i, i + 1); - Tensor scores_slice = (*score_tensor).Slice(i, i + 1); - - bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox, 4}); - scores_slice.Resize({h_score * w_score * c_score, 1}); + Tensor scores_slice = score_tensor.Slice(i, i + 1); + bbox_deltas_slice.Resize({bbox_height * bbox_width * bbox_channels / 4, 4}); + scores_slice.Resize({score_height * score_width * score_channels, 1}); std::pair tensor_pair = ProposalForOneImage( im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice, score_index, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta); diff --git a/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp index b8b5202e27369a74430aa130db68501ff6891eec..55816742e7ca2670ed2e2896430dff68afe5d104 100644 --- a/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp @@ -44,14 +44,14 @@ bool PSRoiPoolKernel::Init(PSRoiPoolParam* param) { } template -void PSROIPoolingForward(const Dtype* bottom_data, const int height, +void PSROIPoolingForward(const int8_t* bottom_data, const int height, const int width, const int input_channel, Dtype* top_data, const int pooled_height, const int pooled_width, const int output_channel, const Dtype* bottom_rois, const Dtype Bin_size_h, const Dtype Bin_size_w, const Dtype roi_start_h, const Dtype roi_start_w, const int pw, const int ph, - const int roi_batch_ind) { + float scale, const int roi_batch_ind) { int hstart = floor(static_cast(ph) * Bin_size_h + roi_start_h); int wstart = floor(static_cast(pw) * Bin_size_w + roi_start_w); int hend = ceil(static_cast(ph + 1) * Bin_size_h + roi_start_h); @@ -64,11 +64,12 @@ void PSROIPoolingForward(const Dtype* bottom_data, const int height, wend = std::min(std::max(wend, 0), width); bool is_empty = (hend <= hstart) || (wend <= wstart); - float sum_pixels_c[output_channel] = {0}; - float pixels_c[output_channel] = {0}; + float avg_pixels_c[output_channel] = {0}; + int sum_pixels_c[output_channel] = {0}; + int8_t pixels_c[output_channel] = {0}; if (!is_empty) { Dtype bin_area = (hend - hstart) * (wend - wstart); - float rec_bin_area = 1 / bin_area; + float scale_fuse = scale / bin_area; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -86,27 +87,21 @@ void PSROIPoolingForward(const Dtype* bottom_data, const int height, } } for (int output_c = 0; output_c < output_channel; output_c++) { - sum_pixels_c[output_c] *= rec_bin_area; + avg_pixels_c[output_c] = sum_pixels_c[output_c] * scale_fuse; } } int output_index_base = (ph * pooled_width + pw) * output_channel; top_data += output_index_base; - memcpy(top_data, sum_pixels_c, output_channel * 4); + memcpy(top_data, avg_pixels_c, output_channel * 4); } template <> void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { auto input_tensor = param.input_x_; auto input_data = input_tensor->data(); - auto Si = input_tensor->scale[0]; - auto float_input_tensor = param.float_input.get(); - auto float_input_data = float_input_tensor->data(); - for (int i = 0; i < float_input_tensor->numel(); i++) { - float_input_data[i] = input_data[i] / 127.0 * Si; - } - - auto* in = float_input_tensor; + auto scale = input_tensor->scale[0] / 127.0; + fpga::fpga_invalidate(input_data, input_tensor->numel() * sizeof(int8_t)); auto* rois = param.input_rois_; auto* out = param.output_; @@ -115,22 +110,19 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { auto spatial_scale = param.spatial_scale_; auto output_channels = param.output_channels_; - auto in_dims = in->dims(); + auto in_dims = input_tensor->dims(); int batch_size = in_dims[0]; int input_channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; - auto data_nhwc = in->mutable_data(); - framework::DDim dims_out_new = framework::make_ddim( {rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])), (param.output_)->dims()[3]}); (param.output_)->Resize(dims_out_new); - const float* input_data_tmp = data_nhwc; // in->data(); framework::Tensor rois_batch_id_list; rois_batch_id_list.Resize({rois_num}); auto rois_batch_id_data = rois_batch_id_list.mutable_data(); @@ -151,12 +143,7 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { "the channels of input X should equal the product of " "output_channels x pooled_height x pooled_width"); - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } + auto output_data = out->mutable_data(); auto input_rois = rois->data(); @@ -187,10 +174,10 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { PSROIPoolingForward( - input_data_tmp, height, width, input_channels, offset_output_data, + input_data, height, width, input_channels, offset_output_data, pooled_height, pooled_width, output_channels, input_rois, - bin_size_h, bin_size_w, roi_start_h, roi_start_w, pw, ph, - roi_batch_ind); + bin_size_h, bin_size_w, roi_start_h, roi_start_w, pw, ph, + scale, roi_batch_ind); } } } diff --git a/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp index ebaf3759400c60c9ecf36467d0eeb7adad140f46..ee10fff5a120792894bd951eb709fd753a076260 100644 --- a/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp @@ -25,6 +25,7 @@ bool Reshape2Kernel::Init(Reshape2Param *param) { auto input = const_cast(param->InputX()); auto output = param->Out(); auto shape = param->Shape(); + output->scale[0] = input->scale[0]; auto num_in = framework::product(input->dims()); auto num_shape = framework::product(framework::make_ddim(shape)); @@ -92,6 +93,29 @@ void reshape(LoDTensor *input, LoDTensor *output) { fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(int8_t)); } +static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ + int input_dims_size = input_dims.size(); + int output_dims_size = output_dims.size(); + bool dims_flag2 = true; + auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; + int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; + for(int i = 0; i < temp_dims.size(); ++i){ + if(i < short_dims){ + if(input_dims[i] != output_dims[i]){ + dims_flag2 = false; + break; + } + } + else{ + if(temp_dims[i] != 1){ + dims_flag2 = false; + break; + } + } + } + return dims_flag2; + } + template <> void Reshape2Kernel::Compute(const Reshape2Param ¶m) { auto input = const_cast(param.InputX()); @@ -109,7 +133,17 @@ void Reshape2Kernel::Compute(const Reshape2Param ¶m) { } } output->Resize(framework::make_ddim(shape)); - if (output->dims() == input->dims()) { + auto input_dims = input->dims(); + auto output_dims = output->dims(); + + bool dims_flags = input_dims == output_dims; + bool dims_flag2 = true; + + if(!dims_flags){ + dims_flag2 = reshape2_judge(input_dims, output_dims); + } + + if (dims_flags || dims_flag2) { DLOG << "No need to reshape"; output->ShareDataWith(*input); framework::LoD lod = input->lod(); diff --git a/mobile/src/operators/kernel/fpga/V2/sigmoid_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/sigmoid_kernel.cpp index 194fd5a30565b866ca702b296981d0b8302a1c16..2235365bc3fe91693d84cf844fbdd035865a766c 100644 --- a/mobile/src/operators/kernel/fpga/V2/sigmoid_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/sigmoid_kernel.cpp @@ -21,12 +21,11 @@ namespace operators { template <> bool SigmoidKernel::Init(SigmoidParam *param) { - auto input = const_cast(param->InputX()); - auto input_ptr = input->data(); paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::SIGMOID; - int16_t leaky_relu_negative_slope = - fpga::fp32_2_fp16(input->scale[0] / 127.0); + int16_t leaky_relu_negative_slope = 0; + auto input = const_cast(param->InputX()); + auto input_ptr = input->data(); auto out = param->Out(); fpga::format_ofm(out); diff --git a/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp index b7615a8891b8292dd4d65c15955a0ee640c2f770..843f249c683717789999db733a04b3da0198bdcb 100755 --- a/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp @@ -81,6 +81,7 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { auto w = 1; auto c = 1; if (dims.size() == 4) { + n = dims[0]; h = dims[1]; w = dims[2]; c = dims[3]; @@ -90,6 +91,7 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { h = 1; } } else if (dims.size() == 2) { + n = dims[0]; c = dims[1]; } if ((c == 2) && (in_x->type() == type_id())) {