diff --git a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp index 647ecb5a6501371c74c8762cf81cee206f1dca68..e41061d4d1b25ebc379b1eeca0990d6073f4b6b0 100644 --- a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp @@ -92,6 +92,31 @@ void reshape(LoDTensor *input, LoDTensor *output) { fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half)); } +static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ + int input_dims_size = input_dims.size(); + int output_dims_size = output_dims.size(); + bool dims_flag2 = true; + auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; + int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; + for(int i = 0; i < temp_dims.size(); ++i){ + if(i < short_dims){ + if(input_dims[i] != output_dims[i]){ + dims_flag2 = false; + break; + } + } + else{ + if(temp_dims[i] != 1){ + dims_flag2 = false; + break; + } + } + } + return dims_flag2; + + +} + template <> void Reshape2Kernel::Compute(const Reshape2Param ¶m) { auto input = const_cast(param.InputX()); @@ -109,7 +134,19 @@ void Reshape2Kernel::Compute(const Reshape2Param ¶m) { } } output->Resize(framework::make_ddim(shape)); - if (output->dims() == input->dims()) { + + auto input_dims = input->dims(); + auto output_dims = output->dims(); + + bool dims_flags = input_dims == output_dims; + bool dims_flag2 = true; + + if(!dims_flags){ + dims_flag2 = reshape2_judge(input_dims, output_dims); + } + + + if (dims_flags || dims_flag2 ) { DLOG << "No need to reshape"; output->ShareDataWith(*input); framework::LoD lod = input->lod();