From a7cd737fcc05c00a94f3f6044c52d1666ea9cabf Mon Sep 17 00:00:00 2001 From: qnqinan Date: Thu, 21 Mar 2019 16:06:04 +0800 Subject: [PATCH] fix the dealign bug in fetch op FPGA track --- src/operators/kernel/fpga/V1/fetch_kernel.cpp | 45 ++++++++++++++----- src/operators/op_param.h | 1 + 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/operators/kernel/fpga/V1/fetch_kernel.cpp b/src/operators/kernel/fpga/V1/fetch_kernel.cpp index 6fbd81ae7f..a0bd58160a 100644 --- a/src/operators/kernel/fpga/V1/fetch_kernel.cpp +++ b/src/operators/kernel/fpga/V1/fetch_kernel.cpp @@ -27,7 +27,23 @@ bool FetchKernel::Init(FetchParam *param) { output->init(typeid(float)); output->Resize(input->dims()); fpga::format_fp32_ofm(output); - + int outC = 1; + int outH = 1; + int outW = 1; + if(output->dims().size() == 4){ + outC = output->dims()[1]; + outH = output->dims()[2]; + outW = output->dims()[3]; + }else{//2 + outC = output->dims()[1]; + } + int unalignedCW = outC * outW; + int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT); + if(alignedCW != unalignedCW){ + param->aligned_out.Resize(input->dims()); + param->aligned_out.mutable_data(input->dims()); + fpga::fpga_flush(param->aligned_out.data(), outH*unalignedCW*sizeof(float)); + } fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; args.input_data_type = fpga::DATA_TYPE_FP16; @@ -82,19 +98,26 @@ void FetchKernel::Compute(const FetchParam ¶m) { } fpga::PerformBypass(args); - auto outC = output->dims()[1]; - auto outH = output->dims()[2]; - auto outW = output->dims()[3]; + int outC = 1; + int outH = 1; + int outW = 1; + if(output->dims().size() == 4){ + outC = output->dims()[1]; + outH = output->dims()[2]; + outW = output->dims()[3]; + }else{//2 + outC = output->dims()[1]; + } fpga::fpga_invalidate(param.fpga_bypass_args.output.address, output->fpga_data_num * sizeof(float)); - - if (output->fpga_data_num != product(input->dims())) { - float *data_tmp = - reinterpret_cast(malloc(outC * outH * outW * sizeof(float))); - dealign(outdata_ptr, data_tmp, outC, outH, outW); - memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float)); - free(data_tmp); + int unalignedCW = outC * outW; + int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT); + if(unalignedCW != alignedCW){ + auto aligned_ptr = const_cast(param.aligned_out.data()); + dealign(outdata_ptr, aligned_ptr, outC, outH, outW); + memcpy(outdata_ptr, aligned_ptr, outC * outH * outW * sizeof(float)); + fpga::fpga_flush(outdata_ptr, outC * outH * outW * sizeof(float)); } } template class FetchKernel; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 645c288a35..8679174e4c 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1270,6 +1270,7 @@ class FetchParam : public OpParam { public: fpga::BypassArgs fpga_bypass_args; + Tensor aligned_out; #endif }; -- GitLab