From 3d7530f4e73d56bc9392de92dac356314c21b114 Mon Sep 17 00:00:00 2001 From: jameswu2014 <545426914@qq.com> Date: Wed, 21 Aug 2019 20:25:58 -0700 Subject: [PATCH] reshape2 ddim bug --- .../kernel/fpga/V1/reshape2_kernel.cpp | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp index 647ecb5a65..e41061d4d1 100644 --- a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp @@ -92,6 +92,31 @@ void reshape(LoDTensor *input, LoDTensor *output) { fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half)); } +static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ + int input_dims_size = input_dims.size(); + int output_dims_size = output_dims.size(); + bool dims_flag2 = true; + auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; + int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; + for(int i = 0; i < temp_dims.size(); ++i){ + if(i < short_dims){ + if(input_dims[i] != output_dims[i]){ + dims_flag2 = false; + break; + } + } + else{ + if(temp_dims[i] != 1){ + dims_flag2 = false; + break; + } + } + } + return dims_flag2; + + +} + template <> void Reshape2Kernel::Compute(const Reshape2Param ¶m) { auto input = const_cast(param.InputX()); @@ -109,7 +134,19 @@ void Reshape2Kernel::Compute(const Reshape2Param ¶m) { } } output->Resize(framework::make_ddim(shape)); - if (output->dims() == input->dims()) { + + auto input_dims = input->dims(); + auto output_dims = output->dims(); + + bool dims_flags = input_dims == output_dims; + bool dims_flag2 = true; + + if(!dims_flags){ + dims_flag2 = reshape2_judge(input_dims, output_dims); + } + + + if (dims_flags || dims_flag2 ) { DLOG << "No need to reshape"; output->ShareDataWith(*input); framework::LoD lod = input->lod(); -- GitLab