提交 bad1792c 编写于 作者: qnqinan's avatar qnqinan

fix the dealign bug in fetch op FPGA track

上级 32d30e5d
...@@ -27,7 +27,23 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) { ...@@ -27,7 +27,23 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) {
output->init(typeid(float)); output->init(typeid(float));
output->Resize(input->dims()); output->Resize(input->dims());
fpga::format_fp32_ofm(output); fpga::format_fp32_ofm(output);
int outC = 1;
int outH = 1;
int outW = 1;
if(output->dims().size() == 4){
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
}else{//2
outC = output->dims()[1];
}
int unalignedCW = outC * outW;
int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
if(alignedCW != unalignedCW){
param->aligned_out.Resize(input->dims());
param->aligned_out.mutable_data<float>(input->dims());
fpga::fpga_flush(param->aligned_out.data<float>(), outH*unalignedCW*sizeof(float));
}
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_data_type = fpga::DATA_TYPE_FP16; args.input_data_type = fpga::DATA_TYPE_FP16;
...@@ -82,19 +98,26 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { ...@@ -82,19 +98,26 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
} }
fpga::PerformBypass(args); fpga::PerformBypass(args);
auto outC = output->dims()[1]; int outC = 1;
auto outH = output->dims()[2]; int outH = 1;
auto outW = output->dims()[3]; int outW = 1;
if(output->dims().size() == 4){
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
}else{//2
outC = output->dims()[1];
}
fpga::fpga_invalidate(param.fpga_bypass_args.output.address, fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
output->fpga_data_num * sizeof(float)); output->fpga_data_num * sizeof(float));
int unalignedCW = outC * outW;
if (output->fpga_data_num != product(input->dims())) { int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
float *data_tmp = if(unalignedCW != alignedCW){
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float))); auto aligned_ptr = const_cast<float*>(param.aligned_out.data<float>());
dealign(outdata_ptr, data_tmp, outC, outH, outW); dealign(outdata_ptr, aligned_ptr, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float)); memcpy(outdata_ptr, aligned_ptr, outC * outH * outW * sizeof(float));
free(data_tmp); fpga::fpga_flush(outdata_ptr, outC * outH * outW * sizeof(float));
} }
} }
template class FetchKernel<FPGA, float>; template class FetchKernel<FPGA, float>;
......
...@@ -1270,6 +1270,7 @@ class FetchParam : public OpParam { ...@@ -1270,6 +1270,7 @@ class FetchParam : public OpParam {
public: public:
fpga::BypassArgs fpga_bypass_args; fpga::BypassArgs fpga_bypass_args;
Tensor aligned_out;
#endif #endif
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册