未验证 提交 a2f7e643 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge branch 'develop' into develop

......@@ -27,7 +27,24 @@ bool FetchKernel<FPGA, float>::Init(FetchParam<FPGA> *param) {
output->init(typeid(float));
output->Resize(input->dims());
fpga::format_fp32_ofm(output);
int outC = 1;
int outH = 1;
int outW = 1;
if (output->dims().size() == 4) {
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
} else { // 2
outC = output->dims()[1];
}
int unalignedCW = outC * outW;
int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
if (alignedCW != unalignedCW) {
param->aligned_out.Resize(input->dims());
param->aligned_out.mutable_data<float>(input->dims());
fpga::fpga_flush(param->aligned_out.data<float>(),
outH * unalignedCW * sizeof(float));
}
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_data_type = fpga::DATA_TYPE_FP16;
......@@ -82,19 +99,26 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
}
fpga::PerformBypass(args);
auto outC = output->dims()[1];
auto outH = output->dims()[2];
auto outW = output->dims()[3];
int outC = 1;
int outH = 1;
int outW = 1;
if (output->dims().size() == 4) {
outC = output->dims()[1];
outH = output->dims()[2];
outW = output->dims()[3];
} else { // 2
outC = output->dims()[1];
}
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
output->fpga_data_num * sizeof(float));
if (output->fpga_data_num != product(input->dims())) {
float *data_tmp =
reinterpret_cast<float *>(malloc(outC * outH * outW * sizeof(float)));
dealign(outdata_ptr, data_tmp, outC, outH, outW);
memcpy(outdata_ptr, data_tmp, outC * outH * outW * sizeof(float));
free(data_tmp);
int unalignedCW = outC * outW;
int alignedCW = fpga::align_to_x(unalignedCW, IMAGE_ALIGNMENT);
if (unalignedCW != alignedCW) {
auto aligned_ptr = const_cast<float *>(param.aligned_out.data<float>());
dealign(outdata_ptr, aligned_ptr, outC, outH, outW);
memcpy(outdata_ptr, aligned_ptr, outC * outH * outW * sizeof(float));
fpga::fpga_flush(outdata_ptr, outC * outH * outW * sizeof(float));
}
}
template class FetchKernel<FPGA, float>;
......
......@@ -1270,6 +1270,7 @@ class FetchParam : public OpParam {
public:
fpga::BypassArgs fpga_bypass_args;
Tensor aligned_out;
#endif
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册