From 0719bf082f931fdea384f03d8a6025e63ca074f8 Mon Sep 17 00:00:00 2001 From: zhangyang0701 Date: Mon, 11 Feb 2019 17:53:48 +0800 Subject: [PATCH] add reshape2 for FPGA track --- src/fpga/V1/api.cpp | 7 ++ src/fpga/V1/api.h | 1 + .../kernel/fpga/V1/reshape2_kernel.cpp | 74 +++++++++++++++++-- 3 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/fpga/V1/api.cpp b/src/fpga/V1/api.cpp index b462cc5230..acb48aca70 100644 --- a/src/fpga/V1/api.cpp +++ b/src/fpga/V1/api.cpp @@ -37,6 +37,13 @@ void format_image(framework::Tensor *image_tensor) { } } +void format_ofm(framework::Tensor *ofm_tensor) { + if (ofm_tensor->type() == typeid(float)) { + format_fp32_ofm(ofm_tensor); + } else { + format_fp16_ofm(ofm_tensor); + } +} void format_fp16_ofm(framework::Tensor *ofm_tensor) { auto dims = ofm_tensor->dims(); size_t memory_size = 0; diff --git a/src/fpga/V1/api.h b/src/fpga/V1/api.h index 05a30ddce4..33a5d3d33f 100644 --- a/src/fpga/V1/api.h +++ b/src/fpga/V1/api.h @@ -23,6 +23,7 @@ namespace paddle_mobile { namespace fpga { void format_image(framework::Tensor* image_tensor); +void format_ofm(framework::Tensor* ofm_tensor); void format_fp16_ofm(framework::Tensor* ofm_tensor); // only allocate memory void format_fp16_ofm(framework::Tensor* ofm_tensor, framework::DDim dims); void format_fp32_ofm(framework::Tensor* ofm_tensor); diff --git a/src/operators/kernel/fpga/V1/reshape2_kernel.cpp b/src/operators/kernel/fpga/V1/reshape2_kernel.cpp index e92be9124f..9e5ce02658 100644 --- a/src/operators/kernel/fpga/V1/reshape2_kernel.cpp +++ b/src/operators/kernel/fpga/V1/reshape2_kernel.cpp @@ -25,7 +25,6 @@ bool Reshape2Kernel::Init(Reshape2Param *param) { auto input = const_cast(param->InputX()); auto output = param->Out(); auto shape = param->Shape(); - output->ShareDataWith(*input); auto num_in = framework::product(input->dims()); auto num_shape = framework::product(framework::make_ddim(shape)); @@ -38,22 +37,79 @@ bool Reshape2Kernel::Init(Reshape2Param *param) { } } output->Resize(framework::make_ddim(shape)); + output->set_type(input->type()); + fpga::format_ofm(output); DLOG << "input: " << input; DLOG << "output: " << output; return true; } +void reshape(LoDTensor *input, LoDTensor *output) { + // Subscript r means after reshape + // TODO zhangyang verify this function + + float *input_ptr_f, *output_ptr_f; + half *input_ptr_h, *output_ptr_h; + bool is_float = false; + + if (input->type() == typeid(float)) { + input_ptr_f = input->data(); + output_ptr_f = output->data(); + is_float = true; + + } else { + input_ptr_h = input->data(); + output_ptr_h = output->data(); + } + + auto C = static_cast(input->dims()[1]); + auto H = static_cast(input->dims()[2]); + auto W = static_cast(input->dims()[3]); + auto Cr = static_cast(output->dims()[1]); + auto Hr = static_cast(output->dims()[2]); + auto Wr = static_cast(output->dims()[3]); + PADDLE_MOBILE_ENFORCE(C * H * W == Cr * Hr * Wr, "Dims don't match"); + auto WC = W * C; + auto WC_align = fpga::align_to_x(WC, IMAGE_ALIGNMENT); + auto HW = H * W; + auto WCr = Wr * Cr; + auto WCr_align = fpga::align_to_x(WCr, IMAGE_ALIGNMENT); + auto HWr = Hr * Wr; + + int offset_align = 0; + int offset_r = 0, offset_align_r = 0; + int cr = 0, hr = 0, wr = 0; + + for (int h = 0; h < H; h++) { + int offset0 = h * WC_align; + for (int w = 0; w < W; w++) { + int offset1 = w * C + offset0; + for (int c = 0; c < C; c++) { + offset_align = offset1 + c; + offset_r = c * HW + h * W + c; + cr = offset_r / HWr; + hr = offset_r % HWr / Wr; + wr = offset_r % Wr; + offset_align_r = hr * WCr_align + wr * Cr + cr; + // DLOG << "hwc"<< h<< " " << w << " " << c; + // DLOG << "hrwrcr" << hr<< " " << wr << " " << cr; + if (is_float) { + output_ptr_f[offset_align_r] = input_ptr_f[offset_align]; + } else { + output_ptr_h[offset_align_r] = input_ptr_h[offset_align]; + } + } + } + } +} + template <> void Reshape2Kernel::Compute(const Reshape2Param ¶m) { auto input = const_cast(param.InputX()); auto output = param.Out(); auto shape = param.Shape(); - if (output->type() != typeid(half)) { - DLOG << "wrong type"; - } - auto num_in = framework::product(input->dims()); auto num_shape = framework::product(framework::make_ddim(shape)); PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported"); @@ -65,10 +121,12 @@ void Reshape2Kernel::Compute(const Reshape2Param ¶m) { } } output->Resize(framework::make_ddim(shape)); - if (output->type() != typeid(half)) { - DLOG << "wrong type"; - DLOG << output; + if (output->dims() == input->dims()) { + DLOG << "No need to reshape"; + return; } + + reshape(input, output); // } -- GitLab