diff --git a/src/operators/kernel/fpga/V1/fetch_kernel.cpp b/src/operators/kernel/fpga/V1/fetch_kernel.cpp index 6fbd81ae7f527b6983e27d482498cb43f1ef93a4..fad1e77643a017659bea3c27d4475aea2c00787d 100644 --- a/src/operators/kernel/fpga/V1/fetch_kernel.cpp +++ b/src/operators/kernel/fpga/V1/fetch_kernel.cpp @@ -27,7 +27,24 @@ bool FetchKernel::Init(FetchParam *param) { output->init(typeid(float)); output->Resize(input->dims()); fpga::format_fp32_ofm(output); - + int outC = 1; + int outH = 1; + int outW = 1; + if (output->dims().size() == 4) { + outC = output->dims()[1]; + outH = output->dims()[2]; + outW = output->dims()[3]; + } else { // 2 + outC = output->dims()[1]; + } + int unalignedCW = outC * outW; + int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT); + if (alignedCW != unalignedCW) { + param->aligned_out.Resize(input->dims()); + param->aligned_out.mutable_data(input->dims()); + fpga::fpga_flush(param->aligned_out.data(), + outH * unalignedCW * sizeof(float)); + } fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; args.input_data_type = fpga::DATA_TYPE_FP16; @@ -82,19 +99,26 @@ void FetchKernel::Compute(const FetchParam ¶m) { } fpga::PerformBypass(args); - auto outC = output->dims()[1]; - auto outH = output->dims()[2]; - auto outW = output->dims()[3]; + int outC = 1; + int outH = 1; + int outW = 1; + if (output->dims().size() == 4) { + outC = output->dims()[1]; + outH = output->dims()[2]; + outW = output->dims()[3]; + } else { // 2 + outC = output->dims()[1]; + } fpga::fpga_invalidate(param.fpga_bypass_args.output.address, output->fpga_data_num * sizeof(float)); - - if (output->fpga_data_num != product(input->dims())) { - float *data_tmp = - reinterpret_cast(malloc(outC * outH * outW * sizeof(float))); - dealign(outdata_ptr, data_tmp, outC, outH, outW); - memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float)); - free(data_tmp); + int unalignedCW = outC * outW; + int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT); + if (unalignedCW != alignedCW) { + auto aligned_ptr = const_cast(param.aligned_out.data()); + dealign(outdata_ptr, aligned_ptr, outC, outH, outW); + memcpy(outdata_ptr, aligned_ptr, outC * outH * outW * sizeof(float)); + fpga::fpga_flush(outdata_ptr, outC * outH * outW * sizeof(float)); } } template class FetchKernel; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 645c288a35408d99537f68e7da7f7b3e9b546409..8679174e4cd5e3efe3abb0d8a7eff0a0c6290516 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1270,6 +1270,7 @@ class FetchParam : public OpParam { public: fpga::BypassArgs fpga_bypass_args; + Tensor aligned_out; #endif };