From c754a38f774fc36bf49a97612c0cb9dc17c5a317 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Tue, 31 Mar 2020 09:43:27 +0800 Subject: [PATCH] [operator] add InferShapeImpl method (#3294) --- lite/core/op_lite.cc | 55 ++++ lite/core/op_lite.h | 20 +- lite/core/program.cc | 3 +- lite/operators/activation_grad_ops.cc | 2 +- lite/operators/activation_grad_ops.h | 2 +- lite/operators/activation_ops.cc | 2 +- lite/operators/activation_ops.h | 2 +- lite/operators/affine_channel_op.cc | 2 +- lite/operators/affine_channel_op.h | 2 +- lite/operators/anchor_generator_op.cc | 2 +- lite/operators/anchor_generator_op.h | 2 +- lite/operators/argmax_op.cc | 2 +- lite/operators/argmax_op.h | 2 +- lite/operators/assign_op.cc | 2 +- lite/operators/assign_op.h | 2 +- lite/operators/assign_value_op.cc | 2 +- lite/operators/assign_value_op.h | 2 +- lite/operators/attention_padding_mask_op.cc | 2 +- lite/operators/attention_padding_mask_op.h | 2 +- lite/operators/axpy_op.cc | 2 +- lite/operators/axpy_op.h | 2 +- lite/operators/batch_norm_op.cc | 2 +- lite/operators/batch_norm_op.h | 2 +- lite/operators/beam_search_decode_op.cc | 2 +- lite/operators/beam_search_decode_op.h | 2 +- lite/operators/beam_search_op.cc | 2 +- lite/operators/beam_search_op.h | 2 +- lite/operators/box_clip_op.cc | 2 +- lite/operators/box_clip_op.h | 2 +- lite/operators/box_coder_op.cc | 2 +- lite/operators/box_coder_op.h | 2 +- lite/operators/calib_op.cc | 2 +- lite/operators/calib_op.h | 2 +- lite/operators/cast_op.cc | 2 +- lite/operators/cast_op.h | 2 +- lite/operators/collect_fpn_proposals_op.cc | 2 +- lite/operators/collect_fpn_proposals_op.h | 2 +- lite/operators/compare_op.cc | 2 +- lite/operators/compare_op.h | 2 +- lite/operators/concat_op.cc | 2 +- lite/operators/concat_op.h | 2 +- lite/operators/conditional_block_op.cc | 2 +- lite/operators/conditional_block_op.h | 2 +- lite/operators/conv_op.cc | 30 +- lite/operators/conv_op.h | 4 +- lite/operators/conv_transpose_op.cc | 2 +- lite/operators/conv_transpose_op.h | 2 +- lite/operators/crf_decoding_op.cc | 2 +- lite/operators/crf_decoding_op.h | 2 +- lite/operators/crop_op.cc | 2 +- lite/operators/crop_op.h | 2 +- lite/operators/decode_bboxes_op.cc | 2 +- lite/operators/decode_bboxes_op.h | 2 +- lite/operators/density_prior_box_op.cc | 2 +- lite/operators/density_prior_box_op.h | 2 +- lite/operators/distribute_fpn_proposals_op.cc | 2 +- lite/operators/distribute_fpn_proposals_op.h | 2 +- lite/operators/dropout_op.cc | 2 +- lite/operators/dropout_op.h | 2 +- lite/operators/elementwise_grad_ops.cc | 2 +- lite/operators/elementwise_grad_ops.h | 2 +- lite/operators/elementwise_ops.cc | 35 +- lite/operators/elementwise_ops.h | 5 +- lite/operators/expand_op.cc | 2 +- lite/operators/expand_op.h | 2 +- .../fake_channel_wise_dequantize_max_abs.h | 2 +- lite/operators/fake_dequantize_max_abs.h | 2 +- ...e_quantize_dequantize_moving_avg_max_abs.h | 2 +- .../fake_quantize_moving_avg_max_abs.h | 2 +- lite/operators/fake_quantize_range_abs_max.h | 2 +- lite/operators/fc_op.cc | 29 +- lite/operators/fc_op.h | 3 +- lite/operators/feed_op.cc | 2 +- lite/operators/fetch_op.cc | 2 +- .../fill_constant_batch_size_like_op.cc | 2 +- .../fill_constant_batch_size_like_op.h | 2 +- lite/operators/fill_constant_op.cc | 2 +- lite/operators/fill_constant_op.h | 2 +- lite/operators/flatten_op.cc | 6 +- lite/operators/flatten_op.h | 4 +- .../fusion_elementwise_activation_ops.cc | 4 +- .../fusion_elementwise_activation_ops.h | 4 +- lite/operators/gather_op.cc | 2 +- lite/operators/gather_op.h | 2 +- lite/operators/generate_proposals_op.cc | 2 +- lite/operators/generate_proposals_op.h | 2 +- lite/operators/grid_sampler_op.cc | 2 +- lite/operators/grid_sampler_op.h | 2 +- lite/operators/gru_op.cc | 2 +- lite/operators/gru_op.h | 2 +- lite/operators/gru_unit_op.cc | 2 +- lite/operators/gru_unit_op.h | 2 +- lite/operators/im2sequence_op.cc | 2 +- lite/operators/im2sequence_op.h | 2 +- lite/operators/increment_op.cc | 2 +- lite/operators/increment_op.h | 2 +- lite/operators/instance_norm_op.cc | 2 +- lite/operators/instance_norm_op.h | 2 +- lite/operators/interpolate_op.cc | 2 +- lite/operators/interpolate_op.h | 2 +- lite/operators/io_copy_op.cc | 2 +- lite/operators/io_copy_op.h | 2 +- lite/operators/is_empty_op.cc | 2 +- lite/operators/is_empty_op.h | 2 +- lite/operators/layer_norm_op.cc | 2 +- lite/operators/layer_norm_op.h | 2 +- lite/operators/layout_op.cc | 2 +- lite/operators/layout_op.h | 2 +- lite/operators/lod_reset_op.cc | 2 +- lite/operators/lod_reset_op.h | 2 +- lite/operators/logical_op.cc | 4 +- lite/operators/logical_op.h | 4 +- lite/operators/lookup_table_dequant_op.cc | 2 +- lite/operators/lookup_table_dequant_op.h | 2 +- lite/operators/lookup_table_op.cc | 2 +- lite/operators/lookup_table_op.h | 2 +- lite/operators/lookup_table_v2_op.cc | 2 +- lite/operators/lookup_table_v2_op.h | 2 +- lite/operators/lrn_op.cc | 2 +- lite/operators/lrn_op.h | 2 +- lite/operators/lstm_op.cc | 2 +- lite/operators/lstm_op.h | 2 +- lite/operators/match_matrix_tensor_op.cc | 2 +- lite/operators/match_matrix_tensor_op.h | 2 +- lite/operators/matmul_op.cc | 2 +- lite/operators/matmul_op.h | 2 +- lite/operators/mean_grad_op.cc | 2 +- lite/operators/mean_grad_op.h | 2 +- lite/operators/mean_op.cc | 2 +- lite/operators/mean_op.h | 2 +- lite/operators/merge_lod_tensor_op.cc | 2 +- lite/operators/merge_lod_tensor_op.h | 2 +- lite/operators/mul_grad_op.cc | 2 +- lite/operators/mul_grad_op.h | 2 +- lite/operators/mul_op.cc | 2 +- lite/operators/mul_op.h | 2 +- lite/operators/multiclass_nms_op.cc | 2 +- lite/operators/multiclass_nms_op.h | 2 +- lite/operators/negative_op.cc | 2 +- lite/operators/negative_op.h | 2 +- lite/operators/norm_op.cc | 2 +- lite/operators/norm_op.h | 2 +- lite/operators/op_params.h | 304 +++++++++++------- lite/operators/pad2d_op.cc | 2 +- lite/operators/pad2d_op.h | 2 +- lite/operators/pool_op.cc | 2 +- lite/operators/pool_op.h | 2 +- lite/operators/power_op.cc | 2 +- lite/operators/power_op.h | 2 +- lite/operators/prior_box_op.cc | 2 +- lite/operators/prior_box_op.h | 2 +- lite/operators/range_op.cc | 2 +- lite/operators/range_op.h | 2 +- lite/operators/read_from_array_op.cc | 2 +- lite/operators/read_from_array_op.h | 2 +- lite/operators/reduce_max_op.cc | 2 +- lite/operators/reduce_max_op.h | 2 +- lite/operators/reduce_mean_op.cc | 2 +- lite/operators/reduce_mean_op.h | 2 +- lite/operators/reduce_ops.cc | 2 +- lite/operators/reduce_ops.h | 2 +- lite/operators/reduce_prod_op.cc | 2 +- lite/operators/reduce_prod_op.h | 2 +- lite/operators/relu_op.cc | 2 +- lite/operators/relu_op.h | 2 +- lite/operators/reshape_op.cc | 6 +- lite/operators/reshape_op.h | 4 +- lite/operators/roi_align_op.cc | 2 +- lite/operators/roi_align_op.h | 2 +- lite/operators/scale_op.cc | 2 +- lite/operators/scale_op.h | 2 +- lite/operators/search_aligned_mat_mul_op.cc | 2 +- lite/operators/search_aligned_mat_mul_op.h | 2 +- lite/operators/search_fc_op.cc | 2 +- lite/operators/search_fc_op.h | 2 +- lite/operators/search_grnn_op.cc | 2 +- lite/operators/search_grnn_op.h | 2 +- lite/operators/search_group_padding_op.cc | 2 +- lite/operators/search_group_padding_op.h | 2 +- lite/operators/search_seq_depadding_op.cc | 2 +- lite/operators/search_seq_depadding_op.h | 2 +- lite/operators/search_seq_fc_op.cc | 2 +- lite/operators/search_seq_fc_op.h | 2 +- lite/operators/search_seq_softmax_op.cc | 2 +- lite/operators/search_seq_softmax_op.h | 2 +- lite/operators/sequence_arithmetic_op.cc | 2 +- lite/operators/sequence_arithmetic_op.h | 2 +- lite/operators/sequence_concat_op.cc | 2 +- lite/operators/sequence_concat_op.h | 2 +- lite/operators/sequence_conv_op.cc | 2 +- lite/operators/sequence_conv_op.h | 2 +- lite/operators/sequence_expand_as_op.cc | 2 +- lite/operators/sequence_expand_as_op.h | 2 +- lite/operators/sequence_expand_op.cc | 2 +- lite/operators/sequence_expand_op.h | 2 +- lite/operators/sequence_pool_concat_op.cc | 2 +- lite/operators/sequence_pool_concat_op.h | 2 +- lite/operators/sequence_pool_op.cc | 2 +- lite/operators/sequence_pool_op.h | 2 +- lite/operators/sequence_reshape_op.cc | 2 +- lite/operators/sequence_reshape_op.h | 2 +- lite/operators/sequence_reverse_op.cc | 2 +- lite/operators/sequence_reverse_op.h | 2 +- lite/operators/sequence_softmax_op.cc | 2 +- lite/operators/sequence_softmax_op.h | 2 +- .../operators/sequence_topk_avg_pooling_op.cc | 2 +- lite/operators/sequence_topk_avg_pooling_op.h | 2 +- lite/operators/sgd_op.cc | 2 +- lite/operators/sgd_op.h | 2 +- lite/operators/shape_op.cc | 2 +- lite/operators/shape_op.h | 2 +- lite/operators/shuffle_channel_op.cc | 2 +- lite/operators/shuffle_channel_op.h | 2 +- lite/operators/slice_op.cc | 2 +- lite/operators/slice_op.h | 2 +- lite/operators/softmax_op.cc | 30 +- lite/operators/softmax_op.h | 3 +- lite/operators/split_lod_tensor_op.cc | 2 +- lite/operators/split_lod_tensor_op.h | 2 +- lite/operators/split_op.cc | 2 +- lite/operators/split_op.h | 2 +- lite/operators/squeeze_op.cc | 6 +- lite/operators/squeeze_op.h | 4 +- lite/operators/stack_op.cc | 2 +- lite/operators/stack_op.h | 2 +- lite/operators/subgraph_op.cc | 2 +- lite/operators/subgraph_op.h | 2 +- lite/operators/topk_op.cc | 2 +- lite/operators/topk_op.h | 2 +- lite/operators/transpose_op.cc | 4 +- lite/operators/transpose_op.h | 4 +- lite/operators/uniform_random_op.cc | 2 +- lite/operators/uniform_random_op.h | 2 +- lite/operators/unsqueeze_op.cc | 6 +- lite/operators/unsqueeze_op.h | 4 +- lite/operators/var_conv_2d_op.cc | 2 +- lite/operators/var_conv_2d_op.h | 2 +- lite/operators/while_op.cc | 2 +- lite/operators/while_op.h | 2 +- lite/operators/write_to_array_op.cc | 2 +- lite/operators/write_to_array_op.h | 2 +- lite/operators/yolo_box_op.cc | 2 +- lite/operators/yolo_box_op.h | 2 +- 243 files changed, 517 insertions(+), 502 deletions(-) diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index c76e369466..a9ccd1b9ae 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -22,6 +22,61 @@ namespace paddle { namespace lite { +bool OpLite::InferShape() { + // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ + // InferShapeByMemoryInternal will be applied. + if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { + return this->InferShapeWithCache(); + } else { + // otherwise, InferShapeImpl is applied directly. + return this->InferShapeImpl(); + } +} +bool OpLite::InferShapeWithCache() { + // 1. Get vector of current input tensors + auto *current_inputs = param_.input_tensor_ptrs(); + // 2. Get hash value of current inputs shape and lod + size_t new_hash = 0; + for (auto iter = current_inputs->begin(); iter != current_inputs->end(); + iter++) { + // combined dims value into new_hash value. + auto &element_dims = (*iter)->dims(); + for (int i = 0; i < element_dims.size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(element_dims[i])); + } + // combine lod value into new_hash valud. + auto &emement_lods = (*iter)->lod(); + for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end(); + lod_iter++) { + for (int i = 0; i < lod_iter->size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(lod_iter->at(i))); + } + } + } + // 3. infer shapes of output tensors + if (new_hash == io_shape_lod_hash_ && new_hash != 0) { + // if current hash value is consistent with io_shape_lod_hash_, + // previous outputs shape and lod are reused. + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + current_outputs->at(i)->Resize(last_output_shapes[i]); + current_outputs->at(i)->set_lod(last_output_lods[i]); + } + } else { + // otherwise, current hash value is changed, InferShapeImpl will apply. + io_shape_lod_hash_ = new_hash; + this->InferShapeImpl(); + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + last_output_shapes[i] = current_outputs->at(i)->dims(); + last_output_lods[i] = current_outputs->at(i)->lod(); + } + } + return true; +} + std::vector> OpLite::CreateKernels( const std::vector &places, const std::string &kernel_type) { std::vector> kernels; diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 77d8091b4b..4c6c66be7e 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -24,6 +25,7 @@ #include "lite/core/kernel.h" #include "lite/core/scope.h" #include "lite/model_parser/cpp/op_desc.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { @@ -64,8 +66,8 @@ class OpLite : public Registry { // Check the shape. virtual bool CheckShape() const { return true; } // Inference the outputs' shape. - virtual bool InferShape() const { return true; } - virtual bool SmartInferShape() { return this->InferShape(); } + virtual bool InferShapeImpl() const { return true; } + virtual bool InferShape(); // Run this operator. virtual bool Run(); // Indicate whether the Op runs only once or not @@ -151,10 +153,16 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; - std::vector last_input_shapes; - std::vector last_output_shapes; - std::vector>> last_output_lods; - std::vector>> last_input_lods; + + std::vector last_output_shapes{}; + std::vector>> last_output_lods{}; + size_t io_shape_lod_hash_{}; + mutable operators::ParamBase param_; + + private: + // Infer Shape according to memory, if current input shapes are consistent + // with that of previous inputs, output shapes of last time will be reused. + bool InferShapeWithCache(); }; /* diff --git a/lite/core/program.cc b/lite/core/program.cc index 580389fbad..7284c3983c 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -286,8 +286,7 @@ void Instruction::Run() { return; } - // op_->InferShape(); - op_->SmartInferShape(); + op_->InferShape(); kernel_->Launch(); has_run_ = true; } diff --git a/lite/operators/activation_grad_ops.cc b/lite/operators/activation_grad_ops.cc index 9a37a5f0a1..b31163e5dc 100644 --- a/lite/operators/activation_grad_ops.cc +++ b/lite/operators/activation_grad_ops.cc @@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const { return true; } -bool ActivationGradOp::InferShape() const { +bool ActivationGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.Out_grad->dims()); return true; } diff --git a/lite/operators/activation_grad_ops.h b/lite/operators/activation_grad_ops.h index 5421b3247f..cf928cfe1b 100644 --- a/lite/operators/activation_grad_ops.h +++ b/lite/operators/activation_grad_ops.h @@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index f7a326358b..abaaa1a705 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const { return true; } -bool ActivationOp::InferShape() const { +bool ActivationOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); auto out_lod = param_.Out->mutable_lod(); *out_lod = param_.X->lod(); diff --git a/lite/operators/activation_ops.h b/lite/operators/activation_ops.h index 34099ab0fd..8f81b12af0 100644 --- a/lite/operators/activation_ops.h +++ b/lite/operators/activation_ops.h @@ -26,7 +26,7 @@ class ActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/affine_channel_op.cc b/lite/operators/affine_channel_op.cc index c4945ababd..447079deb3 100644 --- a/lite/operators/affine_channel_op.cc +++ b/lite/operators/affine_channel_op.cc @@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const { return true; } -bool AffineChannelOpLite::InferShape() const { +bool AffineChannelOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); param_.Out->Resize(x_dims); return true; diff --git a/lite/operators/affine_channel_op.h b/lite/operators/affine_channel_op.h index 85a043bdc8..5a3d9d6625 100644 --- a/lite/operators/affine_channel_op.h +++ b/lite/operators/affine_channel_op.h @@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/anchor_generator_op.cc b/lite/operators/anchor_generator_op.cc index 8daa54905f..e57a4b2df8 100644 --- a/lite/operators/anchor_generator_op.cc +++ b/lite/operators/anchor_generator_op.cc @@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const { return true; } -bool AnchorGeneratorOpLite::InferShape() const { +bool AnchorGeneratorOpLite::InferShapeImpl() const { auto input_dims = param_.Input->dims(); size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size(); std::vector output_shape( diff --git a/lite/operators/anchor_generator_op.h b/lite/operators/anchor_generator_op.h index 46e5e0fac2..2ff3422824 100644 --- a/lite/operators/anchor_generator_op.h +++ b/lite/operators/anchor_generator_op.h @@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/argmax_op.cc b/lite/operators/argmax_op.cc index 772cc44607..b733998ae5 100644 --- a/lite/operators/argmax_op.cc +++ b/lite/operators/argmax_op.cc @@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const { return true; } -bool ArgmaxOpLite::InferShape() const { +bool ArgmaxOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); int x_rank = x_dims.size(); int axis = param_.Axis; diff --git a/lite/operators/argmax_op.h b/lite/operators/argmax_op.h index a5accc97e3..e6944507cf 100644 --- a/lite/operators/argmax_op.h +++ b/lite/operators/argmax_op.h @@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_op.cc b/lite/operators/assign_op.cc index 8510b7e8b7..25e8539d2e 100644 --- a/lite/operators/assign_op.cc +++ b/lite/operators/assign_op.cc @@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const { return true; } -bool AssignOpLite::InferShape() const { +bool AssignOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/assign_op.h b/lite/operators/assign_op.h index 555356c365..9e7039bb5b 100644 --- a/lite/operators/assign_op.h +++ b/lite/operators/assign_op.h @@ -30,7 +30,7 @@ class AssignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_value_op.cc b/lite/operators/assign_value_op.cc index 046c522228..ff5b55735f 100644 --- a/lite/operators/assign_value_op.cc +++ b/lite/operators/assign_value_op.cc @@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const { return true; } -bool AssignValueOpLite::InferShape() const { +bool AssignValueOpLite::InferShapeImpl() const { std::vector shape = param_.shape; std::vector out_shape; for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); diff --git a/lite/operators/assign_value_op.h b/lite/operators/assign_value_op.h index 7bf2206159..030da04818 100644 --- a/lite/operators/assign_value_op.h +++ b/lite/operators/assign_value_op.h @@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/attention_padding_mask_op.cc b/lite/operators/attention_padding_mask_op.cc index a88df0e7a9..2f3a0cd265 100644 --- a/lite/operators/attention_padding_mask_op.cc +++ b/lite/operators/attention_padding_mask_op.cc @@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const { return true; } -bool AttentionPaddingMaskOp::InferShape() const { +bool AttentionPaddingMaskOp::InferShapeImpl() const { auto src_len = param_.X->lod()[0][1]; CHECK_EQ(src_len, param_.X->dims()[1]) << "Mismatch source length, expect: " << src_len diff --git a/lite/operators/attention_padding_mask_op.h b/lite/operators/attention_padding_mask_op.h index 894d68f622..6a2443fc67 100644 --- a/lite/operators/attention_padding_mask_op.h +++ b/lite/operators/attention_padding_mask_op.h @@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/axpy_op.cc b/lite/operators/axpy_op.cc index 60f302862a..c1c6304c31 100644 --- a/lite/operators/axpy_op.cc +++ b/lite/operators/axpy_op.cc @@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const { return true; } -bool AxpyOpLite::InferShape() const { +bool AxpyOpLite::InferShapeImpl() const { auto dims = param_.Bias->dims(); // Set output dims diff --git a/lite/operators/axpy_op.h b/lite/operators/axpy_op.h index 1fa8540743..e9d9f44ca5 100644 --- a/lite/operators/axpy_op.h +++ b/lite/operators/axpy_op.h @@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index eca7fa6001..67e037fba3 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const { return true; } -bool BatchNormOp::InferShape() const { +bool BatchNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t channel_size = 0; switch (param_.data_layout) { diff --git a/lite/operators/batch_norm_op.h b/lite/operators/batch_norm_op.h index 21dbf9a28a..9598763713 100644 --- a/lite/operators/batch_norm_op.h +++ b/lite/operators/batch_norm_op.h @@ -30,7 +30,7 @@ class BatchNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_decode_op.cc b/lite/operators/beam_search_decode_op.cc index 52888d8a99..444c9d6a11 100644 --- a/lite/operators/beam_search_decode_op.cc +++ b/lite/operators/beam_search_decode_op.cc @@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const { return true; } -bool BeamSearchDecodeOpLite::InferShape() const { return true; } +bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; } bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/beam_search_decode_op.h b/lite/operators/beam_search_decode_op.h index 9d324d2bf0..38bf9929ab 100644 --- a/lite/operators/beam_search_decode_op.h +++ b/lite/operators/beam_search_decode_op.h @@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_op.cc b/lite/operators/beam_search_op.cc index c998e002ee..ea777ad533 100644 --- a/lite/operators/beam_search_op.cc +++ b/lite/operators/beam_search_op.cc @@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const { return true; } -bool BeamSearchOp::InferShape() const { return true; } +bool BeamSearchOp::InferShapeImpl() const { return true; } bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front()); diff --git a/lite/operators/beam_search_op.h b/lite/operators/beam_search_op.h index 42a6058de1..7e325cb556 100644 --- a/lite/operators/beam_search_op.h +++ b/lite/operators/beam_search_op.h @@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_clip_op.cc b/lite/operators/box_clip_op.cc index 6bd93c6ea4..08ba49bd9a 100644 --- a/lite/operators/box_clip_op.cc +++ b/lite/operators/box_clip_op.cc @@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const { return true; } -bool BoxClipOpLite::InferShape() const { +bool BoxClipOpLite::InferShapeImpl() const { auto* input = param_.Input; auto* output = param_.Output; output->Resize(input->dims()); diff --git a/lite/operators/box_clip_op.h b/lite/operators/box_clip_op.h index c7e07b1015..0aae2112ec 100644 --- a/lite/operators/box_clip_op.h +++ b/lite/operators/box_clip_op.h @@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_coder_op.cc b/lite/operators/box_coder_op.cc index c86f494fc4..3133176b35 100644 --- a/lite/operators/box_coder_op.cc +++ b/lite/operators/box_coder_op.cc @@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const { return true; } -bool BoxCoderOpLite::InferShape() const { +bool BoxCoderOpLite::InferShapeImpl() const { auto prior_box_dims = param_.prior_box->dims(); auto target_box_dims = param_.target_box->dims(); std::string code_type = param_.code_type; diff --git a/lite/operators/box_coder_op.h b/lite/operators/box_coder_op.h index 61d54fd484..51e86423e3 100644 --- a/lite/operators/box_coder_op.h +++ b/lite/operators/box_coder_op.h @@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/calib_op.cc b/lite/operators/calib_op.cc index da00f01c32..8da8747f8c 100644 --- a/lite/operators/calib_op.cc +++ b/lite/operators/calib_op.cc @@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const { CHECK_OR_FALSE(param_.output); return true; } -bool CalibOpLite::InferShape() const { +bool CalibOpLite::InferShapeImpl() const { param_.output->Resize(param_.input->dims()); return true; } diff --git a/lite/operators/calib_op.h b/lite/operators/calib_op.h index d575766c10..94240880f5 100644 --- a/lite/operators/calib_op.h +++ b/lite/operators/calib_op.h @@ -42,7 +42,7 @@ class CalibOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope); diff --git a/lite/operators/cast_op.cc b/lite/operators/cast_op.cc index 9ece0a45a3..da12e2afde 100644 --- a/lite/operators/cast_op.cc +++ b/lite/operators/cast_op.cc @@ -25,7 +25,7 @@ bool CastOp::CheckShape() const { return true; } -bool CastOp::InferShape() const { +bool CastOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/cast_op.h b/lite/operators/cast_op.h index 2f5f57f127..e045ef89f7 100644 --- a/lite/operators/cast_op.h +++ b/lite/operators/cast_op.h @@ -30,7 +30,7 @@ class CastOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/collect_fpn_proposals_op.cc b/lite/operators/collect_fpn_proposals_op.cc index 4731d4bf81..27dd9a50b6 100644 --- a/lite/operators/collect_fpn_proposals_op.cc +++ b/lite/operators/collect_fpn_proposals_op.cc @@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { return true; } -bool CollectFpnProposalsOpLite::InferShape() const { +bool CollectFpnProposalsOpLite::InferShapeImpl() const { param_.fpn_rois->Resize({param_.post_nms_topN, 4}); return true; diff --git a/lite/operators/collect_fpn_proposals_op.h b/lite/operators/collect_fpn_proposals_op.h index 1ae7bb269f..b3104e81d5 100644 --- a/lite/operators/collect_fpn_proposals_op.h +++ b/lite/operators/collect_fpn_proposals_op.h @@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/compare_op.cc b/lite/operators/compare_op.cc index aa500ba35c..f458eae71e 100644 --- a/lite/operators/compare_op.cc +++ b/lite/operators/compare_op.cc @@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const { return true; } -bool CompareOp::InferShape() const { +bool CompareOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/compare_op.h b/lite/operators/compare_op.h index 7ca21caaa1..c94cf88516 100644 --- a/lite/operators/compare_op.h +++ b/lite/operators/compare_op.h @@ -30,7 +30,7 @@ class CompareOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index b2f7438b64..c15bf29289 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const { return true; } -bool ConcatOpLite::InferShape() const { +bool ConcatOpLite::InferShapeImpl() const { const std::vector &inputs = param_.x; const size_t n = inputs.size(); CHECK_GT_OR_FALSE(n, 0); diff --git a/lite/operators/concat_op.h b/lite/operators/concat_op.h index acc41de9b3..2ac1572c83 100644 --- a/lite/operators/concat_op.h +++ b/lite/operators/concat_op.h @@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conditional_block_op.cc b/lite/operators/conditional_block_op.cc index c79c4e20a2..e3678e92c9 100644 --- a/lite/operators/conditional_block_op.cc +++ b/lite/operators/conditional_block_op.cc @@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const { return true; } -bool ConditionalBlockOpLite::InferShape() const { return true; } +bool ConditionalBlockOpLite::InferShapeImpl() const { return true; } bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/conditional_block_op.h b/lite/operators/conditional_block_op.h index 5518c255c5..1815731c8d 100644 --- a/lite/operators/conditional_block_op.h +++ b/lite/operators/conditional_block_op.h @@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 70ad3a32a8..38c59a0290 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector* paddings, } } -bool ConvOpLite::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.x->dims() && - last_input_lods[0] == param_.x->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool ConvOpLite::InferShape() const { +bool ConvOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 3379fb4095..eab17fe6db 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -34,9 +34,7 @@ class ConvOpLite : public OpLite { explicit ConvOpLite(const std::string& type) : OpLite(type) {} bool CheckShape() const override; - - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index a84b975492..511a5157ad 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size, return output_size; } -bool ConvTransposeOpLite::InferShape() const { +bool ConvTransposeOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_transpose_op.h b/lite/operators/conv_transpose_op.h index fb25c022f9..891ece4f05 100644 --- a/lite/operators/conv_transpose_op.h +++ b/lite/operators/conv_transpose_op.h @@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crf_decoding_op.cc b/lite/operators/crf_decoding_op.cc index 1b0a27ab4a..b1af573518 100644 --- a/lite/operators/crf_decoding_op.cc +++ b/lite/operators/crf_decoding_op.cc @@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const { return true; } -bool CrfDecodingOpLite::InferShape() const { +bool CrfDecodingOpLite::InferShapeImpl() const { auto emission_dims = param_.emission->dims(); if (param_.length == nullptr) { param_.viterbi_path->Resize({emission_dims[0], 1}); diff --git a/lite/operators/crf_decoding_op.h b/lite/operators/crf_decoding_op.h index 6aaf338ec2..4bc50410ab 100644 --- a/lite/operators/crf_decoding_op.h +++ b/lite/operators/crf_decoding_op.h @@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crop_op.cc b/lite/operators/crop_op.cc index 1a27cfb34d..4905d92e58 100644 --- a/lite/operators/crop_op.cc +++ b/lite/operators/crop_op.cc @@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const { return true; } -bool CropOpLite::InferShape() const { +bool CropOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); lite::DDim output_shape(x_dims); diff --git a/lite/operators/crop_op.h b/lite/operators/crop_op.h index f21278e891..bd3d0e71d8 100644 --- a/lite/operators/crop_op.h +++ b/lite/operators/crop_op.h @@ -30,7 +30,7 @@ class CropOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/decode_bboxes_op.cc b/lite/operators/decode_bboxes_op.cc index e22adf1774..1903267c3a 100644 --- a/lite/operators/decode_bboxes_op.cc +++ b/lite/operators/decode_bboxes_op.cc @@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const { return true; } -bool DecodeBboxesOpLite::InferShape() const { +bool DecodeBboxesOpLite::InferShapeImpl() const { param_.bbox_data->Resize(param_.loc_data->dims()); return true; } diff --git a/lite/operators/decode_bboxes_op.h b/lite/operators/decode_bboxes_op.h index c463992c8d..8848a1c26c 100644 --- a/lite/operators/decode_bboxes_op.h +++ b/lite/operators/decode_bboxes_op.h @@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index 86830df2f1..5ac3eef63b 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const { return true; } -bool DensityPriorBoxOpLite::InferShape() const { return true; } +bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; } bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { diff --git a/lite/operators/density_prior_box_op.h b/lite/operators/density_prior_box_op.h index bad55ad3b7..d84b20557f 100644 --- a/lite/operators/density_prior_box_op.h +++ b/lite/operators/density_prior_box_op.h @@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/distribute_fpn_proposals_op.cc b/lite/operators/distribute_fpn_proposals_op.cc index 5d6a0fca92..a23c5e1ffb 100644 --- a/lite/operators/distribute_fpn_proposals_op.cc +++ b/lite/operators/distribute_fpn_proposals_op.cc @@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const { return true; } -bool DistributeFpnProposalsOpLite::InferShape() const { +bool DistributeFpnProposalsOpLite::InferShapeImpl() const { int num_out_rois = param_.max_level - param_.min_level + 1; for (int i = 0; i < num_out_rois; i++) { param_.multi_fpn_rois[i]->Resize({-1, 4}); diff --git a/lite/operators/distribute_fpn_proposals_op.h b/lite/operators/distribute_fpn_proposals_op.h index 2390e32932..22ab2006e0 100644 --- a/lite/operators/distribute_fpn_proposals_op.h +++ b/lite/operators/distribute_fpn_proposals_op.h @@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/dropout_op.cc b/lite/operators/dropout_op.cc index 03047de3b3..858cc6d919 100644 --- a/lite/operators/dropout_op.cc +++ b/lite/operators/dropout_op.cc @@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const { return true; } -bool DropoutOp::InferShape() const { +bool DropoutOp::InferShapeImpl() const { const auto x_dims = param_.x->dims(); param_.output->Resize(x_dims); if (param_.is_test == false) { diff --git a/lite/operators/dropout_op.h b/lite/operators/dropout_op.h index 97e17e350c..bdf0e1d904 100644 --- a/lite/operators/dropout_op.h +++ b/lite/operators/dropout_op.h @@ -28,7 +28,7 @@ class DropoutOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/elementwise_grad_ops.cc b/lite/operators/elementwise_grad_ops.cc index 9d964bf9e3..730785ba6e 100644 --- a/lite/operators/elementwise_grad_ops.cc +++ b/lite/operators/elementwise_grad_ops.cc @@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const { return true; } -bool ElementwiseGradOp::InferShape() const { +bool ElementwiseGradOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (param_.XGrad) { diff --git a/lite/operators/elementwise_grad_ops.h b/lite/operators/elementwise_grad_ops.h index c45d581936..ca8a324134 100644 --- a/lite/operators/elementwise_grad_ops.h +++ b/lite/operators/elementwise_grad_ops.h @@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index 044126b3c2..f4debc39a0 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool ElementwiseOp::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.X->dims() && - last_input_shapes[1] == param_.Y->dims() && - last_input_lods[0] == param_.X->lod() && - last_input_lods[1] == param_.Y->lod()) { - param_.Out->Resize(last_output_shapes[0]); - param_.Out->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.X->dims()); - last_input_lods.push_back(param_.X->lod()); - last_input_shapes.push_back(param_.Y->dims()); - last_input_lods.push_back(param_.Y->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.Out->dims()); - last_output_lods.push_back(param_.Out->lod()); - return true; -} -bool ElementwiseOp::InferShape() const { +bool ElementwiseOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (x_dim == y_dim) { @@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { // return true; //} -// bool ElementwiseGradExplicitOp::InferShape() const { +// bool ElementwiseGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/elementwise_ops.h b/lite/operators/elementwise_ops.h index 9d6e5781b9..0f1b682fa5 100644 --- a/lite/operators/elementwise_ops.h +++ b/lite/operators/elementwise_ops.h @@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/expand_op.cc b/lite/operators/expand_op.cc index 656e8babc0..8e40a3b236 100644 --- a/lite/operators/expand_op.cc +++ b/lite/operators/expand_op.cc @@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const { return true; } -bool ExpandOpLite::InferShape() const { +bool ExpandOpLite::InferShapeImpl() const { DDim out_dims(param_.X->dims()); for (size_t i = 0; i < param_.expand_times.size(); ++i) { out_dims[i] *= param_.expand_times[i]; diff --git a/lite/operators/expand_op.h b/lite/operators/expand_op.h index ce5dcda9e8..1312df8e83 100644 --- a/lite/operators/expand_op.h +++ b/lite/operators/expand_op.h @@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fake_channel_wise_dequantize_max_abs.h b/lite/operators/fake_channel_wise_dequantize_max_abs.h index 43afb7791f..e26d5dda52 100644 --- a/lite/operators/fake_channel_wise_dequantize_max_abs.h +++ b/lite/operators/fake_channel_wise_dequantize_max_abs.h @@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_dequantize_max_abs.h b/lite/operators/fake_dequantize_max_abs.h index bc266327eb..c4bb19c048 100644 --- a/lite/operators/fake_dequantize_max_abs.h +++ b/lite/operators/fake_dequantize_max_abs.h @@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h index 8efa46c415..be7ec60e0e 100644 --- a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_moving_avg_max_abs.h index adc62a480d..5726231f31 100644 --- a/lite/operators/fake_quantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h index f68d1e20f6..14f823ece2 100644 --- a/lite/operators/fake_quantize_range_abs_max.h +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 345fc0d605..d58a9e5b88 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const { return true; } -bool FcOpLite::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (last_input_shapes[0] == param_.input->dims() && - last_input_lods[0] == param_.input->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.input->dims()); - last_input_lods.push_back(param_.input->lod()); - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool FcOpLite::InferShape() const { +bool FcOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& w_dims = param_.w->dims(); int in_num_col_dims = param_.in_num_col_dims; diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index f5dc302e27..2e6a3ad59a 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -35,8 +35,7 @@ class FcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/feed_op.cc b/lite/operators/feed_op.cc index 8a0c75f62b..c429d1f574 100644 --- a/lite/operators/feed_op.cc +++ b/lite/operators/feed_op.cc @@ -29,7 +29,7 @@ class FeedOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/fetch_op.cc b/lite/operators/fetch_op.cc index d50c0db340..9db5fb418d 100644 --- a/lite/operators/fetch_op.cc +++ b/lite/operators/fetch_op.cc @@ -29,7 +29,7 @@ class FetchOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: diff --git a/lite/operators/fill_constant_batch_size_like_op.cc b/lite/operators/fill_constant_batch_size_like_op.cc index 7df3a6aa9e..5b0ebb38e7 100644 --- a/lite/operators/fill_constant_batch_size_like_op.cc +++ b/lite/operators/fill_constant_batch_size_like_op.cc @@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { return true; } -bool FillConstantBatchSizeLikeOp::InferShape() const { +bool FillConstantBatchSizeLikeOp::InferShapeImpl() const { std::vector output_dim{param_.shape.begin(), param_.shape.end()}; if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h index 33cc45779f..3c576ab282 100644 --- a/lite/operators/fill_constant_batch_size_like_op.h +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index 698b787f46..565c4bbd16 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const { return true; } -bool FillConstantOp::InferShape() const { +bool FillConstantOp::InferShapeImpl() const { std::vector out_shape; auto shape_tensor = param_.shape_tensor; auto shape_tensor_list = param_.shape_tensor_list; diff --git a/lite/operators/fill_constant_op.h b/lite/operators/fill_constant_op.h index aa2fea5a66..3c0500898b 100644 --- a/lite/operators/fill_constant_op.h +++ b/lite/operators/fill_constant_op.h @@ -31,7 +31,7 @@ class FillConstantOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc index 6deab45023..b270dbf52f 100644 --- a/lite/operators/flatten_op.cc +++ b/lite/operators/flatten_op.cc @@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const { return true; } -bool FlattenOp::InferShape() const { +bool FlattenOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto out_lod = param_.output->mutable_lod(); @@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const { return true; } -bool Flatten2Op::InferShape() const { - FlattenOp::InferShape(); +bool Flatten2Op::InferShapeImpl() const { + FlattenOp::InferShapeImpl(); auto x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h index 61680fd390..78b803d765 100644 --- a/lite/operators/flatten_op.h +++ b/lite/operators/flatten_op.h @@ -30,7 +30,7 @@ class FlattenOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index 244394b95a..dfe3bda6c6 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { return true; } -bool FusionElementwiseActivationOp::InferShape() const { +bool FusionElementwiseActivationOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); param_.Out->Resize(param_.X->dims()); return true; @@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, // return true; // } -// bool FusionElementwiseActivationGradExplicitOp::InferShape() const { +// bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/fusion_elementwise_activation_ops.h b/lite/operators/fusion_elementwise_activation_ops.h index db521284f0..738c216822 100644 --- a/lite/operators/fusion_elementwise_activation_ops.h +++ b/lite/operators/fusion_elementwise_activation_ops.h @@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/gather_op.cc b/lite/operators/gather_op.cc index 858dad8e4c..670cd61c8e 100644 --- a/lite/operators/gather_op.cc +++ b/lite/operators/gather_op.cc @@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const { return true; } -bool GatherOp::InferShape() const { +bool GatherOp::InferShapeImpl() const { auto index_dims = param_.Index->dims(); CHECK(index_dims.size() == 1 || (index_dims.size() == 2 && index_dims[1] == 1)) diff --git a/lite/operators/gather_op.h b/lite/operators/gather_op.h index 58d5a30ffb..d2072c3a6d 100644 --- a/lite/operators/gather_op.h +++ b/lite/operators/gather_op.h @@ -30,7 +30,7 @@ class GatherOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/generate_proposals_op.cc b/lite/operators/generate_proposals_op.cc index a29ef65e97..48e709c348 100644 --- a/lite/operators/generate_proposals_op.cc +++ b/lite/operators/generate_proposals_op.cc @@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const { return true; } -bool GenerateProposalsOpLite::InferShape() const { +bool GenerateProposalsOpLite::InferShapeImpl() const { param_.RpnRois->Resize(std::vector({-1, 4})); param_.RpnRoiProbs->Resize(std::vector({-1, 1})); return true; diff --git a/lite/operators/generate_proposals_op.h b/lite/operators/generate_proposals_op.h index 502bcca1a3..35dee1966b 100644 --- a/lite/operators/generate_proposals_op.h +++ b/lite/operators/generate_proposals_op.h @@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/grid_sampler_op.cc b/lite/operators/grid_sampler_op.cc index 2b13d17da7..97e2b36a6b 100644 --- a/lite/operators/grid_sampler_op.cc +++ b/lite/operators/grid_sampler_op.cc @@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const { return true; } -bool GridSamplerOp::InferShape() const { +bool GridSamplerOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out->Resize(x_dims); return true; diff --git a/lite/operators/grid_sampler_op.h b/lite/operators/grid_sampler_op.h index 035e1b8345..2fba4fe693 100644 --- a/lite/operators/grid_sampler_op.h +++ b/lite/operators/grid_sampler_op.h @@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_op.cc b/lite/operators/gru_op.cc index eb97d65a1a..862a1ff98f 100644 --- a/lite/operators/gru_op.cc +++ b/lite/operators/gru_op.cc @@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const { return true; } -bool GRUOpLite::InferShape() const { +bool GRUOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& weight_dims = param_.weight->dims(); int frame_size = weight_dims[0]; diff --git a/lite/operators/gru_op.h b/lite/operators/gru_op.h index c43f32f0cd..34f87fa793 100644 --- a/lite/operators/gru_op.h +++ b/lite/operators/gru_op.h @@ -30,7 +30,7 @@ class GRUOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_unit_op.cc b/lite/operators/gru_unit_op.cc index ed33507fc3..ad025fbbc1 100644 --- a/lite/operators/gru_unit_op.cc +++ b/lite/operators/gru_unit_op.cc @@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const { return true; } -bool GRUUnitOpLite::InferShape() const { +bool GRUUnitOpLite::InferShapeImpl() const { auto input_dims = param_.input->dims(); auto hidden_prev_dims = param_.hidden_prev->dims(); auto weight_dims = param_.weight->dims(); diff --git a/lite/operators/gru_unit_op.h b/lite/operators/gru_unit_op.h index 301a7e7323..2785e60e95 100644 --- a/lite/operators/gru_unit_op.h +++ b/lite/operators/gru_unit_op.h @@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/im2sequence_op.cc b/lite/operators/im2sequence_op.cc index 40ab2106af..ae7b102946 100644 --- a/lite/operators/im2sequence_op.cc +++ b/lite/operators/im2sequence_op.cc @@ -26,7 +26,7 @@ inline int Im2SeqOutputSize( } bool Im2SequenceOp::CheckShape() const { return true; } -bool Im2SequenceOp::InferShape() const { +bool Im2SequenceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/im2sequence_op.h b/lite/operators/im2sequence_op.h index 83a347c913..62525baaf0 100644 --- a/lite/operators/im2sequence_op.h +++ b/lite/operators/im2sequence_op.h @@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/increment_op.cc b/lite/operators/increment_op.cc index c1928ccbd4..9b34e4f73b 100644 --- a/lite/operators/increment_op.cc +++ b/lite/operators/increment_op.cc @@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const { return true; } -bool IncrementOp::InferShape() const { +bool IncrementOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/increment_op.h b/lite/operators/increment_op.h index f180d527c3..d4e6fd6b1f 100644 --- a/lite/operators/increment_op.h +++ b/lite/operators/increment_op.h @@ -30,7 +30,7 @@ class IncrementOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/instance_norm_op.cc b/lite/operators/instance_norm_op.cc index 510402ba1f..5f685ccfc5 100644 --- a/lite/operators/instance_norm_op.cc +++ b/lite/operators/instance_norm_op.cc @@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const { return true; } -bool InstanceNormOp::InferShape() const { +bool InstanceNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t batch_size = x_dims[0]; int64_t channel_size = x_dims[1]; diff --git a/lite/operators/instance_norm_op.h b/lite/operators/instance_norm_op.h index d128345805..94a1f69fa4 100644 --- a/lite/operators/instance_norm_op.h +++ b/lite/operators/instance_norm_op.h @@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index 1bfb20df4e..0ef22e4290 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const { return true; } -bool InterpolateOp::InferShape() const { +bool InterpolateOp::InferShapeImpl() const { auto X = param_.X; int n = X->dims()[0]; diff --git a/lite/operators/interpolate_op.h b/lite/operators/interpolate_op.h index 5fcf4ef594..2bc9389648 100644 --- a/lite/operators/interpolate_op.h +++ b/lite/operators/interpolate_op.h @@ -31,7 +31,7 @@ class InterpolateOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/io_copy_op.cc b/lite/operators/io_copy_op.cc index 7df636d7b2..05b2d3d800 100644 --- a/lite/operators/io_copy_op.cc +++ b/lite/operators/io_copy_op.cc @@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool IoCopyOp::InferShape() const { +bool IoCopyOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/io_copy_op.h b/lite/operators/io_copy_op.h index 8d6d69d63e..d6922b667d 100644 --- a/lite/operators/io_copy_op.h +++ b/lite/operators/io_copy_op.h @@ -24,7 +24,7 @@ class IoCopyOp : public OpLite { public: explicit IoCopyOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/is_empty_op.cc b/lite/operators/is_empty_op.cc index ed4c69e64e..a62470e4bb 100644 --- a/lite/operators/is_empty_op.cc +++ b/lite/operators/is_empty_op.cc @@ -21,7 +21,7 @@ namespace operators { bool IsEmptyOp::CheckShape() const { return true; } -bool IsEmptyOp::InferShape() const { return true; } +bool IsEmptyOp::InferShapeImpl() const { return true; } bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = diff --git a/lite/operators/is_empty_op.h b/lite/operators/is_empty_op.h index 5bfa0905c7..14c0830c23 100644 --- a/lite/operators/is_empty_op.h +++ b/lite/operators/is_empty_op.h @@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layer_norm_op.cc b/lite/operators/layer_norm_op.cc index 18ea6cbf28..2f50d232e3 100644 --- a/lite/operators/layer_norm_op.cc +++ b/lite/operators/layer_norm_op.cc @@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const { return true; } -bool LayerNormOp::InferShape() const { +bool LayerNormOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); param_.Y->Resize(out_dims); auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[0]; diff --git a/lite/operators/layer_norm_op.h b/lite/operators/layer_norm_op.h index 297f6bdd40..6e15d2f599 100644 --- a/lite/operators/layer_norm_op.h +++ b/lite/operators/layer_norm_op.h @@ -30,7 +30,7 @@ class LayerNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layout_op.cc b/lite/operators/layout_op.cc index 0127256804..d71dab6870 100644 --- a/lite/operators/layout_op.cc +++ b/lite/operators/layout_op.cc @@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool LayoutOp::InferShape() const { +bool LayoutOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/layout_op.h b/lite/operators/layout_op.h index 216d571d7c..f51768863b 100644 --- a/lite/operators/layout_op.h +++ b/lite/operators/layout_op.h @@ -24,7 +24,7 @@ class LayoutOp : public OpLite { public: explicit LayoutOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/lod_reset_op.cc b/lite/operators/lod_reset_op.cc index 1754e709ff..c30c78bbc6 100644 --- a/lite/operators/lod_reset_op.cc +++ b/lite/operators/lod_reset_op.cc @@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const { return true; } -bool LodResetOp::InferShape() const { +bool LodResetOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. param_.Out->Resize(param_.X->dims()); diff --git a/lite/operators/lod_reset_op.h b/lite/operators/lod_reset_op.h index 4e048a9a69..8ca2bc5780 100644 --- a/lite/operators/lod_reset_op.h +++ b/lite/operators/lod_reset_op.h @@ -30,7 +30,7 @@ class LodResetOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/logical_op.cc b/lite/operators/logical_op.cc index 8af982ad53..2dd5b79828 100644 --- a/lite/operators/logical_op.cc +++ b/lite/operators/logical_op.cc @@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const { return true; } -bool BinaryLogicalOp::InferShape() const { +bool BinaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); @@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const { return true; } -bool UnaryLogicalOp::InferShape() const { +bool UnaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/logical_op.h b/lite/operators/logical_op.h index a0fc1d68a6..e784d4d99b 100644 --- a/lite/operators/logical_op.h +++ b/lite/operators/logical_op.h @@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_dequant_op.cc b/lite/operators/lookup_table_dequant_op.cc index b81043bfbf..844544dfad 100644 --- a/lite/operators/lookup_table_dequant_op.cc +++ b/lite/operators/lookup_table_dequant_op.cc @@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const { return true; } -bool LookupTableDequantOpLite::InferShape() const { +bool LookupTableDequantOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_dequant_op.h b/lite/operators/lookup_table_dequant_op.h index 3a9683d5ca..a094cac9a4 100644 --- a/lite/operators/lookup_table_dequant_op.h +++ b/lite/operators/lookup_table_dequant_op.h @@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index df066435a8..9bc22080bf 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const { return true; } -bool LookupTableOpLite::InferShape() const { +bool LookupTableOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_op.h b/lite/operators/lookup_table_op.h index 2701af9840..91ef77cfa1 100644 --- a/lite/operators/lookup_table_op.h +++ b/lite/operators/lookup_table_op.h @@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_v2_op.cc b/lite/operators/lookup_table_v2_op.cc index df642e6191..8c76090df3 100644 --- a/lite/operators/lookup_table_v2_op.cc +++ b/lite/operators/lookup_table_v2_op.cc @@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const { return true; } -bool LookupTableV2OpLite::InferShape() const { +bool LookupTableV2OpLite::InferShapeImpl() const { auto table_dims = param_.W->dims(); auto ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_v2_op.h b/lite/operators/lookup_table_v2_op.h index dabff3f0ca..b0b8829fe6 100644 --- a/lite/operators/lookup_table_v2_op.h +++ b/lite/operators/lookup_table_v2_op.h @@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lrn_op.cc b/lite/operators/lrn_op.cc index aff3e5af55..dcaffe1aa7 100644 --- a/lite/operators/lrn_op.cc +++ b/lite/operators/lrn_op.cc @@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const { return true; } -bool LrnOpLite::InferShape() const { +bool LrnOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/lrn_op.h b/lite/operators/lrn_op.h index a569a77fb4..13dfdefdc6 100644 --- a/lite/operators/lrn_op.h +++ b/lite/operators/lrn_op.h @@ -28,7 +28,7 @@ class LrnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lstm_op.cc b/lite/operators/lstm_op.cc index 36a0d2f53c..d9b6ebfc32 100644 --- a/lite/operators/lstm_op.cc +++ b/lite/operators/lstm_op.cc @@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const { return true; } -bool LstmOp::InferShape() const { +bool LstmOp::InferShapeImpl() const { auto in_dims = param_.Input->dims(); if (param_.H0) { CHECK(param_.C0) << "lstm must has H0 and C0 in the same time"; diff --git a/lite/operators/lstm_op.h b/lite/operators/lstm_op.h index 221bd5c379..38bef385da 100644 --- a/lite/operators/lstm_op.h +++ b/lite/operators/lstm_op.h @@ -30,7 +30,7 @@ class LstmOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/match_matrix_tensor_op.cc b/lite/operators/match_matrix_tensor_op.cc index a8095a94bf..1cc751109f 100644 --- a/lite/operators/match_matrix_tensor_op.cc +++ b/lite/operators/match_matrix_tensor_op.cc @@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const { return true; } -bool MatchMatrixTensorOpLite::InferShape() const { +bool MatchMatrixTensorOpLite::InferShapeImpl() const { const Tensor* x = param_.x; const Tensor* y = param_.y; DDim x_dims = param_.x->dims(); diff --git a/lite/operators/match_matrix_tensor_op.h b/lite/operators/match_matrix_tensor_op.h index 404183ea5b..f1070a81b4 100644 --- a/lite/operators/match_matrix_tensor_op.h +++ b/lite/operators/match_matrix_tensor_op.h @@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 286ade7b21..1cdcdfa167 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const { return true; } -bool MatMulOpLite::InferShape() const { +bool MatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); bool x_transpose = param_.transpose_X; diff --git a/lite/operators/matmul_op.h b/lite/operators/matmul_op.h index 0aa47c89dd..acb9d512f7 100644 --- a/lite/operators/matmul_op.h +++ b/lite/operators/matmul_op.h @@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mean_grad_op.cc b/lite/operators/mean_grad_op.cc index fd17cac14f..55e374735e 100644 --- a/lite/operators/mean_grad_op.cc +++ b/lite/operators/mean_grad_op.cc @@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const { return true; } -bool MeanGradOp::InferShape() const { +bool MeanGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/mean_grad_op.h b/lite/operators/mean_grad_op.h index 1bd604518b..488581a71b 100644 --- a/lite/operators/mean_grad_op.h +++ b/lite/operators/mean_grad_op.h @@ -27,7 +27,7 @@ class MeanGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/mean_op.cc b/lite/operators/mean_op.cc index 618e9001db..9a66d4fbda 100644 --- a/lite/operators/mean_op.cc +++ b/lite/operators/mean_op.cc @@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const { return true; } -bool MeanOp::InferShape() const { +bool MeanOp::InferShapeImpl() const { param_.Out->Resize(std::vector{1}); return true; } diff --git a/lite/operators/mean_op.h b/lite/operators/mean_op.h index 8526842f93..c4dff93ce7 100644 --- a/lite/operators/mean_op.h +++ b/lite/operators/mean_op.h @@ -27,7 +27,7 @@ class MeanOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/merge_lod_tensor_op.cc b/lite/operators/merge_lod_tensor_op.cc index 4258715b1d..704b5cad6f 100644 --- a/lite/operators/merge_lod_tensor_op.cc +++ b/lite/operators/merge_lod_tensor_op.cc @@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const { return true; } -bool MergeLodTensorOpLite::InferShape() const { +bool MergeLodTensorOpLite::InferShapeImpl() const { auto dims = param_.in_true->dims(); param_.out->Resize(dims); return true; diff --git a/lite/operators/merge_lod_tensor_op.h b/lite/operators/merge_lod_tensor_op.h index 788a345168..ec986fac19 100644 --- a/lite/operators/merge_lod_tensor_op.h +++ b/lite/operators/merge_lod_tensor_op.h @@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/mul_grad_op.cc b/lite/operators/mul_grad_op.cc index 8215521637..51e1fb310c 100644 --- a/lite/operators/mul_grad_op.cc +++ b/lite/operators/mul_grad_op.cc @@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const { return true; } -bool MulGradOpLite::InferShape() const { +bool MulGradOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); if (param_.x_grad) { diff --git a/lite/operators/mul_grad_op.h b/lite/operators/mul_grad_op.h index ef61f54f9b..869aa60c62 100644 --- a/lite/operators/mul_grad_op.h +++ b/lite/operators/mul_grad_op.h @@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mul_op.cc b/lite/operators/mul_op.cc index c870abdc89..8641a041e3 100644 --- a/lite/operators/mul_op.cc +++ b/lite/operators/mul_op.cc @@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const { return true; } -bool MulOpLite::InferShape() const { +bool MulOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); diff --git a/lite/operators/mul_op.h b/lite/operators/mul_op.h index caf7bf6ae9..10a2e2efaa 100644 --- a/lite/operators/mul_op.h +++ b/lite/operators/mul_op.h @@ -33,7 +33,7 @@ class MulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/multiclass_nms_op.cc b/lite/operators/multiclass_nms_op.cc index 9ec79f8b57..3102030e4b 100644 --- a/lite/operators/multiclass_nms_op.cc +++ b/lite/operators/multiclass_nms_op.cc @@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const { return true; } -bool MulticlassNmsOpLite::InferShape() const { +bool MulticlassNmsOpLite::InferShapeImpl() const { auto box_dims = param_.bboxes->dims(); auto score_dims = param_.scores->dims(); auto score_size = score_dims.size(); diff --git a/lite/operators/multiclass_nms_op.h b/lite/operators/multiclass_nms_op.h index 7be0d17d74..f74479f3c9 100644 --- a/lite/operators/multiclass_nms_op.h +++ b/lite/operators/multiclass_nms_op.h @@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/negative_op.cc b/lite/operators/negative_op.cc index 4db1dd4fee..2b98f0a90a 100644 --- a/lite/operators/negative_op.cc +++ b/lite/operators/negative_op.cc @@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const { return true; } -bool NegativeOpLite::InferShape() const { +bool NegativeOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/negative_op.h b/lite/operators/negative_op.h index 83f1008c96..04ec925325 100644 --- a/lite/operators/negative_op.h +++ b/lite/operators/negative_op.h @@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/norm_op.cc b/lite/operators/norm_op.cc index dff26966d4..0513e5c942 100644 --- a/lite/operators/norm_op.cc +++ b/lite/operators/norm_op.cc @@ -25,7 +25,7 @@ bool NormOp::CheckShape() const { return true; } -bool NormOp::InferShape() const { +bool NormOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/norm_op.h b/lite/operators/norm_op.h index ae4594ed02..5c69d959be 100644 --- a/lite/operators/norm_op.h +++ b/lite/operators/norm_op.h @@ -30,7 +30,7 @@ class NormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 36d3b42c6b..1e221a602a 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -24,6 +24,7 @@ #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/desc_apis.h" #include "lite/utils/all.h" +#include "lite/utils/variant.h" /* * This file contains all the argument parameter data structure for operators. */ @@ -32,6 +33,16 @@ namespace paddle { namespace lite { namespace operators { +struct ParamBase { + public: + const std::vector* input_tensor_ptrs() const { return nullptr; } + std::vector* output_tensor_ptrs() { return nullptr; } + + protected: + std::shared_ptr> input_tensor_ptrs_cache_{nullptr}; + std::shared_ptr> output_tensor_ptrs_cache_{nullptr}; +}; + using param_t = Any; #define WITH_INT8_CONFIG \ bool enable_int8{false}; \ @@ -41,38 +52,38 @@ using param_t = Any; int bit_length{8}; /// ----------------------- Functional operators ------------------------------ -struct FeedParam { +struct FeedParam : ParamBase { std::vector* feed_list{}; lite::Tensor* out{}; int col; }; -struct FetchParam { +struct FetchParam : ParamBase { const lite::Tensor* input{}; std::vector* fetch_list{}; int col; }; // Helper op for lite framework -struct IoCopyParam { +struct IoCopyParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct LayoutParam { +struct LayoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct CalibParam { +struct CalibParam : ParamBase { const lite::Tensor* input{}; lite::Tensor* output{}; float scale; }; -struct SubgraphParam { +struct SubgraphParam : ParamBase { std::vector input_names{}; std::vector output_names{}; std::vector input_data_names{}; @@ -84,7 +95,7 @@ struct SubgraphParam { /// -------------------------- NN operators ------------------------------------ -struct FcParam { +struct FcParam : ParamBase { lite::Tensor* input{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* bias{nullptr}; @@ -95,9 +106,24 @@ struct FcParam { bool padding_weights{false}; // for int8 WITH_INT8_CONFIG -}; - -struct SearchSeqFcParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({input})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct SearchSeqFcParam : ParamBase { lite::Tensor* x{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* b{nullptr}; @@ -106,7 +132,7 @@ struct SearchSeqFcParam { }; // For Interpolate Op -struct InterpolateParam { +struct InterpolateParam : ParamBase { lite::Tensor* X{}; lite::Tensor* OutSize{}; lite::Tensor* Out{}; @@ -123,7 +149,7 @@ struct InterpolateParam { }; // For Mul Op -struct MulParam { +struct MulParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; lite::Tensor* output{}; @@ -134,7 +160,7 @@ struct MulParam { WITH_INT8_CONFIG }; -struct MulGradParam { +struct MulGradParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* output_grad{}; @@ -146,7 +172,7 @@ struct MulGradParam { }; // For ReduceMean Op -struct ReduceMeanParam { +struct ReduceMeanParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; @@ -155,7 +181,7 @@ struct ReduceMeanParam { }; // For Stack Op -struct StackParam { +struct StackParam : ParamBase { std::vector X; lite::Tensor* Out{}; @@ -163,7 +189,7 @@ struct StackParam { }; // For Power Op -struct PowerParam { +struct PowerParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -172,7 +198,7 @@ struct PowerParam { float power{}; }; -struct ShuffleChannelParam { +struct ShuffleChannelParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -180,7 +206,7 @@ struct ShuffleChannelParam { }; // For Yolobox -struct YoloBoxParam { +struct YoloBoxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ImgSize{}; lite::Tensor* Boxes{}; @@ -193,7 +219,7 @@ struct YoloBoxParam { }; // For Scale Op -struct ScaleParam { +struct ScaleParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; @@ -203,14 +229,29 @@ struct ScaleParam { }; // For Softmax op -struct SoftmaxParam { +struct SoftmaxParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int axis{-1}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Reshape and Reshape2 Op -struct ReshapeParam { +struct ReshapeParam : ParamBase { const lite::Tensor* x{}; std::vector shape_tensor_vct{}; const lite::Tensor* shape_tensor{}; @@ -222,7 +263,7 @@ struct ReshapeParam { }; // For Concat op -struct ConcatParam { +struct ConcatParam : ParamBase { std::vector x{}; lite::Tensor* output{}; int axis{0}; @@ -230,7 +271,7 @@ struct ConcatParam { }; /// ----------------------- activation operators ---------------------- -struct ActivationParam { +struct ActivationParam : ParamBase { const lite::Tensor* X{}; float Leaky_relu_alpha{0}; // leaky_relu param float Relu_clipped_coef{6}; // relu_clipped param @@ -245,7 +286,7 @@ struct ActivationParam { lite_api::ActivationType active_type; }; -struct ActivationGradParam { +struct ActivationGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out{}; // for backward @@ -254,7 +295,7 @@ struct ActivationGradParam { }; // For Convolution op -struct ConvParam { +struct ConvParam : ParamBase { lite::Tensor* x{}; lite::Tensor* filter{}; lite::Tensor* bias{nullptr}; @@ -294,10 +335,26 @@ struct ConvParam { std::vector output_size; // for int8 WITH_INT8_CONFIG + + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For BatchNorm op -struct BatchNormParam { +struct BatchNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* bias{}; lite::Tensor* scale{}; @@ -316,7 +373,7 @@ struct BatchNormParam { }; // For Pooling op -struct PoolParam { +struct PoolParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::string pooling_type{""}; @@ -340,7 +397,7 @@ struct PoolParam { }; // For Dropout op -struct DropoutParam { +struct DropoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* mask{}; @@ -352,7 +409,7 @@ struct DropoutParam { }; // For Split op -struct SplitParam { +struct SplitParam : ParamBase { lite::Tensor* x{}; std::vector output{}; lite::Tensor* axis_tensor; @@ -364,7 +421,7 @@ struct SplitParam { }; // For Transpose op -struct TransposeParam { +struct TransposeParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* xshape{}; @@ -375,7 +432,7 @@ struct TransposeParam { }; /// ----------------------- element wise operators ---------------------- -struct ElementwiseParam { +struct ElementwiseParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -384,9 +441,24 @@ struct ElementwiseParam { WITH_INT8_CONFIG float x_input_scale{1.0}; float y_input_scale{1.0}; -}; - -struct ElementwiseGradParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct ElementwiseGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; const lite::Tensor* OutGrad{}; @@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { }; /// ----------------------- mean operators ---------------------- -struct MeanParam { +struct MeanParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct MeanGradParam { +struct MeanGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out_grad{}; // for backward @@ -417,7 +489,7 @@ struct MeanGradParam { }; /// ----------------------- fill_constant operators ---------------------- -struct FillConstantParam { +struct FillConstantParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; lite::Tensor* shape_tensor{nullptr}; @@ -429,7 +501,7 @@ struct FillConstantParam { lite::Tensor* out{}; }; -struct FillConstantBatchSizeLikeParam { +struct FillConstantBatchSizeLikeParam : ParamBase { const lite::Tensor* input{nullptr}; lite::Tensor* out{nullptr}; @@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam { }; // -struct FakeQuantizeMovingAvgMaxAbsParam { +struct FakeQuantizeMovingAvgMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; const lite::Tensor* in_accum{}; @@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam { float moving_rate{0.9}; }; -struct FakeDequantizeMaxAbsParam { +struct FakeDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; lite::Tensor* out{}; float max_range; }; -struct FakeChannelWiseDequantizeMaxAbsParam { +struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; std::vector scale_tensors{}; lite::Tensor* out{}; @@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam { }; /// ----------------------- sgd operators ---------------------- -struct SGDParam { +struct SGDParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; const lite::Tensor* Param{}; @@ -482,7 +554,7 @@ struct SGDParam { }; /// ----------------------- uniform_random operators ---------------------- -struct UniformRandomParam { +struct UniformRandomParam : ParamBase { std::vector shape{}; float min{-1.0f}; float max{1.0f}; @@ -491,12 +563,12 @@ struct UniformRandomParam { lite::Tensor* Out{}; }; /// ----------------------- negative operators -------------- -struct NegativeParam { +struct NegativeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- pad2d operators ---------------------- -struct Pad2dParam { +struct Pad2dParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector paddings{0, 0, 0, 0}; @@ -506,7 +578,7 @@ struct Pad2dParam { }; /// ----------------------- Crop operators ---------------------- -struct CropParam { +struct CropParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector offsets; @@ -514,21 +586,21 @@ struct CropParam { }; ///----------------------- argmax operators ---------------------- -struct ArgmaxParam { +struct ArgmaxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; int Axis{0}; }; ///----------------------- axpy operators ---------------------- -struct AxpyParam { +struct AxpyParam : ParamBase { lite::Tensor* Scale{}; lite::Tensor* X{}; lite::Tensor* Bias{}; lite::Tensor* Out{}; }; /// ----------------------- GRU unit operators ----------------------f -struct GRUUnitParam { +struct GRUUnitParam : ParamBase { enum ActType { identity, sigmoid, tanh, relu }; const lite::Tensor* input{nullptr}; const lite::Tensor* hidden_prev{nullptr}; @@ -544,7 +616,7 @@ struct GRUUnitParam { }; /// ------------------------------ lrn operators ------------------------------ -struct LrnParam { +struct LrnParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int n{5}; @@ -555,7 +627,7 @@ struct LrnParam { }; /// ----------------------- decode_bboxes operators ---------------------- -struct DecodeBboxesParam { +struct DecodeBboxesParam : ParamBase { const lite::Tensor* loc_data{}; const lite::Tensor* prior_data{}; lite::Tensor* bbox_data{}; @@ -571,7 +643,7 @@ struct DecodeBboxesParam { }; /// ----------------------- box_coder operators ---------------------- -struct BoxCoderParam { +struct BoxCoderParam : ParamBase { const lite::Tensor* prior_box{}; const lite::Tensor* prior_box_var{}; const lite::Tensor* target_box{}; @@ -584,7 +656,7 @@ struct BoxCoderParam { }; /// ----------------------- multiclass_nms operators ---------------------- -struct MulticlassNmsParam { +struct MulticlassNmsParam : ParamBase { const lite::Tensor* bboxes{}; const lite::Tensor* scores{}; lite::Tensor* out{}; @@ -599,7 +671,7 @@ struct MulticlassNmsParam { }; /// ----------------------- priorbox operators ---------------------- -struct PriorBoxParam { +struct PriorBoxParam : ParamBase { lite::Tensor* input{}; lite::Tensor* image{}; lite::Tensor* boxes{}; @@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam { std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f -struct GRUParam { +struct GRUParam : ParamBase { const lite::Tensor* input{nullptr}; const lite::Tensor* h0{nullptr}; const lite::Tensor* weight{nullptr}; @@ -645,7 +717,7 @@ struct GRUParam { }; /// ----------------------- BeamSearchDecode operators ----------------------f -struct BeamSearchDecodeParam { +struct BeamSearchDecodeParam : ParamBase { std::vector* ids{nullptr}; std::vector* scores{nullptr}; lite::Tensor* sentence_ids{nullptr}; @@ -655,21 +727,21 @@ struct BeamSearchDecodeParam { }; /// ----------------------- LookupTable operators ----------------------f -struct LookupTableParam { +struct LookupTableParam : ParamBase { const lite::Tensor* W{nullptr}; const lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct LookupTableDequantParam { +struct LookupTableDequantParam : ParamBase { lite::Tensor* W{nullptr}; lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct Im2SequenceParam { +struct Im2SequenceParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -679,19 +751,19 @@ struct Im2SequenceParam { std::vector out_strides{1, 1}; }; -struct SequenceSoftmaxParam { +struct SequenceSoftmaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct NormParam { +struct NormParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Norm{}; int axis{1}; float epsilon{1e-10}; }; -struct LayerNormParam { +struct LayerNormParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -702,13 +774,13 @@ struct LayerNormParam { float epsilon{1e-5}; }; -struct LogicalParam { +struct LogicalParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; }; -struct CompareParam { +struct CompareParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; bool force_cpu{0}; @@ -716,7 +788,7 @@ struct CompareParam { lite::Tensor* Out{}; }; -struct WhileParam { +struct WhileParam : ParamBase { Scope* scope{}; Tensor* cond{}; cpp::BlockDesc* sub_block{}; @@ -724,32 +796,32 @@ struct WhileParam { std::vector outs{}; }; -struct TopkParam { +struct TopkParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Indices{}; int K{1}; }; -struct IncrementParam { +struct IncrementParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; float step{1}; }; -struct WriteToArrayParam { +struct WriteToArrayParam : ParamBase { const lite::Tensor* X{nullptr}; const lite::Tensor* I{nullptr}; std::vector* Out{nullptr}; }; -struct ReadFromArrayParam { +struct ReadFromArrayParam : ParamBase { const std::vector* X{nullptr}; const lite::Tensor* I{nullptr}; lite::Tensor* Out{nullptr}; }; -struct BeamSearchParam { +struct BeamSearchParam : ParamBase { const lite::Tensor* pre_ids{}; const lite::Tensor* pre_scores{}; const lite::Tensor* ids{}; @@ -763,7 +835,7 @@ struct BeamSearchParam { bool is_accumulated; }; -struct SequencePoolParam { +struct SequencePoolParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::string pool_type{"AVERAGE"}; @@ -773,7 +845,7 @@ struct SequencePoolParam { #endif }; -struct SequenceConvParam { +struct SequenceConvParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Filter{}; lite::Tensor* Out{}; @@ -782,13 +854,13 @@ struct SequenceConvParam { int contextLength; }; -struct SequencePoolConcatParam { +struct SequencePoolConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; std::vector pool_type{}; }; -struct SearchGroupPaddingParam { +struct SearchGroupPaddingParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out_emb_padding{}; lite::Tensor* out_new{}; @@ -796,36 +868,36 @@ struct SearchGroupPaddingParam { int pad_id; }; -struct SequenceReshapeParam { +struct SequenceReshapeParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int new_dim; }; -struct SequenceExpandParam { +struct SequenceExpandParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; int ref_level{-1}; }; -struct SequenceExpandAsParam { +struct SequenceExpandAsParam : ParamBase { const lite::Tensor* x{nullptr}; const lite::Tensor* y{nullptr}; lite::Tensor* out{nullptr}; }; -struct SequenceReverseParam { +struct SequenceReverseParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct SequenceConcatParam { +struct SequenceConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; }; -struct AttentionPaddingMaskParam { +struct AttentionPaddingMaskParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int pad_id; @@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam { lite::Tensor* pad_begin{}; }; -struct SequenceArithmeticParam { +struct SequenceArithmeticParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int op_type{1}; lite::Tensor* Out{}; }; -struct ReduceMaxParam { +struct ReduceMaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector dim{}; bool keep_dim{false}; }; -struct LodResetParam { +struct LodResetParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -856,12 +928,12 @@ struct LodResetParam { bool append; }; -struct IsEmptyParam { +struct IsEmptyParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct ReduceParam { +struct ReduceParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::vector dim{0}; @@ -869,7 +941,7 @@ struct ReduceParam { bool reduce_all{false}; }; -struct VarConv2DParam { +struct VarConv2DParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -888,19 +960,19 @@ struct VarConv2DParam { }; /// ----------------------- shape operators ---------------------- -struct ShapeParam { +struct ShapeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct CastParam { +struct CastParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int out_dtype{2}; int in_dtype{2}; }; -struct SliceParam { +struct SliceParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector axes{}; @@ -914,7 +986,7 @@ struct SliceParam { lite::Tensor* EndsTensor{nullptr}; }; -struct AffineChannelParam { +struct AffineChannelParam : ParamBase { const lite::Tensor* X{}; // X is 4D tensor const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -922,7 +994,7 @@ struct AffineChannelParam { lite::Tensor* Out{}; }; -struct AnchorGeneratorParam { +struct AnchorGeneratorParam : ParamBase { const lite::Tensor* Input{}; std::vector anchor_sizes{}; std::vector aspect_ratios{}; @@ -934,7 +1006,7 @@ struct AnchorGeneratorParam { lite::Tensor* Variances{}; }; -struct GenerateProposalsParam { +struct GenerateProposalsParam : ParamBase { // inputs const lite::Tensor* Scores{}; const lite::Tensor* BboxDeltas{}; @@ -954,14 +1026,14 @@ struct GenerateProposalsParam { lite::Tensor* RpnRoiProbs{}; }; /// ----------------------- squeeze operators ---------------------- -struct SqueezeParam { +struct SqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; }; -struct UnsqueezeParam { +struct UnsqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; @@ -971,14 +1043,14 @@ struct UnsqueezeParam { }; /// ----------------------- expand operators ---------------------- -struct ExpandParam { +struct ExpandParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector expand_times{}; }; /// ----------------------- matmul operators ---------------------- -struct MatMulParam { +struct MatMulParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -987,20 +1059,20 @@ struct MatMulParam { float alpha{1.0f}; }; -struct GatherParam { +struct GatherParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Index{}; lite::Tensor* Out{}; }; /// ----------------------- assign operators ----------------------- -struct AssignParam { +struct AssignParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- roi_align operators ----------------------- -struct RoiAlignParam { +struct RoiAlignParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ROIs{}; lite::Tensor* Out{}; @@ -1011,13 +1083,13 @@ struct RoiAlignParam { }; /// ----------------------- box_clip operators ----------------------- -struct BoxClipParam { +struct BoxClipParam : ParamBase { const lite::Tensor* Input{}; const lite::Tensor* ImInfo{}; lite::Tensor* Output{}; }; -struct RangeParam { +struct RangeParam : ParamBase { const lite::Tensor* Start; const lite::Tensor* End; const lite::Tensor* Step; @@ -1025,7 +1097,7 @@ struct RangeParam { }; /// ----------------------- assign_value operators ----------------------- -struct AssignValueParam { +struct AssignValueParam : ParamBase { std::vector shape{}; int dtype{}; std::vector fp32_values{}; @@ -1034,7 +1106,7 @@ struct AssignValueParam { }; /// --------------- sequence_topk_avg_pooling operators ------------------ -struct SequenceTopkAvgPoolingParam { +struct SequenceTopkAvgPoolingParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam { }; /// --------------- search_fc operators ------------------ -struct SearchFcParam { +struct SearchFcParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* W{}; const lite::Tensor* b{}; @@ -1053,7 +1125,7 @@ struct SearchFcParam { int out_size{}; }; /// --------------------- match_matrix_tensor operators -------------------- -struct MatchMatrixTensorParam { +struct MatchMatrixTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* w{}; @@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam { }; /// --------------------- search_seq_depadding operators -------------------- -struct SearchSeqDepaddingParam { +struct SearchSeqDepaddingParam : ParamBase { const lite::Tensor* pad{}; const lite::Tensor* src{}; lite::Tensor* out{}; }; /// --------------------- search_grnn operators -------------------- -struct SearchGrnnParam { +struct SearchGrnnParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* wi{}; const lite::Tensor* wh{}; @@ -1084,7 +1156,7 @@ struct SearchGrnnParam { lite::Tensor* layout_input{}; }; -struct SplitLodTensorParam { +struct SplitLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; lite::Tensor* out_true{}; @@ -1092,7 +1164,7 @@ struct SplitLodTensorParam { int level{}; }; -struct MergeLodTensorParam { +struct MergeLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; const lite::Tensor* in_true{}; @@ -1101,7 +1173,7 @@ struct MergeLodTensorParam { int level{}; }; -struct ConditionalBlockParam { +struct ConditionalBlockParam : ParamBase { const lite::Tensor* cond{}; std::vector x{}; std::vector outs{}; @@ -1110,14 +1182,14 @@ struct ConditionalBlockParam { bool is_scalar_condition{}; }; -struct CollectFpnProposalsParam { +struct CollectFpnProposalsParam : ParamBase { std::vector multi_level_rois{}; std::vector multi_level_scores{}; lite::Tensor* fpn_rois{}; int post_nms_topN{}; }; -struct DistributeFpnProposalsParam { +struct DistributeFpnProposalsParam : ParamBase { const lite::Tensor* fpn_rois{}; std::vector multi_fpn_rois{}; lite::Tensor* restore_index{}; @@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam { }; /// --------------------- instance_norm operators -------------------- -struct InstanceNormParam { +struct InstanceNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* bias{}; @@ -1138,12 +1210,12 @@ struct InstanceNormParam { float epsilon; }; /// --------------------- grid sampler operators -------------------- -struct GridSamplerParam { +struct GridSamplerParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* grid{}; }; -struct LstmParam { +struct LstmParam : ParamBase { lite::Tensor* Input{}; lite::Tensor* Weight{}; lite::Tensor* Bias{}; @@ -1160,7 +1232,7 @@ struct LstmParam { std::string candidate_activation; }; -struct CrfDecodingParam { +struct CrfDecodingParam : ParamBase { lite::Tensor* emission{}; lite::Tensor* transition{}; lite::Tensor* label{}; diff --git a/lite/operators/pad2d_op.cc b/lite/operators/pad2d_op.cc index ff522b94b9..7af657c888 100644 --- a/lite/operators/pad2d_op.cc +++ b/lite/operators/pad2d_op.cc @@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const { return true; } -bool Pad2dOpLite::InferShape() const { +bool Pad2dOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1]; diff --git a/lite/operators/pad2d_op.h b/lite/operators/pad2d_op.h index c51a76a7ae..c6d2e56548 100644 --- a/lite/operators/pad2d_op.h +++ b/lite/operators/pad2d_op.h @@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/pool_op.cc b/lite/operators/pool_op.cc index c6f6eed28f..5fb990928e 100644 --- a/lite/operators/pool_op.cc +++ b/lite/operators/pool_op.cc @@ -60,7 +60,7 @@ int PoolOutputSize(int input_size, return output_size; } -bool PoolOpLite::InferShape() const { +bool PoolOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); std::vector& ksize = param_.ksize; // dynamic update 4-pad diff --git a/lite/operators/pool_op.h b/lite/operators/pool_op.h index c44875ff95..3fcf37e634 100644 --- a/lite/operators/pool_op.h +++ b/lite/operators/pool_op.h @@ -37,7 +37,7 @@ class PoolOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { diff --git a/lite/operators/power_op.cc b/lite/operators/power_op.cc index 578d95ad53..83c9edfaca 100644 --- a/lite/operators/power_op.cc +++ b/lite/operators/power_op.cc @@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const { return true; } -bool PowerOp::InferShape() const { +bool PowerOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/power_op.h b/lite/operators/power_op.h index a6d43f4394..e89dfa7b8f 100644 --- a/lite/operators/power_op.h +++ b/lite/operators/power_op.h @@ -31,7 +31,7 @@ class PowerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index c4717c8185..f1b715a46e 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const { return true; } -bool PriorBoxOpLite::InferShape() const { return true; } +bool PriorBoxOpLite::InferShapeImpl() const { return true; } bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto input = opdesc.Input("Input").front(); diff --git a/lite/operators/prior_box_op.h b/lite/operators/prior_box_op.h index a393e80315..1348b7cc73 100644 --- a/lite/operators/prior_box_op.h +++ b/lite/operators/prior_box_op.h @@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/range_op.cc b/lite/operators/range_op.cc index a179d8ffe7..19f474ba43 100644 --- a/lite/operators/range_op.cc +++ b/lite/operators/range_op.cc @@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) { : std::ceil(std::abs((end - start) / step)); } -bool RangeOpLite::InferShape() const { +bool RangeOpLite::InferShapeImpl() const { int start = param_.Start->data()[0]; int end = param_.End->data()[0]; int step = param_.Step->data()[0]; diff --git a/lite/operators/range_op.h b/lite/operators/range_op.h index a1c7d4d4cc..982ef5abf2 100644 --- a/lite/operators/range_op.h +++ b/lite/operators/range_op.h @@ -29,7 +29,7 @@ class RangeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/read_from_array_op.cc b/lite/operators/read_from_array_op.cc index 930eff1ff5..495fd752c9 100644 --- a/lite/operators/read_from_array_op.cc +++ b/lite/operators/read_from_array_op.cc @@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { return true; } -bool ReadFromArrayOp::InferShape() const { +bool ReadFromArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; auto out_dims = (*param_.X)[id].dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/read_from_array_op.h b/lite/operators/read_from_array_op.h index 5c7ba1468f..299a3abaed 100644 --- a/lite/operators/read_from_array_op.h +++ b/lite/operators/read_from_array_op.h @@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_max_op.cc b/lite/operators/reduce_max_op.cc index d7d90ee1f4..ba48acd11f 100644 --- a/lite/operators/reduce_max_op.cc +++ b/lite/operators/reduce_max_op.cc @@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const { return true; } -bool ReduceMaxOp::InferShape() const { +bool ReduceMaxOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_max_op.h b/lite/operators/reduce_max_op.h index 60e263f1b9..54b136a757 100644 --- a/lite/operators/reduce_max_op.h +++ b/lite/operators/reduce_max_op.h @@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite { ReduceMaxOp() {} explicit ReduceMaxOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_mean_op.cc b/lite/operators/reduce_mean_op.cc index bce31c315c..c5baca5e87 100644 --- a/lite/operators/reduce_mean_op.cc +++ b/lite/operators/reduce_mean_op.cc @@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const { return true; } -bool ReduceMeanOp::InferShape() const { +bool ReduceMeanOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_mean_op.h b/lite/operators/reduce_mean_op.h index e701a1132a..43fe955690 100644 --- a/lite/operators/reduce_mean_op.h +++ b/lite/operators/reduce_mean_op.h @@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite { ReduceMeanOp() {} explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index e2cc56b416..1af6daf8c7 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const { return true; } -bool ReduceOp::InferShape() const { +bool ReduceOp::InferShapeImpl() const { const auto &x_dims = param_.x->dims(); auto x_rank = x_dims.size(); auto dims = param_.dim; diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h index 0063aba1fa..d4fdbd1135 100644 --- a/lite/operators/reduce_ops.h +++ b/lite/operators/reduce_ops.h @@ -30,7 +30,7 @@ class ReduceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_prod_op.cc b/lite/operators/reduce_prod_op.cc index 90da13c864..5a6194b36b 100644 --- a/lite/operators/reduce_prod_op.cc +++ b/lite/operators/reduce_prod_op.cc @@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const { return true; } -bool ReduceProdOpLite::InferShape() const { +bool ReduceProdOpLite::InferShapeImpl() const { auto x = param_.x; auto out = param_.output; std::vector dim = param_.dim; diff --git a/lite/operators/reduce_prod_op.h b/lite/operators/reduce_prod_op.h index 5f7a6dcdf9..d8bb1400b9 100644 --- a/lite/operators/reduce_prod_op.h +++ b/lite/operators/reduce_prod_op.h @@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/relu_op.cc b/lite/operators/relu_op.cc index 9fa3ac8f30..e5f51676c6 100644 --- a/lite/operators/relu_op.cc +++ b/lite/operators/relu_op.cc @@ -20,7 +20,7 @@ namespace lite { namespace operators { bool ReluOp::CheckShape() const { return true; } -bool ReluOp::InferShape() const { +bool ReluOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. diff --git a/lite/operators/relu_op.h b/lite/operators/relu_op.h index 23ca7ff16b..7577f2ffba 100644 --- a/lite/operators/relu_op.h +++ b/lite/operators/relu_op.h @@ -30,7 +30,7 @@ class ReluOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 655ac58bdc..5c55eb4aa5 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const { return true; } -bool ReshapeOp::InferShape() const { +bool ReshapeOp::InferShapeImpl() const { const auto &shape_tensor_vct = param_.shape_tensor_vct; auto *shape_tensor = param_.shape_tensor; const auto &shape_vct = param_.shape_vct; @@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const { return true; } -bool Reshape2Op::InferShape() const { - ReshapeOp::InferShape(); +bool Reshape2Op::InferShapeImpl() const { + ReshapeOp::InferShapeImpl(); const auto &x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; diff --git a/lite/operators/reshape_op.h b/lite/operators/reshape_op.h index 1df49fb5f4..9dc302ec97 100644 --- a/lite/operators/reshape_op.h +++ b/lite/operators/reshape_op.h @@ -30,7 +30,7 @@ class ReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/roi_align_op.cc b/lite/operators/roi_align_op.cc index 2f65c0197e..001934dcf8 100644 --- a/lite/operators/roi_align_op.cc +++ b/lite/operators/roi_align_op.cc @@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const { return true; } -bool RoiAlignOpLite::InferShape() const { +bool RoiAlignOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); auto rois_dims = param_.ROIs->dims(); diff --git a/lite/operators/roi_align_op.h b/lite/operators/roi_align_op.h index f3dd1a47f5..65cc72534a 100644 --- a/lite/operators/roi_align_op.h +++ b/lite/operators/roi_align_op.h @@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/scale_op.cc b/lite/operators/scale_op.cc index 1398ea4811..3236277187 100644 --- a/lite/operators/scale_op.cc +++ b/lite/operators/scale_op.cc @@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const { return true; } -bool ScaleOp::InferShape() const { +bool ScaleOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/scale_op.h b/lite/operators/scale_op.h index 684da4ed47..38970bfcfd 100644 --- a/lite/operators/scale_op.h +++ b/lite/operators/scale_op.h @@ -30,7 +30,7 @@ class ScaleOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_aligned_mat_mul_op.cc b/lite/operators/search_aligned_mat_mul_op.cc index 43a276e3c7..65ccbc2b79 100644 --- a/lite/operators/search_aligned_mat_mul_op.cc +++ b/lite/operators/search_aligned_mat_mul_op.cc @@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const { return true; } -bool SearchAlignedMatMulOpLite::InferShape() const { +bool SearchAlignedMatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); const auto& x_lod = param_.X->lod(); diff --git a/lite/operators/search_aligned_mat_mul_op.h b/lite/operators/search_aligned_mat_mul_op.h index 7321b7e9d1..8242e06d01 100644 --- a/lite/operators/search_aligned_mat_mul_op.h +++ b/lite/operators/search_aligned_mat_mul_op.h @@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_fc_op.cc b/lite/operators/search_fc_op.cc index 2e77e36162..3c64f24e48 100644 --- a/lite/operators/search_fc_op.cc +++ b/lite/operators/search_fc_op.cc @@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const { return true; } -bool SearchFcOpLite::InferShape() const { +bool SearchFcOpLite::InferShapeImpl() const { auto out_size = param_.out_size; lite::DDim dims(std::vector({-1, out_size})); param_.Out->Resize(dims); diff --git a/lite/operators/search_fc_op.h b/lite/operators/search_fc_op.h index a871cadd33..235c24c57f 100644 --- a/lite/operators/search_fc_op.h +++ b/lite/operators/search_fc_op.h @@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_grnn_op.cc b/lite/operators/search_grnn_op.cc index b56ae820bf..1ced477c10 100644 --- a/lite/operators/search_grnn_op.cc +++ b/lite/operators/search_grnn_op.cc @@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const { return true; } -bool SearchGrnnOpLite::InferShape() const { +bool SearchGrnnOpLite::InferShapeImpl() const { const auto& x_dims = param_.x->dims(); const auto& x_lod = param_.x->lod(); CHECK_OR_FALSE(!x_lod.empty()); diff --git a/lite/operators/search_grnn_op.h b/lite/operators/search_grnn_op.h index 670af8a6c9..de4b1d8a5c 100644 --- a/lite/operators/search_grnn_op.h +++ b/lite/operators/search_grnn_op.h @@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_group_padding_op.cc b/lite/operators/search_group_padding_op.cc index 5ba4dde275..b97c710109 100644 --- a/lite/operators/search_group_padding_op.cc +++ b/lite/operators/search_group_padding_op.cc @@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const { return true; } -bool SearchGroupPaddingOp::InferShape() const { +bool SearchGroupPaddingOp::InferShapeImpl() const { std::vector x_dims = param_.x->dims().Vectorize(); param_.out_emb_padding->Resize({-1, x_dims[1]}); diff --git a/lite/operators/search_group_padding_op.h b/lite/operators/search_group_padding_op.h index a8e96c9697..6a93c74101 100644 --- a/lite/operators/search_group_padding_op.h +++ b/lite/operators/search_group_padding_op.h @@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite { SearchGroupPaddingOp() {} explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "search_group_padding"; } diff --git a/lite/operators/search_seq_depadding_op.cc b/lite/operators/search_seq_depadding_op.cc index 12d5123e05..6ad4f1ab17 100644 --- a/lite/operators/search_seq_depadding_op.cc +++ b/lite/operators/search_seq_depadding_op.cc @@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const { return true; } -bool SearchSeqDepaddingOpLite::InferShape() const { +bool SearchSeqDepaddingOpLite::InferShapeImpl() const { DDim pad_dims = param_.pad->dims(); param_.out->Resize({-1, pad_dims[1]}); return true; diff --git a/lite/operators/search_seq_depadding_op.h b/lite/operators/search_seq_depadding_op.h index 445d9e0f3b..aa1cc22d4b 100644 --- a/lite/operators/search_seq_depadding_op.h +++ b/lite/operators/search_seq_depadding_op.h @@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_seq_fc_op.cc b/lite/operators/search_seq_fc_op.cc index c5cca5331a..2a4525ac6e 100644 --- a/lite/operators/search_seq_fc_op.cc +++ b/lite/operators/search_seq_fc_op.cc @@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const { return true; } -bool SearchSeqFcOpLite::InferShape() const { +bool SearchSeqFcOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto w_dims = param_.w->dims(); const auto& x_lod = param_.x->lod(); diff --git a/lite/operators/search_seq_fc_op.h b/lite/operators/search_seq_fc_op.h index 3c4f7d82bf..bacafcfe6f 100644 --- a/lite/operators/search_seq_fc_op.h +++ b/lite/operators/search_seq_fc_op.h @@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_seq_softmax_op.cc b/lite/operators/search_seq_softmax_op.cc index 973ffa04c4..9b0550341c 100644 --- a/lite/operators/search_seq_softmax_op.cc +++ b/lite/operators/search_seq_softmax_op.cc @@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const { return true; } -bool SearchSeqSoftmaxOp::InferShape() const { +bool SearchSeqSoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); param_.output->set_lod(param_.x->lod()); return true; diff --git a/lite/operators/search_seq_softmax_op.h b/lite/operators/search_seq_softmax_op.h index f97e8ddd3a..dca3619eab 100644 --- a/lite/operators/search_seq_softmax_op.h +++ b/lite/operators/search_seq_softmax_op.h @@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_arithmetic_op.cc b/lite/operators/sequence_arithmetic_op.cc index 29c39ebc23..e17a179a86 100644 --- a/lite/operators/sequence_arithmetic_op.cc +++ b/lite/operators/sequence_arithmetic_op.cc @@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const { return true; } -bool SequenceArithmeticOp::InferShape() const { +bool SequenceArithmeticOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); param_.Out->set_lod(param_.X->lod()); return true; diff --git a/lite/operators/sequence_arithmetic_op.h b/lite/operators/sequence_arithmetic_op.h index 9f844dfbf4..cf9ef1583a 100644 --- a/lite/operators/sequence_arithmetic_op.h +++ b/lite/operators/sequence_arithmetic_op.h @@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_concat_op.cc b/lite/operators/sequence_concat_op.cc index 88afe5e00f..91c70c0d2f 100644 --- a/lite/operators/sequence_concat_op.cc +++ b/lite/operators/sequence_concat_op.cc @@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const { return true; } -bool SequenceConcatOp::InferShape() const { return true; } +bool SequenceConcatOp::InferShapeImpl() const { return true; } bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/sequence_concat_op.h b/lite/operators/sequence_concat_op.h index 8cdc07ebca..c7d61db785 100644 --- a/lite/operators/sequence_concat_op.h +++ b/lite/operators/sequence_concat_op.h @@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite { SequenceConcatOp() {} explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_concat"; } diff --git a/lite/operators/sequence_conv_op.cc b/lite/operators/sequence_conv_op.cc index 89596a22c6..681e05c9b6 100644 --- a/lite/operators/sequence_conv_op.cc +++ b/lite/operators/sequence_conv_op.cc @@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const { return true; } -bool SequenceConvOp::InferShape() const { +bool SequenceConvOp::InferShapeImpl() const { const auto *input = param_.X; const auto *filter = param_.Filter; auto in_dims = input->dims(); diff --git a/lite/operators/sequence_conv_op.h b/lite/operators/sequence_conv_op.h index 34d65d3cc9..3ec7ac4d3d 100644 --- a/lite/operators/sequence_conv_op.h +++ b/lite/operators/sequence_conv_op.h @@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite { SequenceConvOp() {} explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_expand_as_op.cc b/lite/operators/sequence_expand_as_op.cc index 22a4743103..02c787b5a5 100644 --- a/lite/operators/sequence_expand_as_op.cc +++ b/lite/operators/sequence_expand_as_op.cc @@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const { return true; } -bool SequenceExpandAsOpLite::InferShape() const { +bool SequenceExpandAsOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto y_lod = param_.y->lod(); auto out_dims = x_dims; diff --git a/lite/operators/sequence_expand_as_op.h b/lite/operators/sequence_expand_as_op.h index 2eae8a26da..19d6905c1a 100644 --- a/lite/operators/sequence_expand_as_op.h +++ b/lite/operators/sequence_expand_as_op.h @@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_expand_op.cc b/lite/operators/sequence_expand_op.cc index 0a5427a62f..4bb3c66b26 100644 --- a/lite/operators/sequence_expand_op.cc +++ b/lite/operators/sequence_expand_op.cc @@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const { return true; } -bool SequenceExpandOp::InferShape() const { +bool SequenceExpandOp::InferShapeImpl() const { const auto x_lod = param_.X->lod(); auto x_dims = param_.X->dims(); int ref_level = param_.ref_level; diff --git a/lite/operators/sequence_expand_op.h b/lite/operators/sequence_expand_op.h index da4b2fe71e..fffe2110d8 100644 --- a/lite/operators/sequence_expand_op.h +++ b/lite/operators/sequence_expand_op.h @@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_pool_concat_op.cc b/lite/operators/sequence_pool_concat_op.cc index 9ee0d4d596..ce490e8246 100644 --- a/lite/operators/sequence_pool_concat_op.cc +++ b/lite/operators/sequence_pool_concat_op.cc @@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const { return true; } -bool SequencePoolConcatOp::InferShape() const { +bool SequencePoolConcatOp::InferShapeImpl() const { int out_dim = 0; for (int i = 0; i < param_.X.size(); ++i) { out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size()); diff --git a/lite/operators/sequence_pool_concat_op.h b/lite/operators/sequence_pool_concat_op.h index 7a70ceaf29..58e6fc18ba 100644 --- a/lite/operators/sequence_pool_concat_op.h +++ b/lite/operators/sequence_pool_concat_op.h @@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite { SequencePoolConcatOp() {} explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_pool_op.cc b/lite/operators/sequence_pool_op.cc index be3726ffe7..6b4f7d8b78 100644 --- a/lite/operators/sequence_pool_op.cc +++ b/lite/operators/sequence_pool_op.cc @@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const { return true; } -bool SequencePoolOp::InferShape() const { +bool SequencePoolOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); out_dims[0] = input->lod()[0].size() - 1; diff --git a/lite/operators/sequence_pool_op.h b/lite/operators/sequence_pool_op.h index 215dd113a3..7b9e36bb5e 100644 --- a/lite/operators/sequence_pool_op.h +++ b/lite/operators/sequence_pool_op.h @@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite { SequencePoolOp() {} explicit SequencePoolOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc index c7e86af650..37ebd8a2ba 100644 --- a/lite/operators/sequence_reshape_op.cc +++ b/lite/operators/sequence_reshape_op.cc @@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const { return true; } -bool SequenceReshapeOp::InferShape() const { +bool SequenceReshapeOp::InferShapeImpl() const { int new_dim = param_.new_dim; auto x_numel = param_.x->dims().production(); std::vector out_shape{x_numel / new_dim, diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h index c8378aebc4..4ef395bdaa 100644 --- a/lite/operators/sequence_reshape_op.h +++ b/lite/operators/sequence_reshape_op.h @@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_reverse_op.cc b/lite/operators/sequence_reverse_op.cc index dd8fa2e8fd..19a47cac9d 100644 --- a/lite/operators/sequence_reverse_op.cc +++ b/lite/operators/sequence_reverse_op.cc @@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const { return true; } -bool SequenceReverseOp::InferShape() const { +bool SequenceReverseOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/sequence_reverse_op.h b/lite/operators/sequence_reverse_op.h index 326d0f6892..68d9fdb0f1 100644 --- a/lite/operators/sequence_reverse_op.h +++ b/lite/operators/sequence_reverse_op.h @@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite { SequenceReverseOp() {} explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_reverse"; } diff --git a/lite/operators/sequence_softmax_op.cc b/lite/operators/sequence_softmax_op.cc index d106097ed5..eb1821129d 100644 --- a/lite/operators/sequence_softmax_op.cc +++ b/lite/operators/sequence_softmax_op.cc @@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool SequenceSoftmaxOp::InferShape() const { +bool SequenceSoftmaxOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/sequence_softmax_op.h b/lite/operators/sequence_softmax_op.h index 37dfc0d444..5942cb0441 100644 --- a/lite/operators/sequence_softmax_op.h +++ b/lite/operators/sequence_softmax_op.h @@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_topk_avg_pooling_op.cc b/lite/operators/sequence_topk_avg_pooling_op.cc index 6f5cbeeeee..cb6f12c4b3 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.cc +++ b/lite/operators/sequence_topk_avg_pooling_op.cc @@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const { return true; } -bool SequenceTopkAvgPoolingOpLite::InferShape() const { +bool SequenceTopkAvgPoolingOpLite::InferShapeImpl() const { int channel_num = param_.channel_num; std::vector topks = param_.topks; auto row_dim = param_.ROW->dims(); diff --git a/lite/operators/sequence_topk_avg_pooling_op.h b/lite/operators/sequence_topk_avg_pooling_op.h index 1c1cfe3a9c..a619edc908 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.h +++ b/lite/operators/sequence_topk_avg_pooling_op.h @@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sgd_op.cc b/lite/operators/sgd_op.cc index 6214542595..eb8cb6b724 100644 --- a/lite/operators/sgd_op.cc +++ b/lite/operators/sgd_op.cc @@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const { return true; } -bool SGDOpLite::InferShape() const { +bool SGDOpLite::InferShapeImpl() const { param_.ParamOut->Resize(param_.Param->dims()); return true; } diff --git a/lite/operators/sgd_op.h b/lite/operators/sgd_op.h index 9159bf95a6..6a29c8bfa6 100644 --- a/lite/operators/sgd_op.h +++ b/lite/operators/sgd_op.h @@ -33,7 +33,7 @@ class SGDOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/shape_op.cc b/lite/operators/shape_op.cc index c6d5dc4d01..1661a90926 100644 --- a/lite/operators/shape_op.cc +++ b/lite/operators/shape_op.cc @@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const { return true; } -bool ShapeOpLite::InferShape() const { +bool ShapeOpLite::InferShapeImpl() const { std::vector shape_vec; shape_vec.push_back(static_cast(param_.X->dims().size())); param_.Out->Resize(shape_vec); diff --git a/lite/operators/shape_op.h b/lite/operators/shape_op.h index ada9961c75..6512b8ac02 100644 --- a/lite/operators/shape_op.h +++ b/lite/operators/shape_op.h @@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/shuffle_channel_op.cc b/lite/operators/shuffle_channel_op.cc index 926aa932f3..d45643a3d8 100644 --- a/lite/operators/shuffle_channel_op.cc +++ b/lite/operators/shuffle_channel_op.cc @@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const { return true; } -bool ShuffleChannelOpLite::InferShape() const { +bool ShuffleChannelOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/shuffle_channel_op.h b/lite/operators/shuffle_channel_op.h index c48a47f619..7683458981 100644 --- a/lite/operators/shuffle_channel_op.h +++ b/lite/operators/shuffle_channel_op.h @@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/slice_op.cc b/lite/operators/slice_op.cc index bbc3d1429e..cf7d94535c 100644 --- a/lite/operators/slice_op.cc +++ b/lite/operators/slice_op.cc @@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const { return true; } -bool SliceOp::InferShape() const { +bool SliceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto in_dims = param_.X->dims(); diff --git a/lite/operators/slice_op.h b/lite/operators/slice_op.h index 936a1405f4..ec69f23d8d 100644 --- a/lite/operators/slice_op.h +++ b/lite/operators/slice_op.h @@ -30,7 +30,7 @@ class SliceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 0989c91397..000953007c 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const { return true; } -bool SoftmaxOp::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (param_.x->dims() == last_input_shapes[0] && - param_.x->lod() == last_input_lods[0]) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - return true; -} - -bool SoftmaxOp::InferShape() const { +bool SoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); diff --git a/lite/operators/softmax_op.h b/lite/operators/softmax_op.h index c65d039fda..20dc2f461e 100644 --- a/lite/operators/softmax_op.h +++ b/lite/operators/softmax_op.h @@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_lod_tensor_op.cc b/lite/operators/split_lod_tensor_op.cc index 9b665b6026..2900c8165d 100644 --- a/lite/operators/split_lod_tensor_op.cc +++ b/lite/operators/split_lod_tensor_op.cc @@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const { return true; } -bool SplitLodTensorOpLite::InferShape() const { +bool SplitLodTensorOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out_true->Resize(x_dims); param_.out_false->Resize(x_dims); diff --git a/lite/operators/split_lod_tensor_op.h b/lite/operators/split_lod_tensor_op.h index c7feef4f85..fb7f85de5c 100644 --- a/lite/operators/split_lod_tensor_op.h +++ b/lite/operators/split_lod_tensor_op.h @@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 834d68a315..71deb5631d 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const { return true; } -bool SplitOp::InferShape() const { +bool SplitOp::InferShapeImpl() const { const auto &outs = param_.output; auto in_dims = param_.x->dims(); int axis = param_.axis; diff --git a/lite/operators/split_op.h b/lite/operators/split_op.h index 6619074215..3bb40a8d35 100644 --- a/lite/operators/split_op.h +++ b/lite/operators/split_op.h @@ -30,7 +30,7 @@ class SplitOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index 01f96c28ff..633a6b4d4e 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const { return true; } -bool SqueezeOp::InferShape() const { +bool SqueezeOp::InferShapeImpl() const { std::vector squeeze_dims = param_.axes; DDim in_dims = param_.X->dims(); DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true); @@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const { return true; } -bool Squeeze2Op::InferShape() const { - SqueezeOp::InferShape(); +bool Squeeze2Op::InferShapeImpl() const { + SqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/squeeze_op.h b/lite/operators/squeeze_op.h index 1a550c5fbe..983e17acf6 100644 --- a/lite/operators/squeeze_op.h +++ b/lite/operators/squeeze_op.h @@ -30,7 +30,7 @@ class SqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/stack_op.cc b/lite/operators/stack_op.cc index 8fdf61e822..0f9ba6662b 100644 --- a/lite/operators/stack_op.cc +++ b/lite/operators/stack_op.cc @@ -32,7 +32,7 @@ bool StackOp::CheckShape() const { return true; } -bool StackOp::InferShape() const { +bool StackOp::InferShapeImpl() const { auto input = param_.X; auto input_dims = input[0]->dims(); int axis = param_.axis; diff --git a/lite/operators/stack_op.h b/lite/operators/stack_op.h index 068d905338..9ce73057a3 100644 --- a/lite/operators/stack_op.h +++ b/lite/operators/stack_op.h @@ -31,7 +31,7 @@ class StackOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/subgraph_op.cc b/lite/operators/subgraph_op.cc index 58388669af..9ac07e9633 100644 --- a/lite/operators/subgraph_op.cc +++ b/lite/operators/subgraph_op.cc @@ -22,7 +22,7 @@ namespace operators { bool SubgraphOp::CheckShape() const { return true; } -bool SubgraphOp::InferShape() const { return CheckShape(); /* enrich me */ } +bool SubgraphOp::InferShapeImpl() const { return CheckShape(); /* enrich me */ } bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { param_.input_names = op_desc.Input("Inputs"); diff --git a/lite/operators/subgraph_op.h b/lite/operators/subgraph_op.h index 7f593159c8..edbfb92204 100644 --- a/lite/operators/subgraph_op.h +++ b/lite/operators/subgraph_op.h @@ -35,7 +35,7 @@ class SubgraphOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/topk_op.cc b/lite/operators/topk_op.cc index fbfb825544..4a68cbb474 100644 --- a/lite/operators/topk_op.cc +++ b/lite/operators/topk_op.cc @@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const { return true; } -bool TopkOp::InferShape() const { +bool TopkOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); out_dims[out_dims.size() - 1] = param_.K; auto out = param_.Out; diff --git a/lite/operators/topk_op.h b/lite/operators/topk_op.h index 037fa413ea..d5888e5f18 100644 --- a/lite/operators/topk_op.h +++ b/lite/operators/topk_op.h @@ -30,7 +30,7 @@ class TopkOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index 71086b492b..40780346d0 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const { return true; } -bool TransposeOp::InferShape() const { +bool TransposeOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); @@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const { return true; } -bool Transpose2Op::InferShape() const { +bool Transpose2Op::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); diff --git a/lite/operators/transpose_op.h b/lite/operators/transpose_op.h index ce352a7d82..39b75b96d8 100644 --- a/lite/operators/transpose_op.h +++ b/lite/operators/transpose_op.h @@ -31,7 +31,7 @@ class TransposeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -50,7 +50,7 @@ class Transpose2Op : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/uniform_random_op.cc b/lite/operators/uniform_random_op.cc index 93e74e2b01..512648bfe4 100644 --- a/lite/operators/uniform_random_op.cc +++ b/lite/operators/uniform_random_op.cc @@ -22,7 +22,7 @@ namespace operators { bool UniformRandomOpLite::CheckShape() const { return true; } -bool UniformRandomOpLite::InferShape() const { +bool UniformRandomOpLite::InferShapeImpl() const { param_.Out->Resize(param_.shape); return true; } diff --git a/lite/operators/uniform_random_op.h b/lite/operators/uniform_random_op.h index f7dde8882f..a7890ea3e7 100644 --- a/lite/operators/uniform_random_op.h +++ b/lite/operators/uniform_random_op.h @@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 39b275b7b5..b5ae90248a 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const { return true; } -bool UnsqueezeOp::InferShape() const { +bool UnsqueezeOp::InferShapeImpl() const { std::vector final_axes; auto axes = param_.axes; auto *axes_tensor = param_.axes_tensor; @@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const { return true; } -bool Unsqueeze2Op::InferShape() const { - UnsqueezeOp::InferShape(); +bool Unsqueeze2Op::InferShapeImpl() const { + UnsqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/unsqueeze_op.h b/lite/operators/unsqueeze_op.h index 1e88828c6c..5139b69c63 100644 --- a/lite/operators/unsqueeze_op.h +++ b/lite/operators/unsqueeze_op.h @@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 51f43c7099..8cf11f6465 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -21,7 +21,7 @@ namespace operators { bool VarConv2dOp::CheckShape() const { return true; } -bool VarConv2dOp::InferShape() const { return true; } +bool VarConv2dOp::InferShapeImpl() const { return true; } bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = const_cast( diff --git a/lite/operators/var_conv_2d_op.h b/lite/operators/var_conv_2d_op.h index ce6309419c..5fa492d28e 100644 --- a/lite/operators/var_conv_2d_op.h +++ b/lite/operators/var_conv_2d_op.h @@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite { VarConv2dOp() {} explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "var_conv_2d"; } diff --git a/lite/operators/while_op.cc b/lite/operators/while_op.cc index dba266af77..1dcf9553f3 100644 --- a/lite/operators/while_op.cc +++ b/lite/operators/while_op.cc @@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const { return true; } -bool WhileOpLite::InferShape() const { return true; } +bool WhileOpLite::InferShapeImpl() const { return true; } bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("X"); diff --git a/lite/operators/while_op.h b/lite/operators/while_op.h index fcba722dbc..94aec15a6d 100644 --- a/lite/operators/while_op.h +++ b/lite/operators/while_op.h @@ -30,7 +30,7 @@ class WhileOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/write_to_array_op.cc b/lite/operators/write_to_array_op.cc index bf2d9bc4b7..d2cf7b4f94 100644 --- a/lite/operators/write_to_array_op.cc +++ b/lite/operators/write_to_array_op.cc @@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const { return true; } -bool WriteToArrayOp::InferShape() const { +bool WriteToArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; if (param_.Out->size() < id + 1) { param_.Out->resize(id + 1); diff --git a/lite/operators/write_to_array_op.h b/lite/operators/write_to_array_op.h index 8c987a2450..9460b7e364 100644 --- a/lite/operators/write_to_array_op.h +++ b/lite/operators/write_to_array_op.h @@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/yolo_box_op.cc b/lite/operators/yolo_box_op.cc index c8186d3f31..0a5481a8fb 100644 --- a/lite/operators/yolo_box_op.cc +++ b/lite/operators/yolo_box_op.cc @@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const { return true; } -bool YoloBoxOp::InferShape() const { +bool YoloBoxOp::InferShapeImpl() const { auto* X = param_.X; auto anchors = param_.anchors; int anchor_num = anchors.size() / 2; diff --git a/lite/operators/yolo_box_op.h b/lite/operators/yolo_box_op.h index 2e2ea6d634..85448000f3 100644 --- a/lite/operators/yolo_box_op.h +++ b/lite/operators/yolo_box_op.h @@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; -- GitLab