diff --git a/mobile/src/operators/kernel/fpga/V1/.fetch_kernel.cpp.swp b/mobile/src/operators/kernel/fpga/V1/.fetch_kernel.cpp.swp deleted file mode 100644 index a49df219da80b5418ef845304c9d3619a269a5f3..0000000000000000000000000000000000000000 Binary files a/mobile/src/operators/kernel/fpga/V1/.fetch_kernel.cpp.swp and /dev/null differ diff --git a/mobile/src/operators/kernel/fpga/V1/fetch_kernel.cpp b/mobile/src/operators/kernel/fpga/V1/fetch_kernel.cpp index 89cfe8f2635c4dc3ea4ebaf8a8466cbb84c723c5..4309146afd98e124f2dbe5ecb14a505db69fce1e 100644 --- a/mobile/src/operators/kernel/fpga/V1/fetch_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V1/fetch_kernel.cpp @@ -80,7 +80,7 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto output = ¶m.Out()->at(col); if (input->type() == type_id()) { output->ShareDataWith(*input); - // output dims equal to input dim + // output dims == input dim output->Resize(input->dims()); return; } diff --git a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp index e41061d4d1b25ebc379b1eeca0990d6073f4b6b0..8697739b0a20378df53636bdeedc5edf3c877ed8 100644 --- a/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V1/reshape2_kernel.cpp @@ -92,69 +92,67 @@ void reshape(LoDTensor *input, LoDTensor *output) { fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half)); } -static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ - int input_dims_size = input_dims.size(); - int output_dims_size = output_dims.size(); - bool dims_flag2 = true; - auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; - int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; - for(int i = 0; i < temp_dims.size(); ++i){ - if(i < short_dims){ - if(input_dims[i] != output_dims[i]){ - dims_flag2 = false; - break; - } - } - else{ - if(temp_dims[i] != 1){ - dims_flag2 = false; - break; - } - } - } - return dims_flag2; - - +static inline bool reshape2_judge(const framework::DDim input_dims, + const framework::DDim output_dims) { + int input_dims_size = input_dims.size(); + int output_dims_size = output_dims.size(); + bool dims_flag2 = true; + auto temp_dims = input_dims_size > output_dims_size ? + input_dims : output_dims; + int short_dims = input_dims_size > output_dims_size ? + output_dims_size : input_dims_size; + for (int i = 0; i < temp_dims.size(); ++i) { + if (i < short_dims) { + if (input_dims[i] != output_dims[i]) { + dims_flag2 = false; + break; + } + } else { + if (temp_dims[i] != 1) { + dims_flag2 = false; + break; + } + } + } + return dims_flag2; } template <> void Reshape2Kernel::Compute(const Reshape2Param ¶m) { - auto input = const_cast(param.InputX()); - auto output = param.Out(); - auto shape = param.Shape(); - - auto num_in = framework::product(input->dims()); - auto num_shape = framework::product(framework::make_ddim(shape)); - PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported"); - - for (int i = 0; i < shape.size(); i++) { - if (shape[i] == -1) { - shape[i] = static_cast(-num_in / num_shape); - break; + auto input = const_cast(param.InputX()); + auto output = param.Out(); + auto shape = param.Shape(); + + auto num_in = framework::product(input->dims()); + auto num_shape = framework::product(framework::make_ddim(shape)); + PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported"); + + for (int i = 0; i < shape.size(); i++) { + if (shape[i] == -1) { + shape[i] = static_cast(-num_in / num_shape); + break; + } } - } - output->Resize(framework::make_ddim(shape)); + output->Resize(framework::make_ddim(shape)); - auto input_dims = input->dims(); - auto output_dims = output->dims(); + auto input_dims = input->dims(); + auto output_dims = output->dims(); - bool dims_flags = input_dims == output_dims; - bool dims_flag2 = true; + bool dims_flags = input_dims == output_dims; + bool dims_flag2 = true; - if(!dims_flags){ - dims_flag2 = reshape2_judge(input_dims, output_dims); - } - - - if (dims_flags || dims_flag2 ) { - DLOG << "No need to reshape"; - output->ShareDataWith(*input); - framework::LoD lod = input->lod(); - output->set_lod(lod); - return; - } + if (!dims_flags) { + dims_flag2 = reshape2_judge(input_dims, output_dims); + } + if (dims_flags || dims_flag2) { + DLOG << "No need to reshape"; + output->ShareDataWith(*input); + framework::LoD lod = input->lod(); + output->set_lod(lod); + return; + } - reshape(input, output); + reshape(input, output); // }