提交 bad1792c 编写于 作者: qnqinan's avatar qnqinan

fix the dealign bug in fetch op FPGA track

上级 32d30e5d
......@@ -27,7 +27,23 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) {
output->init(typeid(float));
output->Resize(input->dims());
fpga::format_fp32_ofm(output);
int outC = 1;
int outH = 1;
int outW = 1;
if(output->dims().size() == 4){
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
}else{//2
outC = output->dims()[1];
}
int unalignedCW = outC * outW;
int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
if(alignedCW != unalignedCW){
param->aligned_out.Resize(input->dims());
param->aligned_out.mutable_data<float>(input->dims());
fpga::fpga_flush(param->aligned_out.data<float>(), outH*unalignedCW*sizeof(float));
}
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_data_type = fpga::DATA_TYPE_FP16;
......@@ -82,19 +98,26 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
}
fpga::PerformBypass(args);
auto outC = output->dims()[1];
auto outH = output->dims()[2];
auto outW = output->dims()[3];
int outC = 1;
int outH = 1;
int outW = 1;
if(output->dims().size() == 4){
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
}else{//2
outC = output->dims()[1];
}
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
output->fpga_data_num * sizeof(float));
if (output->fpga_data_num != product(input->dims())) {
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp);
int unalignedCW = outC * outW;
int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
if(unalignedCW != alignedCW){
auto aligned_ptr = const_cast<float*>(param.aligned_out.data<float>());
dealign(outdata_ptr, aligned_ptr, outC, outH, outW);
memcpy(outdata_ptr, aligned_ptr, outC * outH * outW * sizeof(float));
fpga::fpga_flush(outdata_ptr, outC * outH * outW * sizeof(float));
}
}
template class FetchKernel<FPGA, float>;
......
......@@ -1270,6 +1270,7 @@ class FetchParam : public OpParam {
public:
fpga::BypassArgs fpga_bypass_args;
Tensor aligned_out;
#endif
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册