diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index c76e369466a9b998b2ad6fde67b97117649fddc0..a9ccd1b9ae9a5d45f8d0e5638b3aab1d73d1903c 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -22,6 +22,61 @@ namespace paddle { namespace lite { +bool OpLite::InferShape() { + // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ + // InferShapeByMemoryInternal will be applied. + if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { + return this->InferShapeWithCache(); + } else { + // otherwise, InferShapeImpl is applied directly. + return this->InferShapeImpl(); + } +} +bool OpLite::InferShapeWithCache() { + // 1. Get vector of current input tensors + auto *current_inputs = param_.input_tensor_ptrs(); + // 2. Get hash value of current inputs shape and lod + size_t new_hash = 0; + for (auto iter = current_inputs->begin(); iter != current_inputs->end(); + iter++) { + // combined dims value into new_hash value. + auto &element_dims = (*iter)->dims(); + for (int i = 0; i < element_dims.size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(element_dims[i])); + } + // combine lod value into new_hash valud. + auto &emement_lods = (*iter)->lod(); + for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end(); + lod_iter++) { + for (int i = 0; i < lod_iter->size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(lod_iter->at(i))); + } + } + } + // 3. infer shapes of output tensors + if (new_hash == io_shape_lod_hash_ && new_hash != 0) { + // if current hash value is consistent with io_shape_lod_hash_, + // previous outputs shape and lod are reused. + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + current_outputs->at(i)->Resize(last_output_shapes[i]); + current_outputs->at(i)->set_lod(last_output_lods[i]); + } + } else { + // otherwise, current hash value is changed, InferShapeImpl will apply. + io_shape_lod_hash_ = new_hash; + this->InferShapeImpl(); + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + last_output_shapes[i] = current_outputs->at(i)->dims(); + last_output_lods[i] = current_outputs->at(i)->lod(); + } + } + return true; +} + std::vector> OpLite::CreateKernels( const std::vector &places, const std::string &kernel_type) { std::vector> kernels; diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 77d8091b4b16cfbce2efc3d549f916a9136c61ab..4c6c66be7e41889c116aed023d863df8a4a912c8 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -24,6 +25,7 @@ #include "lite/core/kernel.h" #include "lite/core/scope.h" #include "lite/model_parser/cpp/op_desc.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { @@ -64,8 +66,8 @@ class OpLite : public Registry { // Check the shape. virtual bool CheckShape() const { return true; } // Inference the outputs' shape. - virtual bool InferShape() const { return true; } - virtual bool SmartInferShape() { return this->InferShape(); } + virtual bool InferShapeImpl() const { return true; } + virtual bool InferShape(); // Run this operator. virtual bool Run(); // Indicate whether the Op runs only once or not @@ -151,10 +153,16 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; - std::vector last_input_shapes; - std::vector last_output_shapes; - std::vector>> last_output_lods; - std::vector>> last_input_lods; + + std::vector last_output_shapes{}; + std::vector>> last_output_lods{}; + size_t io_shape_lod_hash_{}; + mutable operators::ParamBase param_; + + private: + // Infer Shape according to memory, if current input shapes are consistent + // with that of previous inputs, output shapes of last time will be reused. + bool InferShapeWithCache(); }; /* diff --git a/lite/core/program.cc b/lite/core/program.cc index 580389fbad54c0de8efd65ef78c9b69fd3e72893..7284c3983cb34a0db2387ece40f6d07b9d9a8511 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -286,8 +286,7 @@ void Instruction::Run() { return; } - // op_->InferShape(); - op_->SmartInferShape(); + op_->InferShape(); kernel_->Launch(); has_run_ = true; } diff --git a/lite/operators/activation_grad_ops.cc b/lite/operators/activation_grad_ops.cc index 9a37a5f0a178192ead00801632914a8f446f058f..b31163e5dce6d9b77d923ba44ed58952263610a5 100644 --- a/lite/operators/activation_grad_ops.cc +++ b/lite/operators/activation_grad_ops.cc @@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const { return true; } -bool ActivationGradOp::InferShape() const { +bool ActivationGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.Out_grad->dims()); return true; } diff --git a/lite/operators/activation_grad_ops.h b/lite/operators/activation_grad_ops.h index 5421b3247ff844e20931a6a15b85eb7da85e7f69..cf928cfe1bf9945a1dd0474408472759a499b5d7 100644 --- a/lite/operators/activation_grad_ops.h +++ b/lite/operators/activation_grad_ops.h @@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index f7a326358bb30d747c949d7bacdebb47846562b5..abaaa1a705c2c6995a7d846c1d9add0dab98867b 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const { return true; } -bool ActivationOp::InferShape() const { +bool ActivationOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); auto out_lod = param_.Out->mutable_lod(); *out_lod = param_.X->lod(); diff --git a/lite/operators/activation_ops.h b/lite/operators/activation_ops.h index 34099ab0fdb422f523e383dc0dd286acf24b2731..8f81b12af03052e558e7faa2e813039d4dee8988 100644 --- a/lite/operators/activation_ops.h +++ b/lite/operators/activation_ops.h @@ -26,7 +26,7 @@ class ActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/affine_channel_op.cc b/lite/operators/affine_channel_op.cc index c4945ababd2fdf3b0f1b25d26eb0f66c8f613b21..447079deb33bdb893b99901d8559d6961489789d 100644 --- a/lite/operators/affine_channel_op.cc +++ b/lite/operators/affine_channel_op.cc @@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const { return true; } -bool AffineChannelOpLite::InferShape() const { +bool AffineChannelOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); param_.Out->Resize(x_dims); return true; diff --git a/lite/operators/affine_channel_op.h b/lite/operators/affine_channel_op.h index 85a043bdc8e1c6f41c27b2e57555d3454322f789..5a3d9d66259d477d42ac00e0e1b1a7ba1bf2e862 100644 --- a/lite/operators/affine_channel_op.h +++ b/lite/operators/affine_channel_op.h @@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/anchor_generator_op.cc b/lite/operators/anchor_generator_op.cc index 8daa54905fcf7cf52259840c26198721d6b8f0fa..e57a4b2df8c75afd28506b5e0e2f7b7aa142b838 100644 --- a/lite/operators/anchor_generator_op.cc +++ b/lite/operators/anchor_generator_op.cc @@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const { return true; } -bool AnchorGeneratorOpLite::InferShape() const { +bool AnchorGeneratorOpLite::InferShapeImpl() const { auto input_dims = param_.Input->dims(); size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size(); std::vector output_shape( diff --git a/lite/operators/anchor_generator_op.h b/lite/operators/anchor_generator_op.h index 46e5e0fac243c10b62122327ef06ea166878e54f..2ff3422824c15b54ed1fa3ca9952745d5b1706ac 100644 --- a/lite/operators/anchor_generator_op.h +++ b/lite/operators/anchor_generator_op.h @@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/argmax_op.cc b/lite/operators/argmax_op.cc index 772cc446077e5e896b757051fae9f9b8f59df1d8..b733998ae57785483f539b56dcb47b7b50f04cf0 100644 --- a/lite/operators/argmax_op.cc +++ b/lite/operators/argmax_op.cc @@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const { return true; } -bool ArgmaxOpLite::InferShape() const { +bool ArgmaxOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); int x_rank = x_dims.size(); int axis = param_.Axis; diff --git a/lite/operators/argmax_op.h b/lite/operators/argmax_op.h index a5accc97e3b9f3bb2fbd00f45fd3a45063e5c747..e6944507cf9f6ded86ccbae7c3cec79106e8ba98 100644 --- a/lite/operators/argmax_op.h +++ b/lite/operators/argmax_op.h @@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_op.cc b/lite/operators/assign_op.cc index 8510b7e8b7b8a5732e0e09d3db494ab3eb9f15a8..25e8539d2e55a07a19d707713489d86f84aa64db 100644 --- a/lite/operators/assign_op.cc +++ b/lite/operators/assign_op.cc @@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const { return true; } -bool AssignOpLite::InferShape() const { +bool AssignOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/assign_op.h b/lite/operators/assign_op.h index 555356c3659ff31c84b2630c1f5da6acab003823..9e7039bb5b0088a6bda6acbf2baf7a50444df8b2 100644 --- a/lite/operators/assign_op.h +++ b/lite/operators/assign_op.h @@ -30,7 +30,7 @@ class AssignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_value_op.cc b/lite/operators/assign_value_op.cc index 046c5222283fc73bd3af1e53520b1fc5539bcd31..ff5b55735f7b58aa2eaa2274574336dadd8061e6 100644 --- a/lite/operators/assign_value_op.cc +++ b/lite/operators/assign_value_op.cc @@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const { return true; } -bool AssignValueOpLite::InferShape() const { +bool AssignValueOpLite::InferShapeImpl() const { std::vector shape = param_.shape; std::vector out_shape; for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); diff --git a/lite/operators/assign_value_op.h b/lite/operators/assign_value_op.h index 7bf220615935f02051ed606adb894bf9842378f3..030da048184c9862b76f59198574b394457768d5 100644 --- a/lite/operators/assign_value_op.h +++ b/lite/operators/assign_value_op.h @@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/attention_padding_mask_op.cc b/lite/operators/attention_padding_mask_op.cc index a88df0e7a902c6cac63eb77377bb0b49ee30c9b3..2f3a0cd265c56ac24548e23ff3daf09e27e1d800 100644 --- a/lite/operators/attention_padding_mask_op.cc +++ b/lite/operators/attention_padding_mask_op.cc @@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const { return true; } -bool AttentionPaddingMaskOp::InferShape() const { +bool AttentionPaddingMaskOp::InferShapeImpl() const { auto src_len = param_.X->lod()[0][1]; CHECK_EQ(src_len, param_.X->dims()[1]) << "Mismatch source length, expect: " << src_len diff --git a/lite/operators/attention_padding_mask_op.h b/lite/operators/attention_padding_mask_op.h index 894d68f6226720139aee07274d4ac5cf660749f1..6a2443fc6749d4f2066ee761fd194441e2fe46cd 100644 --- a/lite/operators/attention_padding_mask_op.h +++ b/lite/operators/attention_padding_mask_op.h @@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/axpy_op.cc b/lite/operators/axpy_op.cc index 60f302862afa47ca75ae703e7b848bb3a0e7604c..c1c6304c3119f89bdc46400b2478a767c914d001 100644 --- a/lite/operators/axpy_op.cc +++ b/lite/operators/axpy_op.cc @@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const { return true; } -bool AxpyOpLite::InferShape() const { +bool AxpyOpLite::InferShapeImpl() const { auto dims = param_.Bias->dims(); // Set output dims diff --git a/lite/operators/axpy_op.h b/lite/operators/axpy_op.h index 1fa8540743f65db864f33633003b4ed8f6d8cb92..e9d9f44ca5f5843628af998d9140519a3f3a1c29 100644 --- a/lite/operators/axpy_op.h +++ b/lite/operators/axpy_op.h @@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index eca7fa6001dda7835213c60be1d21eedff301ae4..67e037fba349e811f1faf991c84310b11ab7a13c 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const { return true; } -bool BatchNormOp::InferShape() const { +bool BatchNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t channel_size = 0; switch (param_.data_layout) { diff --git a/lite/operators/batch_norm_op.h b/lite/operators/batch_norm_op.h index 21dbf9a28a4257acdd80ac6c49d111cdd757b65d..9598763713564192ed4ad0c99200f0fdb1d88d37 100644 --- a/lite/operators/batch_norm_op.h +++ b/lite/operators/batch_norm_op.h @@ -30,7 +30,7 @@ class BatchNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_decode_op.cc b/lite/operators/beam_search_decode_op.cc index 52888d8a99c0f6507862f515c633f04d4fe09c39..444c9d6a11217c3134c3cb1f988c60c4b98d4566 100644 --- a/lite/operators/beam_search_decode_op.cc +++ b/lite/operators/beam_search_decode_op.cc @@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const { return true; } -bool BeamSearchDecodeOpLite::InferShape() const { return true; } +bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; } bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/beam_search_decode_op.h b/lite/operators/beam_search_decode_op.h index 9d324d2bf0974fe5b65711c4ab2dacaf0d0d65d9..38bf9929ab12ba764fcd3fe6cacc7c08f35c15ca 100644 --- a/lite/operators/beam_search_decode_op.h +++ b/lite/operators/beam_search_decode_op.h @@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_op.cc b/lite/operators/beam_search_op.cc index c998e002ee3d6b8f3196fdfa212462dac4da0969..ea777ad53395aba1c7d6c21b07013e374b03c1f4 100644 --- a/lite/operators/beam_search_op.cc +++ b/lite/operators/beam_search_op.cc @@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const { return true; } -bool BeamSearchOp::InferShape() const { return true; } +bool BeamSearchOp::InferShapeImpl() const { return true; } bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front()); diff --git a/lite/operators/beam_search_op.h b/lite/operators/beam_search_op.h index 42a6058de112215f525b51bfff6ff16aae04391d..7e325cb55668a77cf09466e86be220218a49cbee 100644 --- a/lite/operators/beam_search_op.h +++ b/lite/operators/beam_search_op.h @@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_clip_op.cc b/lite/operators/box_clip_op.cc index 6bd93c6ea4e2efc93fdc7e64f1738c2ac3d40997..08ba49bd9ada076c6650249f67af15174491f634 100644 --- a/lite/operators/box_clip_op.cc +++ b/lite/operators/box_clip_op.cc @@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const { return true; } -bool BoxClipOpLite::InferShape() const { +bool BoxClipOpLite::InferShapeImpl() const { auto* input = param_.Input; auto* output = param_.Output; output->Resize(input->dims()); diff --git a/lite/operators/box_clip_op.h b/lite/operators/box_clip_op.h index c7e07b1015c52eb5711638163bda327c11152dd0..0aae2112ec8b91ba63205fadd4123bc3c5fce2fd 100644 --- a/lite/operators/box_clip_op.h +++ b/lite/operators/box_clip_op.h @@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_coder_op.cc b/lite/operators/box_coder_op.cc index c86f494fc4f96f688c30027f1d6aa1ee452da8f0..3133176b35ecae49ed9171ef6e8b519c6774ce5d 100644 --- a/lite/operators/box_coder_op.cc +++ b/lite/operators/box_coder_op.cc @@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const { return true; } -bool BoxCoderOpLite::InferShape() const { +bool BoxCoderOpLite::InferShapeImpl() const { auto prior_box_dims = param_.prior_box->dims(); auto target_box_dims = param_.target_box->dims(); std::string code_type = param_.code_type; diff --git a/lite/operators/box_coder_op.h b/lite/operators/box_coder_op.h index 61d54fd484ff377763e00f1d71bff1c0c6f89398..51e86423e39786426d53fe8ced861866bfeb1053 100644 --- a/lite/operators/box_coder_op.h +++ b/lite/operators/box_coder_op.h @@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/calib_op.cc b/lite/operators/calib_op.cc index da00f01c3206c81fb89749432383ea8d99c14dc1..8da8747f8c9df038ee424395fd75a20a718f1970 100644 --- a/lite/operators/calib_op.cc +++ b/lite/operators/calib_op.cc @@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const { CHECK_OR_FALSE(param_.output); return true; } -bool CalibOpLite::InferShape() const { +bool CalibOpLite::InferShapeImpl() const { param_.output->Resize(param_.input->dims()); return true; } diff --git a/lite/operators/calib_op.h b/lite/operators/calib_op.h index d575766c10d1e6cd66bf7f8117315ffe21fe10fe..94240880f55e782f025fe5777eba19e0c96cfbee 100644 --- a/lite/operators/calib_op.h +++ b/lite/operators/calib_op.h @@ -42,7 +42,7 @@ class CalibOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope); diff --git a/lite/operators/cast_op.cc b/lite/operators/cast_op.cc index 9ece0a45a3e997e4d1663755f42f6b42efb86c5d..da12e2afded2c23565080b06409ce35b0535c4ff 100644 --- a/lite/operators/cast_op.cc +++ b/lite/operators/cast_op.cc @@ -25,7 +25,7 @@ bool CastOp::CheckShape() const { return true; } -bool CastOp::InferShape() const { +bool CastOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/cast_op.h b/lite/operators/cast_op.h index 2f5f57f12740d085bda36141299cfbe7c798c378..e045ef89f73d0ac29b0f03e148ad651c1513668f 100644 --- a/lite/operators/cast_op.h +++ b/lite/operators/cast_op.h @@ -30,7 +30,7 @@ class CastOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/collect_fpn_proposals_op.cc b/lite/operators/collect_fpn_proposals_op.cc index 4731d4bf81c241c6733b1403699874c1053d2b7f..27dd9a50b6fb0a9943b7a9d86be390cbc6d406b0 100644 --- a/lite/operators/collect_fpn_proposals_op.cc +++ b/lite/operators/collect_fpn_proposals_op.cc @@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { return true; } -bool CollectFpnProposalsOpLite::InferShape() const { +bool CollectFpnProposalsOpLite::InferShapeImpl() const { param_.fpn_rois->Resize({param_.post_nms_topN, 4}); return true; diff --git a/lite/operators/collect_fpn_proposals_op.h b/lite/operators/collect_fpn_proposals_op.h index 1ae7bb269ff53bb8add92d9afc8d462c45cb5f0b..b3104e81d5ff8d82083a7b37ffd88dd169b840c9 100644 --- a/lite/operators/collect_fpn_proposals_op.h +++ b/lite/operators/collect_fpn_proposals_op.h @@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/compare_op.cc b/lite/operators/compare_op.cc index aa500ba35c37cf8af17091d8d37d8fd8d1a08e0e..f458eae71edea6086e8947ae8881f6f218e49808 100644 --- a/lite/operators/compare_op.cc +++ b/lite/operators/compare_op.cc @@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const { return true; } -bool CompareOp::InferShape() const { +bool CompareOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/compare_op.h b/lite/operators/compare_op.h index 7ca21caaa1347f248213b2b43293ca18d514ba9a..c94cf88516af7676f8e524c091713cbaa4dd70ff 100644 --- a/lite/operators/compare_op.h +++ b/lite/operators/compare_op.h @@ -30,7 +30,7 @@ class CompareOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index b2f7438b64aa34787896839f020f0b056e6453fb..c15bf292897006b3c6d5e67bcfaea5d0e590a82d 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const { return true; } -bool ConcatOpLite::InferShape() const { +bool ConcatOpLite::InferShapeImpl() const { const std::vector &inputs = param_.x; const size_t n = inputs.size(); CHECK_GT_OR_FALSE(n, 0); diff --git a/lite/operators/concat_op.h b/lite/operators/concat_op.h index acc41de9b36cf6a808788a4f585e8a9c7f049717..2ac1572c833db217546aaa176640cb5c1022d3bf 100644 --- a/lite/operators/concat_op.h +++ b/lite/operators/concat_op.h @@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conditional_block_op.cc b/lite/operators/conditional_block_op.cc index c79c4e20a29834e858bc670104e2a09e55888c85..e3678e92c9d33be5428c82331ce963f4c6067369 100644 --- a/lite/operators/conditional_block_op.cc +++ b/lite/operators/conditional_block_op.cc @@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const { return true; } -bool ConditionalBlockOpLite::InferShape() const { return true; } +bool ConditionalBlockOpLite::InferShapeImpl() const { return true; } bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/conditional_block_op.h b/lite/operators/conditional_block_op.h index 5518c255c5799aa5b44557a4493275794fd598f5..1815731c8df3ac07bee80aa8e0cc658e752b5c4f 100644 --- a/lite/operators/conditional_block_op.h +++ b/lite/operators/conditional_block_op.h @@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 70ad3a32a83003e449524205a71dcc7536b9a11e..38c59a0290b03031e9cbe013a4a10c14c7ad1743 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector* paddings, } } -bool ConvOpLite::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.x->dims() && - last_input_lods[0] == param_.x->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool ConvOpLite::InferShape() const { +bool ConvOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 3379fb409529e261f4af38ef2ee3483f17cc8a3b..eab17fe6db0a59a9eb0eea0ab7344758a8232d15 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -34,9 +34,7 @@ class ConvOpLite : public OpLite { explicit ConvOpLite(const std::string& type) : OpLite(type) {} bool CheckShape() const override; - - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index a84b975492040ec0bdc1326f33f8b7edafdea2bb..511a5157ad58e5e2d7bda5c4d0de136c9b3f9590 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size, return output_size; } -bool ConvTransposeOpLite::InferShape() const { +bool ConvTransposeOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_transpose_op.h b/lite/operators/conv_transpose_op.h index fb25c022f974ad195bf72b19cb9b459b2d11d5f2..891ece4f052128c8c236db5650414d6015ea9565 100644 --- a/lite/operators/conv_transpose_op.h +++ b/lite/operators/conv_transpose_op.h @@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crf_decoding_op.cc b/lite/operators/crf_decoding_op.cc index 1b0a27ab4afdfc165dedc2ccfad492658ec40399..b1af573518bc483b6eaf5e013609583b548fb300 100644 --- a/lite/operators/crf_decoding_op.cc +++ b/lite/operators/crf_decoding_op.cc @@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const { return true; } -bool CrfDecodingOpLite::InferShape() const { +bool CrfDecodingOpLite::InferShapeImpl() const { auto emission_dims = param_.emission->dims(); if (param_.length == nullptr) { param_.viterbi_path->Resize({emission_dims[0], 1}); diff --git a/lite/operators/crf_decoding_op.h b/lite/operators/crf_decoding_op.h index 6aaf338ec240d2caa659785f909d5eee7d249008..4bc50410ab0504b3e25585caba7f8fff823553b0 100644 --- a/lite/operators/crf_decoding_op.h +++ b/lite/operators/crf_decoding_op.h @@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crop_op.cc b/lite/operators/crop_op.cc index 1a27cfb34d958176c8ad0a6e17d7e17e5287d2d5..4905d92e587ea10783fe7a3cb88b6ee67761c73e 100644 --- a/lite/operators/crop_op.cc +++ b/lite/operators/crop_op.cc @@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const { return true; } -bool CropOpLite::InferShape() const { +bool CropOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); lite::DDim output_shape(x_dims); diff --git a/lite/operators/crop_op.h b/lite/operators/crop_op.h index f21278e891d265093c26be1f96e416974af13b2e..bd3d0e71d8780fab16134ba347f3208249403bd7 100644 --- a/lite/operators/crop_op.h +++ b/lite/operators/crop_op.h @@ -30,7 +30,7 @@ class CropOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/decode_bboxes_op.cc b/lite/operators/decode_bboxes_op.cc index e22adf1774427e10e3fa146e388a6ce365f86021..1903267c3aa46e048787f007a5c9cede8c574c5a 100644 --- a/lite/operators/decode_bboxes_op.cc +++ b/lite/operators/decode_bboxes_op.cc @@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const { return true; } -bool DecodeBboxesOpLite::InferShape() const { +bool DecodeBboxesOpLite::InferShapeImpl() const { param_.bbox_data->Resize(param_.loc_data->dims()); return true; } diff --git a/lite/operators/decode_bboxes_op.h b/lite/operators/decode_bboxes_op.h index c463992c8da6b042d5df027b03e64a594ede8a02..8848a1c26cd9363595a3200fc6e2535751f72df0 100644 --- a/lite/operators/decode_bboxes_op.h +++ b/lite/operators/decode_bboxes_op.h @@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index 86830df2f19b5615e8b9cfb4b3b57eb22000f588..5ac3eef63bb59c80bffaf3bed558b3ac5baf4d61 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const { return true; } -bool DensityPriorBoxOpLite::InferShape() const { return true; } +bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; } bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { diff --git a/lite/operators/density_prior_box_op.h b/lite/operators/density_prior_box_op.h index bad55ad3b7046da45663a2cdd41243ecd5d41cb0..d84b20557fab101ba60f0af58234ffca4e672a57 100644 --- a/lite/operators/density_prior_box_op.h +++ b/lite/operators/density_prior_box_op.h @@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/distribute_fpn_proposals_op.cc b/lite/operators/distribute_fpn_proposals_op.cc index 5d6a0fca923dd38fd456e024ec14ba7c2685163d..a23c5e1ffb50b1d22a42d5e68bd424d078e83110 100644 --- a/lite/operators/distribute_fpn_proposals_op.cc +++ b/lite/operators/distribute_fpn_proposals_op.cc @@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const { return true; } -bool DistributeFpnProposalsOpLite::InferShape() const { +bool DistributeFpnProposalsOpLite::InferShapeImpl() const { int num_out_rois = param_.max_level - param_.min_level + 1; for (int i = 0; i < num_out_rois; i++) { param_.multi_fpn_rois[i]->Resize({-1, 4}); diff --git a/lite/operators/distribute_fpn_proposals_op.h b/lite/operators/distribute_fpn_proposals_op.h index 2390e329329f7406f05ba69b3768556f94a02bec..22ab2006e072ea36037cb05faaca324a7d2922c9 100644 --- a/lite/operators/distribute_fpn_proposals_op.h +++ b/lite/operators/distribute_fpn_proposals_op.h @@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/dropout_op.cc b/lite/operators/dropout_op.cc index 03047de3b318ee2221809ee602d94f204568d723..858cc6d9197433985aabfb428993d2fa1333527e 100644 --- a/lite/operators/dropout_op.cc +++ b/lite/operators/dropout_op.cc @@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const { return true; } -bool DropoutOp::InferShape() const { +bool DropoutOp::InferShapeImpl() const { const auto x_dims = param_.x->dims(); param_.output->Resize(x_dims); if (param_.is_test == false) { diff --git a/lite/operators/dropout_op.h b/lite/operators/dropout_op.h index 97e17e350c6a87a82e3cf05635d9575269489d7a..bdf0e1d9046178b48f2b4917840eee6ac8572c5a 100644 --- a/lite/operators/dropout_op.h +++ b/lite/operators/dropout_op.h @@ -28,7 +28,7 @@ class DropoutOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/elementwise_grad_ops.cc b/lite/operators/elementwise_grad_ops.cc index 9d964bf9e36889f2bc72b2656d23bf4022cc121c..730785ba6e6553e6a306f87bdbc63ea5b1017f0a 100644 --- a/lite/operators/elementwise_grad_ops.cc +++ b/lite/operators/elementwise_grad_ops.cc @@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const { return true; } -bool ElementwiseGradOp::InferShape() const { +bool ElementwiseGradOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (param_.XGrad) { diff --git a/lite/operators/elementwise_grad_ops.h b/lite/operators/elementwise_grad_ops.h index c45d581936207f0b37ee70a0505b912d0b509e35..ca8a3241349b4cdc04e4800a0a88b215f586ba72 100644 --- a/lite/operators/elementwise_grad_ops.h +++ b/lite/operators/elementwise_grad_ops.h @@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index 044126b3c22fa853d4908c06c307f32278fa5b9b..f4debc39a0d480f38e6d37e8e60d516def7f0b55 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool ElementwiseOp::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.X->dims() && - last_input_shapes[1] == param_.Y->dims() && - last_input_lods[0] == param_.X->lod() && - last_input_lods[1] == param_.Y->lod()) { - param_.Out->Resize(last_output_shapes[0]); - param_.Out->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.X->dims()); - last_input_lods.push_back(param_.X->lod()); - last_input_shapes.push_back(param_.Y->dims()); - last_input_lods.push_back(param_.Y->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.Out->dims()); - last_output_lods.push_back(param_.Out->lod()); - return true; -} -bool ElementwiseOp::InferShape() const { +bool ElementwiseOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (x_dim == y_dim) { @@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { // return true; //} -// bool ElementwiseGradExplicitOp::InferShape() const { +// bool ElementwiseGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/elementwise_ops.h b/lite/operators/elementwise_ops.h index 9d6e5781b9754eb22be11da0d7f77b764eb25912..0f1b682fa5f267dd802c5ee0e35aca8f6d68f39c 100644 --- a/lite/operators/elementwise_ops.h +++ b/lite/operators/elementwise_ops.h @@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/expand_op.cc b/lite/operators/expand_op.cc index 656e8babc022e3bb022b3c3b4bb066ea5e5d173c..8e40a3b236609b1e83b5224efb462a1f803764df 100644 --- a/lite/operators/expand_op.cc +++ b/lite/operators/expand_op.cc @@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const { return true; } -bool ExpandOpLite::InferShape() const { +bool ExpandOpLite::InferShapeImpl() const { DDim out_dims(param_.X->dims()); for (size_t i = 0; i < param_.expand_times.size(); ++i) { out_dims[i] *= param_.expand_times[i]; diff --git a/lite/operators/expand_op.h b/lite/operators/expand_op.h index ce5dcda9e80377699b168e6a4970a9bba0cf5039..1312df8e83747107e4c87e856c3b07fc2748d75b 100644 --- a/lite/operators/expand_op.h +++ b/lite/operators/expand_op.h @@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fake_channel_wise_dequantize_max_abs.h b/lite/operators/fake_channel_wise_dequantize_max_abs.h index 43afb7791fe617af0c7ac496cc62a12e6cc548d2..e26d5dda52f8b72d9202067a8782cf1dc10b983e 100644 --- a/lite/operators/fake_channel_wise_dequantize_max_abs.h +++ b/lite/operators/fake_channel_wise_dequantize_max_abs.h @@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_dequantize_max_abs.h b/lite/operators/fake_dequantize_max_abs.h index bc266327ebcb14da01201dcc1825367ff7ecd72e..c4bb19c04872078eb997afca6cd7a3cce6923fde 100644 --- a/lite/operators/fake_dequantize_max_abs.h +++ b/lite/operators/fake_dequantize_max_abs.h @@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h index 8efa46c41501be79ccc69f4cc9f9646c11673d2d..be7ec60e0eab730c2910c3822c976d579b48d6b7 100644 --- a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_moving_avg_max_abs.h index adc62a480d2d2efec54b3822f55a9f66c278e21e..5726231f31eab2012d2cd594c5c26977c71141ff 100644 --- a/lite/operators/fake_quantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h index f68d1e20f6e60bb5aa99a2402ea8c9f88aa18470..14f823ece2ee168ae09bc1db67f3d6a7e8c18d5d 100644 --- a/lite/operators/fake_quantize_range_abs_max.h +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 345fc0d605ccd68e3a6ef72429e20400a772568c..d58a9e5b881048dd47340082fe9c94a618a7a5fb 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const { return true; } -bool FcOpLite::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (last_input_shapes[0] == param_.input->dims() && - last_input_lods[0] == param_.input->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.input->dims()); - last_input_lods.push_back(param_.input->lod()); - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool FcOpLite::InferShape() const { +bool FcOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& w_dims = param_.w->dims(); int in_num_col_dims = param_.in_num_col_dims; diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index f5dc302e27a220ee1f1e0679cbb3c2ed257747dd..2e6a3ad59a1ca6d2e31f42ceb4b2d1b381c697ee 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -35,8 +35,7 @@ class FcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/feed_op.cc b/lite/operators/feed_op.cc index 8a0c75f62b6bed5767a8cc4b8348b4ca5b59eea5..c429d1f5744e50ff84a0a3d76e2f3e1ba68a0821 100644 --- a/lite/operators/feed_op.cc +++ b/lite/operators/feed_op.cc @@ -29,7 +29,7 @@ class FeedOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/fetch_op.cc b/lite/operators/fetch_op.cc index d50c0db34084bf8a70c9451ba0f0d8960e9d18c9..9db5fb418dab4418a0d6a622f87620c5c2673ecf 100644 --- a/lite/operators/fetch_op.cc +++ b/lite/operators/fetch_op.cc @@ -29,7 +29,7 @@ class FetchOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: diff --git a/lite/operators/fill_constant_batch_size_like_op.cc b/lite/operators/fill_constant_batch_size_like_op.cc index 7df3a6aa9e75ecc3fe88031a544c8e5ed3d1dd02..5b0ebb38e717afea4dabe011c0161248e2113a02 100644 --- a/lite/operators/fill_constant_batch_size_like_op.cc +++ b/lite/operators/fill_constant_batch_size_like_op.cc @@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { return true; } -bool FillConstantBatchSizeLikeOp::InferShape() const { +bool FillConstantBatchSizeLikeOp::InferShapeImpl() const { std::vector output_dim{param_.shape.begin(), param_.shape.end()}; if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h index 33cc45779f6132fbc34b33eb2abbe9ca71418046..3c576ab28222c45aa17ba96f5e3e585624a29c02 100644 --- a/lite/operators/fill_constant_batch_size_like_op.h +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index 698b787f469375831d937fdf16bb58af06288e71..565c4bbd16e01af340e728e28866268c1a845760 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const { return true; } -bool FillConstantOp::InferShape() const { +bool FillConstantOp::InferShapeImpl() const { std::vector out_shape; auto shape_tensor = param_.shape_tensor; auto shape_tensor_list = param_.shape_tensor_list; diff --git a/lite/operators/fill_constant_op.h b/lite/operators/fill_constant_op.h index aa2fea5a665ee9a3c50efa3ec354fe52d9643050..3c0500898bef45efc7a72bc68c82fca9036c63f4 100644 --- a/lite/operators/fill_constant_op.h +++ b/lite/operators/fill_constant_op.h @@ -31,7 +31,7 @@ class FillConstantOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc index 6deab45023876b1a5707ef5cea6ec69af3875328..b270dbf52f9a19f574e6f8967ff93e3a013e5737 100644 --- a/lite/operators/flatten_op.cc +++ b/lite/operators/flatten_op.cc @@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const { return true; } -bool FlattenOp::InferShape() const { +bool FlattenOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto out_lod = param_.output->mutable_lod(); @@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const { return true; } -bool Flatten2Op::InferShape() const { - FlattenOp::InferShape(); +bool Flatten2Op::InferShapeImpl() const { + FlattenOp::InferShapeImpl(); auto x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h index 61680fd3903b77f8826cda6f6a242739720155d7..78b803d765c8513ead9bf482bf23914ac4bf3430 100644 --- a/lite/operators/flatten_op.h +++ b/lite/operators/flatten_op.h @@ -30,7 +30,7 @@ class FlattenOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index 244394b95aafede6956bc548430f5c14f28ae910..dfe3bda6c65a75f8b0f8a080d9dc367fb493e6f2 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { return true; } -bool FusionElementwiseActivationOp::InferShape() const { +bool FusionElementwiseActivationOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); param_.Out->Resize(param_.X->dims()); return true; @@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, // return true; // } -// bool FusionElementwiseActivationGradExplicitOp::InferShape() const { +// bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/fusion_elementwise_activation_ops.h b/lite/operators/fusion_elementwise_activation_ops.h index db521284f0fc96c542fd5e7104b045f83f837f97..738c2168225d86f4614ba8eaaa6c6354f038116c 100644 --- a/lite/operators/fusion_elementwise_activation_ops.h +++ b/lite/operators/fusion_elementwise_activation_ops.h @@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/gather_op.cc b/lite/operators/gather_op.cc index 858dad8e4c4b623e8d2499019bba36c7e0373b60..670cd61c8ea5af2f29a908b5d49bedccaff93c0a 100644 --- a/lite/operators/gather_op.cc +++ b/lite/operators/gather_op.cc @@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const { return true; } -bool GatherOp::InferShape() const { +bool GatherOp::InferShapeImpl() const { auto index_dims = param_.Index->dims(); CHECK(index_dims.size() == 1 || (index_dims.size() == 2 && index_dims[1] == 1)) diff --git a/lite/operators/gather_op.h b/lite/operators/gather_op.h index 58d5a30ffbb5f563503c8934d8c9e40bb539d5df..d2072c3a6d6e6e0b100ab3bb9413da8cd4f51f6b 100644 --- a/lite/operators/gather_op.h +++ b/lite/operators/gather_op.h @@ -30,7 +30,7 @@ class GatherOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/generate_proposals_op.cc b/lite/operators/generate_proposals_op.cc index a29ef65e97ccfdaaaf20d6cbbb411fc69cee6f54..48e709c348974dcf1868a7a17425b4168f04b4f6 100644 --- a/lite/operators/generate_proposals_op.cc +++ b/lite/operators/generate_proposals_op.cc @@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const { return true; } -bool GenerateProposalsOpLite::InferShape() const { +bool GenerateProposalsOpLite::InferShapeImpl() const { param_.RpnRois->Resize(std::vector({-1, 4})); param_.RpnRoiProbs->Resize(std::vector({-1, 1})); return true; diff --git a/lite/operators/generate_proposals_op.h b/lite/operators/generate_proposals_op.h index 502bcca1a3276fbbcc2f05bf8b38fcf2d1bbb024..35dee1966bda7cd9e865f42113c7a92061a3782a 100644 --- a/lite/operators/generate_proposals_op.h +++ b/lite/operators/generate_proposals_op.h @@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/grid_sampler_op.cc b/lite/operators/grid_sampler_op.cc index 2b13d17da7c439f582f682a74b1590cda632cf78..97e2b36a6bcd0eb784a39ab4f2a2e0703d7a7c93 100644 --- a/lite/operators/grid_sampler_op.cc +++ b/lite/operators/grid_sampler_op.cc @@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const { return true; } -bool GridSamplerOp::InferShape() const { +bool GridSamplerOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out->Resize(x_dims); return true; diff --git a/lite/operators/grid_sampler_op.h b/lite/operators/grid_sampler_op.h index 035e1b834510affefacafad763d75d6fbf53aed9..2fba4fe69311c274765e9db4c9b27e137c78a3ee 100644 --- a/lite/operators/grid_sampler_op.h +++ b/lite/operators/grid_sampler_op.h @@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_op.cc b/lite/operators/gru_op.cc index eb97d65a1a213e31b23087d1ca5c8e963ecf9bbb..862a1ff98f699393c9aa91afab978f947cc25187 100644 --- a/lite/operators/gru_op.cc +++ b/lite/operators/gru_op.cc @@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const { return true; } -bool GRUOpLite::InferShape() const { +bool GRUOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& weight_dims = param_.weight->dims(); int frame_size = weight_dims[0]; diff --git a/lite/operators/gru_op.h b/lite/operators/gru_op.h index c43f32f0cd41b8fa9bc8a541c48523a4f120009d..34f87fa79371fc3d798a57b4aae0945a27a692c3 100644 --- a/lite/operators/gru_op.h +++ b/lite/operators/gru_op.h @@ -30,7 +30,7 @@ class GRUOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_unit_op.cc b/lite/operators/gru_unit_op.cc index ed33507fc3fa61fce1e718581309ae37992c0531..ad025fbbc19cf27f053d5cc2bda566f186a72529 100644 --- a/lite/operators/gru_unit_op.cc +++ b/lite/operators/gru_unit_op.cc @@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const { return true; } -bool GRUUnitOpLite::InferShape() const { +bool GRUUnitOpLite::InferShapeImpl() const { auto input_dims = param_.input->dims(); auto hidden_prev_dims = param_.hidden_prev->dims(); auto weight_dims = param_.weight->dims(); diff --git a/lite/operators/gru_unit_op.h b/lite/operators/gru_unit_op.h index 301a7e7323afaea16dce2adcb356a41a8b0b8cac..2785e60e95b0f36cc5bf92714af857ef658d80dc 100644 --- a/lite/operators/gru_unit_op.h +++ b/lite/operators/gru_unit_op.h @@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/im2sequence_op.cc b/lite/operators/im2sequence_op.cc index 40ab2106af85b3386f93385785b65b9293b1c7f9..ae7b1029468ddb9f723de522ce715859d9a08a09 100644 --- a/lite/operators/im2sequence_op.cc +++ b/lite/operators/im2sequence_op.cc @@ -26,7 +26,7 @@ inline int Im2SeqOutputSize( } bool Im2SequenceOp::CheckShape() const { return true; } -bool Im2SequenceOp::InferShape() const { +bool Im2SequenceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/im2sequence_op.h b/lite/operators/im2sequence_op.h index 83a347c913fd80c3a890053e1e1945b6cf2a7cd4..62525baaf071bb92b79773c248adb4fd1c798d90 100644 --- a/lite/operators/im2sequence_op.h +++ b/lite/operators/im2sequence_op.h @@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/increment_op.cc b/lite/operators/increment_op.cc index c1928ccbd4ca28ad1d1d83d2e232234ca1677aaa..9b34e4f73b8cc0e27cab06547d3fab84c7033b88 100644 --- a/lite/operators/increment_op.cc +++ b/lite/operators/increment_op.cc @@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const { return true; } -bool IncrementOp::InferShape() const { +bool IncrementOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/increment_op.h b/lite/operators/increment_op.h index f180d527c31494dcfb8cb53f005861ae639c9844..d4e6fd6b1ff1aea47df130d510bc84ab0a0b6019 100644 --- a/lite/operators/increment_op.h +++ b/lite/operators/increment_op.h @@ -30,7 +30,7 @@ class IncrementOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/instance_norm_op.cc b/lite/operators/instance_norm_op.cc index 510402ba1fb363f383b3cba8eb322a4ff7975c18..5f685ccfc59a7170a2d29d2b8e561ed933c8517c 100644 --- a/lite/operators/instance_norm_op.cc +++ b/lite/operators/instance_norm_op.cc @@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const { return true; } -bool InstanceNormOp::InferShape() const { +bool InstanceNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t batch_size = x_dims[0]; int64_t channel_size = x_dims[1]; diff --git a/lite/operators/instance_norm_op.h b/lite/operators/instance_norm_op.h index d128345805cf77ac2a4123a8549c92051593fff0..94a1f69fa4433072a986f1d82d5f1b8401a03386 100644 --- a/lite/operators/instance_norm_op.h +++ b/lite/operators/instance_norm_op.h @@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index 1bfb20df4e4b9762e93b6a39f0d34eb2521acfe0..0ef22e42903842ac41e9aca010f78796b5a32fcc 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const { return true; } -bool InterpolateOp::InferShape() const { +bool InterpolateOp::InferShapeImpl() const { auto X = param_.X; int n = X->dims()[0]; diff --git a/lite/operators/interpolate_op.h b/lite/operators/interpolate_op.h index 5fcf4ef594d52a4ac14e5545b195cc51cbf379cf..2bc938964811c57189e45d3b9d892542f9f02e8f 100644 --- a/lite/operators/interpolate_op.h +++ b/lite/operators/interpolate_op.h @@ -31,7 +31,7 @@ class InterpolateOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/io_copy_op.cc b/lite/operators/io_copy_op.cc index 7df636d7b2d877a5539a980080077be785d47505..05b2d3d800d2d2989ae23f9a1ccac57021e82ac1 100644 --- a/lite/operators/io_copy_op.cc +++ b/lite/operators/io_copy_op.cc @@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool IoCopyOp::InferShape() const { +bool IoCopyOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/io_copy_op.h b/lite/operators/io_copy_op.h index 8d6d69d63ed8b7ec289d7935ea28df2482e0cf31..d6922b667d78e3b79a005aae895b9e63dc76fa21 100644 --- a/lite/operators/io_copy_op.h +++ b/lite/operators/io_copy_op.h @@ -24,7 +24,7 @@ class IoCopyOp : public OpLite { public: explicit IoCopyOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/is_empty_op.cc b/lite/operators/is_empty_op.cc index ed4c69e64eaae8fdcb8289c5389dcff1df2ea8b5..a62470e4bb7f88d4c441dc8814bba7c4913ab3e4 100644 --- a/lite/operators/is_empty_op.cc +++ b/lite/operators/is_empty_op.cc @@ -21,7 +21,7 @@ namespace operators { bool IsEmptyOp::CheckShape() const { return true; } -bool IsEmptyOp::InferShape() const { return true; } +bool IsEmptyOp::InferShapeImpl() const { return true; } bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = diff --git a/lite/operators/is_empty_op.h b/lite/operators/is_empty_op.h index 5bfa0905c7c57110473fde48d78d17947abbb547..14c0830c233a9ff011b00d130bc36054a7ede57a 100644 --- a/lite/operators/is_empty_op.h +++ b/lite/operators/is_empty_op.h @@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layer_norm_op.cc b/lite/operators/layer_norm_op.cc index 18ea6cbf281846600273d6e7d462ed43f2e45637..2f50d232e3781e44b8203084382c20872094a263 100644 --- a/lite/operators/layer_norm_op.cc +++ b/lite/operators/layer_norm_op.cc @@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const { return true; } -bool LayerNormOp::InferShape() const { +bool LayerNormOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); param_.Y->Resize(out_dims); auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[0]; diff --git a/lite/operators/layer_norm_op.h b/lite/operators/layer_norm_op.h index 297f6bdd402b919b4baa1915135ed909c57cfa0b..6e15d2f599beb14df024f2591b098b128c3af8dd 100644 --- a/lite/operators/layer_norm_op.h +++ b/lite/operators/layer_norm_op.h @@ -30,7 +30,7 @@ class LayerNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layout_op.cc b/lite/operators/layout_op.cc index 01272568045233a90e2aaaffa758e4ce1515d700..d71dab68702ddd53af1540c2a6dce14d43b27e09 100644 --- a/lite/operators/layout_op.cc +++ b/lite/operators/layout_op.cc @@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool LayoutOp::InferShape() const { +bool LayoutOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/layout_op.h b/lite/operators/layout_op.h index 216d571d7c37204ec6ef6c513caba726841bcdf2..f51768863bf2e942262f364c271b902922b39cb1 100644 --- a/lite/operators/layout_op.h +++ b/lite/operators/layout_op.h @@ -24,7 +24,7 @@ class LayoutOp : public OpLite { public: explicit LayoutOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/lod_reset_op.cc b/lite/operators/lod_reset_op.cc index 1754e709ff2439462e8f40d047f5594ed740e07a..c30c78bbc6c1300660c01e6219c9e5113c39a718 100644 --- a/lite/operators/lod_reset_op.cc +++ b/lite/operators/lod_reset_op.cc @@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const { return true; } -bool LodResetOp::InferShape() const { +bool LodResetOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. param_.Out->Resize(param_.X->dims()); diff --git a/lite/operators/lod_reset_op.h b/lite/operators/lod_reset_op.h index 4e048a9a696c3e1e4a366c732bb269134c9d5d06..8ca2bc578099aabfe6c9649d58e9caeabea7870f 100644 --- a/lite/operators/lod_reset_op.h +++ b/lite/operators/lod_reset_op.h @@ -30,7 +30,7 @@ class LodResetOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/logical_op.cc b/lite/operators/logical_op.cc index 8af982ad535192f4897ea70cdb180b230d29dfd6..2dd5b798280ef80a54d557e449beee15959971b8 100644 --- a/lite/operators/logical_op.cc +++ b/lite/operators/logical_op.cc @@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const { return true; } -bool BinaryLogicalOp::InferShape() const { +bool BinaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); @@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const { return true; } -bool UnaryLogicalOp::InferShape() const { +bool UnaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/logical_op.h b/lite/operators/logical_op.h index a0fc1d68a60a0650179f66ca9fd443e96a483c34..e784d4d99b7de29593e411db9b6a888e5bd52e21 100644 --- a/lite/operators/logical_op.h +++ b/lite/operators/logical_op.h @@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_dequant_op.cc b/lite/operators/lookup_table_dequant_op.cc index b81043bfbfeed356e3d67065686057adfadcb25f..844544dfad3c535342169d08159a80484a29643d 100644 --- a/lite/operators/lookup_table_dequant_op.cc +++ b/lite/operators/lookup_table_dequant_op.cc @@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const { return true; } -bool LookupTableDequantOpLite::InferShape() const { +bool LookupTableDequantOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_dequant_op.h b/lite/operators/lookup_table_dequant_op.h index 3a9683d5ca0d87365cb240b91dccab07cf26ca71..a094cac9a49891294ec71194d39a023867f58052 100644 --- a/lite/operators/lookup_table_dequant_op.h +++ b/lite/operators/lookup_table_dequant_op.h @@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index df066435a8758e5a75ad1bed78111396d50b44cf..9bc22080bfb6c0ebda28e620dd9b781ec515ecbb 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const { return true; } -bool LookupTableOpLite::InferShape() const { +bool LookupTableOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_op.h b/lite/operators/lookup_table_op.h index 2701af984088cfda450f98fa5bc432dad7c2bc59..91ef77cfa1852a93d3aa28aceb616eec3306af3a 100644 --- a/lite/operators/lookup_table_op.h +++ b/lite/operators/lookup_table_op.h @@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_v2_op.cc b/lite/operators/lookup_table_v2_op.cc index df642e6191cffb748191b38eb5a6578aac163da4..8c76090df385ca5adf454ac1918c11c8838695f1 100644 --- a/lite/operators/lookup_table_v2_op.cc +++ b/lite/operators/lookup_table_v2_op.cc @@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const { return true; } -bool LookupTableV2OpLite::InferShape() const { +bool LookupTableV2OpLite::InferShapeImpl() const { auto table_dims = param_.W->dims(); auto ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_v2_op.h b/lite/operators/lookup_table_v2_op.h index dabff3f0cac75cb70cde6eb6e95df34dc36901fe..b0b8829fe6aeaf02a445109ea804266758919822 100644 --- a/lite/operators/lookup_table_v2_op.h +++ b/lite/operators/lookup_table_v2_op.h @@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lrn_op.cc b/lite/operators/lrn_op.cc index aff3e5af5566771411acf20736fdbec703f5def9..dcaffe1aa7cbc64c26dd2d56fcaa650e1599eb10 100644 --- a/lite/operators/lrn_op.cc +++ b/lite/operators/lrn_op.cc @@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const { return true; } -bool LrnOpLite::InferShape() const { +bool LrnOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/lrn_op.h b/lite/operators/lrn_op.h index a569a77fb40d7ea60e9e41171e73668e499684a5..13dfdefdc6f28dc289f490340faa14c166485db0 100644 --- a/lite/operators/lrn_op.h +++ b/lite/operators/lrn_op.h @@ -28,7 +28,7 @@ class LrnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lstm_op.cc b/lite/operators/lstm_op.cc index 36a0d2f53c1f30976ad6df811ad352721e3d7ff7..d9b6ebfc321190286d27272ea7b09a2a751cd9f1 100644 --- a/lite/operators/lstm_op.cc +++ b/lite/operators/lstm_op.cc @@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const { return true; } -bool LstmOp::InferShape() const { +bool LstmOp::InferShapeImpl() const { auto in_dims = param_.Input->dims(); if (param_.H0) { CHECK(param_.C0) << "lstm must has H0 and C0 in the same time"; diff --git a/lite/operators/lstm_op.h b/lite/operators/lstm_op.h index 221bd5c37945f4ff65b21a83449937563d9e5944..38bef385da67defa4e3459cfbcb6cbf24e0f2ed9 100644 --- a/lite/operators/lstm_op.h +++ b/lite/operators/lstm_op.h @@ -30,7 +30,7 @@ class LstmOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/match_matrix_tensor_op.cc b/lite/operators/match_matrix_tensor_op.cc index a8095a94bf75cd5d6d9087509449c159056ebc28..1cc751109f76a96097d363b493322dde182a715d 100644 --- a/lite/operators/match_matrix_tensor_op.cc +++ b/lite/operators/match_matrix_tensor_op.cc @@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const { return true; } -bool MatchMatrixTensorOpLite::InferShape() const { +bool MatchMatrixTensorOpLite::InferShapeImpl() const { const Tensor* x = param_.x; const Tensor* y = param_.y; DDim x_dims = param_.x->dims(); diff --git a/lite/operators/match_matrix_tensor_op.h b/lite/operators/match_matrix_tensor_op.h index 404183ea5bda3c35ba8b833853bc0005d60b9f7d..f1070a81b471ded59610af1a5bb40e35ccba7aff 100644 --- a/lite/operators/match_matrix_tensor_op.h +++ b/lite/operators/match_matrix_tensor_op.h @@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 286ade7b2130ce662eea2b7ba4e142bf489306ca..1cdcdfa16760385db059a4894e35d04bda51a85d 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const { return true; } -bool MatMulOpLite::InferShape() const { +bool MatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); bool x_transpose = param_.transpose_X; diff --git a/lite/operators/matmul_op.h b/lite/operators/matmul_op.h index 0aa47c89dd2227f70e7264c39b13c019d9b00587..acb9d512f7ac50818e9521ca67e04318397dabb0 100644 --- a/lite/operators/matmul_op.h +++ b/lite/operators/matmul_op.h @@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mean_grad_op.cc b/lite/operators/mean_grad_op.cc index fd17cac14fca153499a52e93f6f09ea44ea9a559..55e374735ea8d861c65f1296968a40a8b5b1f096 100644 --- a/lite/operators/mean_grad_op.cc +++ b/lite/operators/mean_grad_op.cc @@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const { return true; } -bool MeanGradOp::InferShape() const { +bool MeanGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/mean_grad_op.h b/lite/operators/mean_grad_op.h index 1bd604518bfc088fc45566e393fd997ae4eed06e..488581a71bb423c09540d17cbb05c170f6f06374 100644 --- a/lite/operators/mean_grad_op.h +++ b/lite/operators/mean_grad_op.h @@ -27,7 +27,7 @@ class MeanGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/mean_op.cc b/lite/operators/mean_op.cc index 618e9001db056b935de6aef8feff9125155d0e1a..9a66d4fbda3116ef7bd751f34f66eefd1f2e6e99 100644 --- a/lite/operators/mean_op.cc +++ b/lite/operators/mean_op.cc @@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const { return true; } -bool MeanOp::InferShape() const { +bool MeanOp::InferShapeImpl() const { param_.Out->Resize(std::vector{1}); return true; } diff --git a/lite/operators/mean_op.h b/lite/operators/mean_op.h index 8526842f93cb1d01debad9c6cb28ec28b98e43e9..c4dff93ce78aa4598bd12fb3181aa5f2bd4820b6 100644 --- a/lite/operators/mean_op.h +++ b/lite/operators/mean_op.h @@ -27,7 +27,7 @@ class MeanOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/merge_lod_tensor_op.cc b/lite/operators/merge_lod_tensor_op.cc index 4258715b1d1aa6bf7fac160dcd6fc8ca6dd3754d..704b5cad6fc80bee8bcb5dfd2921c5cf87182ff8 100644 --- a/lite/operators/merge_lod_tensor_op.cc +++ b/lite/operators/merge_lod_tensor_op.cc @@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const { return true; } -bool MergeLodTensorOpLite::InferShape() const { +bool MergeLodTensorOpLite::InferShapeImpl() const { auto dims = param_.in_true->dims(); param_.out->Resize(dims); return true; diff --git a/lite/operators/merge_lod_tensor_op.h b/lite/operators/merge_lod_tensor_op.h index 788a3451685cd0f42b72ee01e93e17da49507957..ec986fac1988efb5efa262c9fc340c6b450f8ddf 100644 --- a/lite/operators/merge_lod_tensor_op.h +++ b/lite/operators/merge_lod_tensor_op.h @@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/mul_grad_op.cc b/lite/operators/mul_grad_op.cc index 8215521637cbc29a4bdcc4b735b9658fc4cc4840..51e1fb310cb12d83dda9436bb73042c7b22fae11 100644 --- a/lite/operators/mul_grad_op.cc +++ b/lite/operators/mul_grad_op.cc @@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const { return true; } -bool MulGradOpLite::InferShape() const { +bool MulGradOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); if (param_.x_grad) { diff --git a/lite/operators/mul_grad_op.h b/lite/operators/mul_grad_op.h index ef61f54f9b88cd691ab98c4d8904b848dcea66b5..869aa60c6232000008cb57d110aa454396b2ff34 100644 --- a/lite/operators/mul_grad_op.h +++ b/lite/operators/mul_grad_op.h @@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mul_op.cc b/lite/operators/mul_op.cc index c870abdc8989b48d8aa2f14f989ad475c027995e..8641a041e38b7a85ee7f0af8b3536f0b9224b36f 100644 --- a/lite/operators/mul_op.cc +++ b/lite/operators/mul_op.cc @@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const { return true; } -bool MulOpLite::InferShape() const { +bool MulOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); diff --git a/lite/operators/mul_op.h b/lite/operators/mul_op.h index caf7bf6ae902ac4e4f22d4a9aadfa108fa7622da..10a2e2efaa4db0e106e3c56c2f9b1cec9fb55ac4 100644 --- a/lite/operators/mul_op.h +++ b/lite/operators/mul_op.h @@ -33,7 +33,7 @@ class MulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/multiclass_nms_op.cc b/lite/operators/multiclass_nms_op.cc index 9ec79f8b57d63f20325bf686c1280522aa4fa80a..3102030e4b2cdc40eae369d2a43e9b94287e1873 100644 --- a/lite/operators/multiclass_nms_op.cc +++ b/lite/operators/multiclass_nms_op.cc @@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const { return true; } -bool MulticlassNmsOpLite::InferShape() const { +bool MulticlassNmsOpLite::InferShapeImpl() const { auto box_dims = param_.bboxes->dims(); auto score_dims = param_.scores->dims(); auto score_size = score_dims.size(); diff --git a/lite/operators/multiclass_nms_op.h b/lite/operators/multiclass_nms_op.h index 7be0d17d7478bdcfb4c4c6b1f22e505fb9da0846..f74479f3c9a42e6f5ec06126fedf91a2e17b6c2f 100644 --- a/lite/operators/multiclass_nms_op.h +++ b/lite/operators/multiclass_nms_op.h @@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/negative_op.cc b/lite/operators/negative_op.cc index 4db1dd4feede42fc4267eb3fc3553c538807f1a8..2b98f0a90af812ac9c524368e41177377f4d69e2 100644 --- a/lite/operators/negative_op.cc +++ b/lite/operators/negative_op.cc @@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const { return true; } -bool NegativeOpLite::InferShape() const { +bool NegativeOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/negative_op.h b/lite/operators/negative_op.h index 83f1008c9630284956347b87151e58f49588b867..04ec92532559c050cc5a9e8ac6bdf9a817e0dc70 100644 --- a/lite/operators/negative_op.h +++ b/lite/operators/negative_op.h @@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/norm_op.cc b/lite/operators/norm_op.cc index dff26966d48889389e2837194c2bc5a96fc960e5..0513e5c942d73397f269f1fe7bb89572a97ae548 100644 --- a/lite/operators/norm_op.cc +++ b/lite/operators/norm_op.cc @@ -25,7 +25,7 @@ bool NormOp::CheckShape() const { return true; } -bool NormOp::InferShape() const { +bool NormOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/norm_op.h b/lite/operators/norm_op.h index ae4594ed023d47179a7125bd9183e39f505ae16b..5c69d959be81eaccddc396dadacf920493ef99f5 100644 --- a/lite/operators/norm_op.h +++ b/lite/operators/norm_op.h @@ -30,7 +30,7 @@ class NormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 36d3b42c6b315a3858f475bd5756579137528051..1e221a602a426f3f117c69b9525f2a1d85880ee0 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -24,6 +24,7 @@ #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/desc_apis.h" #include "lite/utils/all.h" +#include "lite/utils/variant.h" /* * This file contains all the argument parameter data structure for operators. */ @@ -32,6 +33,16 @@ namespace paddle { namespace lite { namespace operators { +struct ParamBase { + public: + const std::vector* input_tensor_ptrs() const { return nullptr; } + std::vector* output_tensor_ptrs() { return nullptr; } + + protected: + std::shared_ptr> input_tensor_ptrs_cache_{nullptr}; + std::shared_ptr> output_tensor_ptrs_cache_{nullptr}; +}; + using param_t = Any; #define WITH_INT8_CONFIG \ bool enable_int8{false}; \ @@ -41,38 +52,38 @@ using param_t = Any; int bit_length{8}; /// ----------------------- Functional operators ------------------------------ -struct FeedParam { +struct FeedParam : ParamBase { std::vector* feed_list{}; lite::Tensor* out{}; int col; }; -struct FetchParam { +struct FetchParam : ParamBase { const lite::Tensor* input{}; std::vector* fetch_list{}; int col; }; // Helper op for lite framework -struct IoCopyParam { +struct IoCopyParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct LayoutParam { +struct LayoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct CalibParam { +struct CalibParam : ParamBase { const lite::Tensor* input{}; lite::Tensor* output{}; float scale; }; -struct SubgraphParam { +struct SubgraphParam : ParamBase { std::vector input_names{}; std::vector output_names{}; std::vector input_data_names{}; @@ -84,7 +95,7 @@ struct SubgraphParam { /// -------------------------- NN operators ------------------------------------ -struct FcParam { +struct FcParam : ParamBase { lite::Tensor* input{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* bias{nullptr}; @@ -95,9 +106,24 @@ struct FcParam { bool padding_weights{false}; // for int8 WITH_INT8_CONFIG -}; - -struct SearchSeqFcParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({input})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct SearchSeqFcParam : ParamBase { lite::Tensor* x{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* b{nullptr}; @@ -106,7 +132,7 @@ struct SearchSeqFcParam { }; // For Interpolate Op -struct InterpolateParam { +struct InterpolateParam : ParamBase { lite::Tensor* X{}; lite::Tensor* OutSize{}; lite::Tensor* Out{}; @@ -123,7 +149,7 @@ struct InterpolateParam { }; // For Mul Op -struct MulParam { +struct MulParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; lite::Tensor* output{}; @@ -134,7 +160,7 @@ struct MulParam { WITH_INT8_CONFIG }; -struct MulGradParam { +struct MulGradParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* output_grad{}; @@ -146,7 +172,7 @@ struct MulGradParam { }; // For ReduceMean Op -struct ReduceMeanParam { +struct ReduceMeanParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; @@ -155,7 +181,7 @@ struct ReduceMeanParam { }; // For Stack Op -struct StackParam { +struct StackParam : ParamBase { std::vector X; lite::Tensor* Out{}; @@ -163,7 +189,7 @@ struct StackParam { }; // For Power Op -struct PowerParam { +struct PowerParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -172,7 +198,7 @@ struct PowerParam { float power{}; }; -struct ShuffleChannelParam { +struct ShuffleChannelParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -180,7 +206,7 @@ struct ShuffleChannelParam { }; // For Yolobox -struct YoloBoxParam { +struct YoloBoxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ImgSize{}; lite::Tensor* Boxes{}; @@ -193,7 +219,7 @@ struct YoloBoxParam { }; // For Scale Op -struct ScaleParam { +struct ScaleParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; @@ -203,14 +229,29 @@ struct ScaleParam { }; // For Softmax op -struct SoftmaxParam { +struct SoftmaxParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int axis{-1}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Reshape and Reshape2 Op -struct ReshapeParam { +struct ReshapeParam : ParamBase { const lite::Tensor* x{}; std::vector shape_tensor_vct{}; const lite::Tensor* shape_tensor{}; @@ -222,7 +263,7 @@ struct ReshapeParam { }; // For Concat op -struct ConcatParam { +struct ConcatParam : ParamBase { std::vector x{}; lite::Tensor* output{}; int axis{0}; @@ -230,7 +271,7 @@ struct ConcatParam { }; /// ----------------------- activation operators ---------------------- -struct ActivationParam { +struct ActivationParam : ParamBase { const lite::Tensor* X{}; float Leaky_relu_alpha{0}; // leaky_relu param float Relu_clipped_coef{6}; // relu_clipped param @@ -245,7 +286,7 @@ struct ActivationParam { lite_api::ActivationType active_type; }; -struct ActivationGradParam { +struct ActivationGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out{}; // for backward @@ -254,7 +295,7 @@ struct ActivationGradParam { }; // For Convolution op -struct ConvParam { +struct ConvParam : ParamBase { lite::Tensor* x{}; lite::Tensor* filter{}; lite::Tensor* bias{nullptr}; @@ -294,10 +335,26 @@ struct ConvParam { std::vector output_size; // for int8 WITH_INT8_CONFIG + + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For BatchNorm op -struct BatchNormParam { +struct BatchNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* bias{}; lite::Tensor* scale{}; @@ -316,7 +373,7 @@ struct BatchNormParam { }; // For Pooling op -struct PoolParam { +struct PoolParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::string pooling_type{""}; @@ -340,7 +397,7 @@ struct PoolParam { }; // For Dropout op -struct DropoutParam { +struct DropoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* mask{}; @@ -352,7 +409,7 @@ struct DropoutParam { }; // For Split op -struct SplitParam { +struct SplitParam : ParamBase { lite::Tensor* x{}; std::vector output{}; lite::Tensor* axis_tensor; @@ -364,7 +421,7 @@ struct SplitParam { }; // For Transpose op -struct TransposeParam { +struct TransposeParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* xshape{}; @@ -375,7 +432,7 @@ struct TransposeParam { }; /// ----------------------- element wise operators ---------------------- -struct ElementwiseParam { +struct ElementwiseParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -384,9 +441,24 @@ struct ElementwiseParam { WITH_INT8_CONFIG float x_input_scale{1.0}; float y_input_scale{1.0}; -}; - -struct ElementwiseGradParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct ElementwiseGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; const lite::Tensor* OutGrad{}; @@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { }; /// ----------------------- mean operators ---------------------- -struct MeanParam { +struct MeanParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct MeanGradParam { +struct MeanGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out_grad{}; // for backward @@ -417,7 +489,7 @@ struct MeanGradParam { }; /// ----------------------- fill_constant operators ---------------------- -struct FillConstantParam { +struct FillConstantParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; lite::Tensor* shape_tensor{nullptr}; @@ -429,7 +501,7 @@ struct FillConstantParam { lite::Tensor* out{}; }; -struct FillConstantBatchSizeLikeParam { +struct FillConstantBatchSizeLikeParam : ParamBase { const lite::Tensor* input{nullptr}; lite::Tensor* out{nullptr}; @@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam { }; // -struct FakeQuantizeMovingAvgMaxAbsParam { +struct FakeQuantizeMovingAvgMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; const lite::Tensor* in_accum{}; @@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam { float moving_rate{0.9}; }; -struct FakeDequantizeMaxAbsParam { +struct FakeDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; lite::Tensor* out{}; float max_range; }; -struct FakeChannelWiseDequantizeMaxAbsParam { +struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; std::vector scale_tensors{}; lite::Tensor* out{}; @@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam { }; /// ----------------------- sgd operators ---------------------- -struct SGDParam { +struct SGDParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; const lite::Tensor* Param{}; @@ -482,7 +554,7 @@ struct SGDParam { }; /// ----------------------- uniform_random operators ---------------------- -struct UniformRandomParam { +struct UniformRandomParam : ParamBase { std::vector shape{}; float min{-1.0f}; float max{1.0f}; @@ -491,12 +563,12 @@ struct UniformRandomParam { lite::Tensor* Out{}; }; /// ----------------------- negative operators -------------- -struct NegativeParam { +struct NegativeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- pad2d operators ---------------------- -struct Pad2dParam { +struct Pad2dParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector paddings{0, 0, 0, 0}; @@ -506,7 +578,7 @@ struct Pad2dParam { }; /// ----------------------- Crop operators ---------------------- -struct CropParam { +struct CropParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector offsets; @@ -514,21 +586,21 @@ struct CropParam { }; ///----------------------- argmax operators ---------------------- -struct ArgmaxParam { +struct ArgmaxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; int Axis{0}; }; ///----------------------- axpy operators ---------------------- -struct AxpyParam { +struct AxpyParam : ParamBase { lite::Tensor* Scale{}; lite::Tensor* X{}; lite::Tensor* Bias{}; lite::Tensor* Out{}; }; /// ----------------------- GRU unit operators ----------------------f -struct GRUUnitParam { +struct GRUUnitParam : ParamBase { enum ActType { identity, sigmoid, tanh, relu }; const lite::Tensor* input{nullptr}; const lite::Tensor* hidden_prev{nullptr}; @@ -544,7 +616,7 @@ struct GRUUnitParam { }; /// ------------------------------ lrn operators ------------------------------ -struct LrnParam { +struct LrnParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int n{5}; @@ -555,7 +627,7 @@ struct LrnParam { }; /// ----------------------- decode_bboxes operators ---------------------- -struct DecodeBboxesParam { +struct DecodeBboxesParam : ParamBase { const lite::Tensor* loc_data{}; const lite::Tensor* prior_data{}; lite::Tensor* bbox_data{}; @@ -571,7 +643,7 @@ struct DecodeBboxesParam { }; /// ----------------------- box_coder operators ---------------------- -struct BoxCoderParam { +struct BoxCoderParam : ParamBase { const lite::Tensor* prior_box{}; const lite::Tensor* prior_box_var{}; const lite::Tensor* target_box{}; @@ -584,7 +656,7 @@ struct BoxCoderParam { }; /// ----------------------- multiclass_nms operators ---------------------- -struct MulticlassNmsParam { +struct MulticlassNmsParam : ParamBase { const lite::Tensor* bboxes{}; const lite::Tensor* scores{}; lite::Tensor* out{}; @@ -599,7 +671,7 @@ struct MulticlassNmsParam { }; /// ----------------------- priorbox operators ---------------------- -struct PriorBoxParam { +struct PriorBoxParam : ParamBase { lite::Tensor* input{}; lite::Tensor* image{}; lite::Tensor* boxes{}; @@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam { std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f -struct GRUParam { +struct GRUParam : ParamBase { const lite::Tensor* input{nullptr}; const lite::Tensor* h0{nullptr}; const lite::Tensor* weight{nullptr}; @@ -645,7 +717,7 @@ struct GRUParam { }; /// ----------------------- BeamSearchDecode operators ----------------------f -struct BeamSearchDecodeParam { +struct BeamSearchDecodeParam : ParamBase { std::vector* ids{nullptr}; std::vector* scores{nullptr}; lite::Tensor* sentence_ids{nullptr}; @@ -655,21 +727,21 @@ struct BeamSearchDecodeParam { }; /// ----------------------- LookupTable operators ----------------------f -struct LookupTableParam { +struct LookupTableParam : ParamBase { const lite::Tensor* W{nullptr}; const lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct LookupTableDequantParam { +struct LookupTableDequantParam : ParamBase { lite::Tensor* W{nullptr}; lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct Im2SequenceParam { +struct Im2SequenceParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -679,19 +751,19 @@ struct Im2SequenceParam { std::vector out_strides{1, 1}; }; -struct SequenceSoftmaxParam { +struct SequenceSoftmaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct NormParam { +struct NormParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Norm{}; int axis{1}; float epsilon{1e-10}; }; -struct LayerNormParam { +struct LayerNormParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -702,13 +774,13 @@ struct LayerNormParam { float epsilon{1e-5}; }; -struct LogicalParam { +struct LogicalParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; }; -struct CompareParam { +struct CompareParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; bool force_cpu{0}; @@ -716,7 +788,7 @@ struct CompareParam { lite::Tensor* Out{}; }; -struct WhileParam { +struct WhileParam : ParamBase { Scope* scope{}; Tensor* cond{}; cpp::BlockDesc* sub_block{}; @@ -724,32 +796,32 @@ struct WhileParam { std::vector outs{}; }; -struct TopkParam { +struct TopkParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Indices{}; int K{1}; }; -struct IncrementParam { +struct IncrementParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; float step{1}; }; -struct WriteToArrayParam { +struct WriteToArrayParam : ParamBase { const lite::Tensor* X{nullptr}; const lite::Tensor* I{nullptr}; std::vector* Out{nullptr}; }; -struct ReadFromArrayParam { +struct ReadFromArrayParam : ParamBase { const std::vector* X{nullptr}; const lite::Tensor* I{nullptr}; lite::Tensor* Out{nullptr}; }; -struct BeamSearchParam { +struct BeamSearchParam : ParamBase { const lite::Tensor* pre_ids{}; const lite::Tensor* pre_scores{}; const lite::Tensor* ids{}; @@ -763,7 +835,7 @@ struct BeamSearchParam { bool is_accumulated; }; -struct SequencePoolParam { +struct SequencePoolParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::string pool_type{"AVERAGE"}; @@ -773,7 +845,7 @@ struct SequencePoolParam { #endif }; -struct SequenceConvParam { +struct SequenceConvParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Filter{}; lite::Tensor* Out{}; @@ -782,13 +854,13 @@ struct SequenceConvParam { int contextLength; }; -struct SequencePoolConcatParam { +struct SequencePoolConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; std::vector pool_type{}; }; -struct SearchGroupPaddingParam { +struct SearchGroupPaddingParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out_emb_padding{}; lite::Tensor* out_new{}; @@ -796,36 +868,36 @@ struct SearchGroupPaddingParam { int pad_id; }; -struct SequenceReshapeParam { +struct SequenceReshapeParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int new_dim; }; -struct SequenceExpandParam { +struct SequenceExpandParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; int ref_level{-1}; }; -struct SequenceExpandAsParam { +struct SequenceExpandAsParam : ParamBase { const lite::Tensor* x{nullptr}; const lite::Tensor* y{nullptr}; lite::Tensor* out{nullptr}; }; -struct SequenceReverseParam { +struct SequenceReverseParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct SequenceConcatParam { +struct SequenceConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; }; -struct AttentionPaddingMaskParam { +struct AttentionPaddingMaskParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int pad_id; @@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam { lite::Tensor* pad_begin{}; }; -struct SequenceArithmeticParam { +struct SequenceArithmeticParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int op_type{1}; lite::Tensor* Out{}; }; -struct ReduceMaxParam { +struct ReduceMaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector dim{}; bool keep_dim{false}; }; -struct LodResetParam { +struct LodResetParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -856,12 +928,12 @@ struct LodResetParam { bool append; }; -struct IsEmptyParam { +struct IsEmptyParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct ReduceParam { +struct ReduceParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::vector dim{0}; @@ -869,7 +941,7 @@ struct ReduceParam { bool reduce_all{false}; }; -struct VarConv2DParam { +struct VarConv2DParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -888,19 +960,19 @@ struct VarConv2DParam { }; /// ----------------------- shape operators ---------------------- -struct ShapeParam { +struct ShapeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct CastParam { +struct CastParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int out_dtype{2}; int in_dtype{2}; }; -struct SliceParam { +struct SliceParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector axes{}; @@ -914,7 +986,7 @@ struct SliceParam { lite::Tensor* EndsTensor{nullptr}; }; -struct AffineChannelParam { +struct AffineChannelParam : ParamBase { const lite::Tensor* X{}; // X is 4D tensor const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -922,7 +994,7 @@ struct AffineChannelParam { lite::Tensor* Out{}; }; -struct AnchorGeneratorParam { +struct AnchorGeneratorParam : ParamBase { const lite::Tensor* Input{}; std::vector anchor_sizes{}; std::vector aspect_ratios{}; @@ -934,7 +1006,7 @@ struct AnchorGeneratorParam { lite::Tensor* Variances{}; }; -struct GenerateProposalsParam { +struct GenerateProposalsParam : ParamBase { // inputs const lite::Tensor* Scores{}; const lite::Tensor* BboxDeltas{}; @@ -954,14 +1026,14 @@ struct GenerateProposalsParam { lite::Tensor* RpnRoiProbs{}; }; /// ----------------------- squeeze operators ---------------------- -struct SqueezeParam { +struct SqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; }; -struct UnsqueezeParam { +struct UnsqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; @@ -971,14 +1043,14 @@ struct UnsqueezeParam { }; /// ----------------------- expand operators ---------------------- -struct ExpandParam { +struct ExpandParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector expand_times{}; }; /// ----------------------- matmul operators ---------------------- -struct MatMulParam { +struct MatMulParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -987,20 +1059,20 @@ struct MatMulParam { float alpha{1.0f}; }; -struct GatherParam { +struct GatherParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Index{}; lite::Tensor* Out{}; }; /// ----------------------- assign operators ----------------------- -struct AssignParam { +struct AssignParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- roi_align operators ----------------------- -struct RoiAlignParam { +struct RoiAlignParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ROIs{}; lite::Tensor* Out{}; @@ -1011,13 +1083,13 @@ struct RoiAlignParam { }; /// ----------------------- box_clip operators ----------------------- -struct BoxClipParam { +struct BoxClipParam : ParamBase { const lite::Tensor* Input{}; const lite::Tensor* ImInfo{}; lite::Tensor* Output{}; }; -struct RangeParam { +struct RangeParam : ParamBase { const lite::Tensor* Start; const lite::Tensor* End; const lite::Tensor* Step; @@ -1025,7 +1097,7 @@ struct RangeParam { }; /// ----------------------- assign_value operators ----------------------- -struct AssignValueParam { +struct AssignValueParam : ParamBase { std::vector shape{}; int dtype{}; std::vector fp32_values{}; @@ -1034,7 +1106,7 @@ struct AssignValueParam { }; /// --------------- sequence_topk_avg_pooling operators ------------------ -struct SequenceTopkAvgPoolingParam { +struct SequenceTopkAvgPoolingParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam { }; /// --------------- search_fc operators ------------------ -struct SearchFcParam { +struct SearchFcParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* W{}; const lite::Tensor* b{}; @@ -1053,7 +1125,7 @@ struct SearchFcParam { int out_size{}; }; /// --------------------- match_matrix_tensor operators -------------------- -struct MatchMatrixTensorParam { +struct MatchMatrixTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* w{}; @@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam { }; /// --------------------- search_seq_depadding operators -------------------- -struct SearchSeqDepaddingParam { +struct SearchSeqDepaddingParam : ParamBase { const lite::Tensor* pad{}; const lite::Tensor* src{}; lite::Tensor* out{}; }; /// --------------------- search_grnn operators -------------------- -struct SearchGrnnParam { +struct SearchGrnnParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* wi{}; const lite::Tensor* wh{}; @@ -1084,7 +1156,7 @@ struct SearchGrnnParam { lite::Tensor* layout_input{}; }; -struct SplitLodTensorParam { +struct SplitLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; lite::Tensor* out_true{}; @@ -1092,7 +1164,7 @@ struct SplitLodTensorParam { int level{}; }; -struct MergeLodTensorParam { +struct MergeLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; const lite::Tensor* in_true{}; @@ -1101,7 +1173,7 @@ struct MergeLodTensorParam { int level{}; }; -struct ConditionalBlockParam { +struct ConditionalBlockParam : ParamBase { const lite::Tensor* cond{}; std::vector x{}; std::vector outs{}; @@ -1110,14 +1182,14 @@ struct ConditionalBlockParam { bool is_scalar_condition{}; }; -struct CollectFpnProposalsParam { +struct CollectFpnProposalsParam : ParamBase { std::vector multi_level_rois{}; std::vector multi_level_scores{}; lite::Tensor* fpn_rois{}; int post_nms_topN{}; }; -struct DistributeFpnProposalsParam { +struct DistributeFpnProposalsParam : ParamBase { const lite::Tensor* fpn_rois{}; std::vector multi_fpn_rois{}; lite::Tensor* restore_index{}; @@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam { }; /// --------------------- instance_norm operators -------------------- -struct InstanceNormParam { +struct InstanceNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* bias{}; @@ -1138,12 +1210,12 @@ struct InstanceNormParam { float epsilon; }; /// --------------------- grid sampler operators -------------------- -struct GridSamplerParam { +struct GridSamplerParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* grid{}; }; -struct LstmParam { +struct LstmParam : ParamBase { lite::Tensor* Input{}; lite::Tensor* Weight{}; lite::Tensor* Bias{}; @@ -1160,7 +1232,7 @@ struct LstmParam { std::string candidate_activation; }; -struct CrfDecodingParam { +struct CrfDecodingParam : ParamBase { lite::Tensor* emission{}; lite::Tensor* transition{}; lite::Tensor* label{}; diff --git a/lite/operators/pad2d_op.cc b/lite/operators/pad2d_op.cc index ff522b94b95091b6df6d4d2f71e18907c5118619..7af657c888f9b1b28a1b273a193be59e2ace895c 100644 --- a/lite/operators/pad2d_op.cc +++ b/lite/operators/pad2d_op.cc @@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const { return true; } -bool Pad2dOpLite::InferShape() const { +bool Pad2dOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1]; diff --git a/lite/operators/pad2d_op.h b/lite/operators/pad2d_op.h index c51a76a7aef5624b1480fd1b1cdf56bf23c63674..c6d2e565483655c6279af8318434f129ec92a5e5 100644 --- a/lite/operators/pad2d_op.h +++ b/lite/operators/pad2d_op.h @@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/pool_op.cc b/lite/operators/pool_op.cc index c6f6eed28f8cdb5f080b6f4367a1b88b1dbc0701..5fb990928ec1ae723bc12b695af1be5e50da5079 100644 --- a/lite/operators/pool_op.cc +++ b/lite/operators/pool_op.cc @@ -60,7 +60,7 @@ int PoolOutputSize(int input_size, return output_size; } -bool PoolOpLite::InferShape() const { +bool PoolOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); std::vector& ksize = param_.ksize; // dynamic update 4-pad diff --git a/lite/operators/pool_op.h b/lite/operators/pool_op.h index c44875ff95b554ca92cf5288597a5bdaf2cb1bf8..3fcf37e6348628d489e9a2097e2c8dac7eba3e3c 100644 --- a/lite/operators/pool_op.h +++ b/lite/operators/pool_op.h @@ -37,7 +37,7 @@ class PoolOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { diff --git a/lite/operators/power_op.cc b/lite/operators/power_op.cc index 578d95ad53ffe0481288934a7a04d0f9e4442440..83c9edfaca1505746640280633bf6d47cddc6146 100644 --- a/lite/operators/power_op.cc +++ b/lite/operators/power_op.cc @@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const { return true; } -bool PowerOp::InferShape() const { +bool PowerOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/power_op.h b/lite/operators/power_op.h index a6d43f4394a8d3a2141f32e1fb633aef8c8227f8..e89dfa7b8f682e029bfba1059fda9c17340c420b 100644 --- a/lite/operators/power_op.h +++ b/lite/operators/power_op.h @@ -31,7 +31,7 @@ class PowerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index c4717c8185b24cfd9f6a551dcb932dc325a502d2..f1b715a46e1378f805d91312cc7804cb4097ec02 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const { return true; } -bool PriorBoxOpLite::InferShape() const { return true; } +bool PriorBoxOpLite::InferShapeImpl() const { return true; } bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto input = opdesc.Input("Input").front(); diff --git a/lite/operators/prior_box_op.h b/lite/operators/prior_box_op.h index a393e80315eab07cc8558da8c26d6acad8cc76c1..1348b7cc73f6b731453584ef455813fe0d1cf8be 100644 --- a/lite/operators/prior_box_op.h +++ b/lite/operators/prior_box_op.h @@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/range_op.cc b/lite/operators/range_op.cc index a179d8ffe7abc1665b13f7d0dfeaa8b3c18cf1d5..19f474ba43b15153a7e2cca38f5ff9b097b41342 100644 --- a/lite/operators/range_op.cc +++ b/lite/operators/range_op.cc @@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) { : std::ceil(std::abs((end - start) / step)); } -bool RangeOpLite::InferShape() const { +bool RangeOpLite::InferShapeImpl() const { int start = param_.Start->data()[0]; int end = param_.End->data()[0]; int step = param_.Step->data()[0]; diff --git a/lite/operators/range_op.h b/lite/operators/range_op.h index a1c7d4d4cc43d72001ac3519cb1c4f85ab8196ff..982ef5abf25aac816c00da918147bac8933424a9 100644 --- a/lite/operators/range_op.h +++ b/lite/operators/range_op.h @@ -29,7 +29,7 @@ class RangeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/read_from_array_op.cc b/lite/operators/read_from_array_op.cc index 930eff1ff5ff100c085a4fdb6bdf3a032d44c14b..495fd752c90da528e474b7aa726c65fd6e66c123 100644 --- a/lite/operators/read_from_array_op.cc +++ b/lite/operators/read_from_array_op.cc @@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { return true; } -bool ReadFromArrayOp::InferShape() const { +bool ReadFromArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; auto out_dims = (*param_.X)[id].dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/read_from_array_op.h b/lite/operators/read_from_array_op.h index 5c7ba1468f59e27a273b368014c707676c48e36a..299a3abaedcf3618f5e28a9636d427961a97b931 100644 --- a/lite/operators/read_from_array_op.h +++ b/lite/operators/read_from_array_op.h @@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_max_op.cc b/lite/operators/reduce_max_op.cc index d7d90ee1f454556baee1a87cfd0023f8cf8c119d..ba48acd11f3517f33b020ede92e07cfadc5d497b 100644 --- a/lite/operators/reduce_max_op.cc +++ b/lite/operators/reduce_max_op.cc @@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const { return true; } -bool ReduceMaxOp::InferShape() const { +bool ReduceMaxOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_max_op.h b/lite/operators/reduce_max_op.h index 60e263f1b9b72a31c223cc60f89a7ddf81949e8c..54b136a7576fb2bb078c5bcae727b15d319bdf8e 100644 --- a/lite/operators/reduce_max_op.h +++ b/lite/operators/reduce_max_op.h @@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite { ReduceMaxOp() {} explicit ReduceMaxOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_mean_op.cc b/lite/operators/reduce_mean_op.cc index bce31c315c22e93d7758a05ecf2ace0668dd0cc1..c5baca5e87068d267ada21854b7769bf2bc19461 100644 --- a/lite/operators/reduce_mean_op.cc +++ b/lite/operators/reduce_mean_op.cc @@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const { return true; } -bool ReduceMeanOp::InferShape() const { +bool ReduceMeanOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_mean_op.h b/lite/operators/reduce_mean_op.h index e701a1132aa1260b5f169f89dec546a0d80fc916..43fe955690b3e4569f75c88a4d7b9ba9e961fcca 100644 --- a/lite/operators/reduce_mean_op.h +++ b/lite/operators/reduce_mean_op.h @@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite { ReduceMeanOp() {} explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index e2cc56b416dd166e6b22a0c642907844ab964cc5..1af6daf8c73e8e41f69be8f8af8f485ac767d702 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const { return true; } -bool ReduceOp::InferShape() const { +bool ReduceOp::InferShapeImpl() const { const auto &x_dims = param_.x->dims(); auto x_rank = x_dims.size(); auto dims = param_.dim; diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h index 0063aba1fa606c6228e7dcb1197bfb36f57aa33c..d4fdbd113586a57b0d5a1e6e5fbde6707efb7cc1 100644 --- a/lite/operators/reduce_ops.h +++ b/lite/operators/reduce_ops.h @@ -30,7 +30,7 @@ class ReduceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_prod_op.cc b/lite/operators/reduce_prod_op.cc index 90da13c8643fa030c376ca25cb3a67b70f3485a4..5a6194b36b9c0b4a95fb47049999da093f979e3b 100644 --- a/lite/operators/reduce_prod_op.cc +++ b/lite/operators/reduce_prod_op.cc @@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const { return true; } -bool ReduceProdOpLite::InferShape() const { +bool ReduceProdOpLite::InferShapeImpl() const { auto x = param_.x; auto out = param_.output; std::vector dim = param_.dim; diff --git a/lite/operators/reduce_prod_op.h b/lite/operators/reduce_prod_op.h index 5f7a6dcdf98eb99d9145b7e3108972f4debeaeb5..d8bb1400b9aecf449499d4c6920c2ef88eb119b2 100644 --- a/lite/operators/reduce_prod_op.h +++ b/lite/operators/reduce_prod_op.h @@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/relu_op.cc b/lite/operators/relu_op.cc index 9fa3ac8f30784b8349788dfd4eaf39252db1a156..e5f51676c69bcde6b68a9e9d17f936874a5ea86f 100644 --- a/lite/operators/relu_op.cc +++ b/lite/operators/relu_op.cc @@ -20,7 +20,7 @@ namespace lite { namespace operators { bool ReluOp::CheckShape() const { return true; } -bool ReluOp::InferShape() const { +bool ReluOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. diff --git a/lite/operators/relu_op.h b/lite/operators/relu_op.h index 23ca7ff16b48de747069f006cddbb9504e6942e3..7577f2ffbab62298138b22970c00caf9ab01367f 100644 --- a/lite/operators/relu_op.h +++ b/lite/operators/relu_op.h @@ -30,7 +30,7 @@ class ReluOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 655ac58bdcbfc0f8d9cdbb0ef0078db5eb0333fa..5c55eb4aa516ae3aecf49250f42d38491c1270f1 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const { return true; } -bool ReshapeOp::InferShape() const { +bool ReshapeOp::InferShapeImpl() const { const auto &shape_tensor_vct = param_.shape_tensor_vct; auto *shape_tensor = param_.shape_tensor; const auto &shape_vct = param_.shape_vct; @@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const { return true; } -bool Reshape2Op::InferShape() const { - ReshapeOp::InferShape(); +bool Reshape2Op::InferShapeImpl() const { + ReshapeOp::InferShapeImpl(); const auto &x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; diff --git a/lite/operators/reshape_op.h b/lite/operators/reshape_op.h index 1df49fb5f44c88978b78f17885a5ba4412aa9ab7..9dc302ec9706512b16cd9e7db38b944d2d1324f5 100644 --- a/lite/operators/reshape_op.h +++ b/lite/operators/reshape_op.h @@ -30,7 +30,7 @@ class ReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/roi_align_op.cc b/lite/operators/roi_align_op.cc index 2f65c0197ecf1324678c63b6bd16018f83389702..001934dcf8f77527666c1b5cc0a01afcade2af81 100644 --- a/lite/operators/roi_align_op.cc +++ b/lite/operators/roi_align_op.cc @@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const { return true; } -bool RoiAlignOpLite::InferShape() const { +bool RoiAlignOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); auto rois_dims = param_.ROIs->dims(); diff --git a/lite/operators/roi_align_op.h b/lite/operators/roi_align_op.h index f3dd1a47f5e2d0dbb39439c9789573b9b7a33728..65cc72534a2e2b63a1e024a55c766f2c1983f5ab 100644 --- a/lite/operators/roi_align_op.h +++ b/lite/operators/roi_align_op.h @@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/scale_op.cc b/lite/operators/scale_op.cc index 1398ea481194cae545fc8f1fa803eff5f5b78a31..3236277187462dd1185e698e5cb8fe919fe20b97 100644 --- a/lite/operators/scale_op.cc +++ b/lite/operators/scale_op.cc @@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const { return true; } -bool ScaleOp::InferShape() const { +bool ScaleOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/scale_op.h b/lite/operators/scale_op.h index 684da4ed47370090c5cb690ea728fa4f9147c4bf..38970bfcfd82eebce51612e6afb531cbf3b10966 100644 --- a/lite/operators/scale_op.h +++ b/lite/operators/scale_op.h @@ -30,7 +30,7 @@ class ScaleOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_aligned_mat_mul_op.cc b/lite/operators/search_aligned_mat_mul_op.cc index 43a276e3c7a2f7481ade2ee18c1446593f7c5f43..65ccbc2b793cb3a64c16a5b3bf7d869d8e271327 100644 --- a/lite/operators/search_aligned_mat_mul_op.cc +++ b/lite/operators/search_aligned_mat_mul_op.cc @@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const { return true; } -bool SearchAlignedMatMulOpLite::InferShape() const { +bool SearchAlignedMatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); const auto& x_lod = param_.X->lod(); diff --git a/lite/operators/search_aligned_mat_mul_op.h b/lite/operators/search_aligned_mat_mul_op.h index 7321b7e9d15331e6aad36364436a99d3d4089c8c..8242e06d0170a8a4c178f0e460c64f93b0c2bc3c 100644 --- a/lite/operators/search_aligned_mat_mul_op.h +++ b/lite/operators/search_aligned_mat_mul_op.h @@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_fc_op.cc b/lite/operators/search_fc_op.cc index 2e77e361624e681aa93e36610674df0e1f9a13af..3c64f24e48f750b367b75431333401329721a9b9 100644 --- a/lite/operators/search_fc_op.cc +++ b/lite/operators/search_fc_op.cc @@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const { return true; } -bool SearchFcOpLite::InferShape() const { +bool SearchFcOpLite::InferShapeImpl() const { auto out_size = param_.out_size; lite::DDim dims(std::vector({-1, out_size})); param_.Out->Resize(dims); diff --git a/lite/operators/search_fc_op.h b/lite/operators/search_fc_op.h index a871cadd33b4f7d4b6130a0b8ac2974a738ac0c3..235c24c57ff0e925d763fa11a78f56cfe72613cd 100644 --- a/lite/operators/search_fc_op.h +++ b/lite/operators/search_fc_op.h @@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_grnn_op.cc b/lite/operators/search_grnn_op.cc index b56ae820bf9de4ffe6aa3f6db7a8e1385c8cc11f..1ced477c109d8cd93485f0193523887759939f17 100644 --- a/lite/operators/search_grnn_op.cc +++ b/lite/operators/search_grnn_op.cc @@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const { return true; } -bool SearchGrnnOpLite::InferShape() const { +bool SearchGrnnOpLite::InferShapeImpl() const { const auto& x_dims = param_.x->dims(); const auto& x_lod = param_.x->lod(); CHECK_OR_FALSE(!x_lod.empty()); diff --git a/lite/operators/search_grnn_op.h b/lite/operators/search_grnn_op.h index 670af8a6c9ff9eafa33018a0303ea1a36b0a1e01..de4b1d8a5c4d551970fcbb7b0c17de67214b5c9a 100644 --- a/lite/operators/search_grnn_op.h +++ b/lite/operators/search_grnn_op.h @@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_group_padding_op.cc b/lite/operators/search_group_padding_op.cc index 5ba4dde275f4b9662416bdf5190cacfafc56a40d..b97c710109ea9eb1ae3b1e50e3bdab3e1e97ac3e 100644 --- a/lite/operators/search_group_padding_op.cc +++ b/lite/operators/search_group_padding_op.cc @@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const { return true; } -bool SearchGroupPaddingOp::InferShape() const { +bool SearchGroupPaddingOp::InferShapeImpl() const { std::vector x_dims = param_.x->dims().Vectorize(); param_.out_emb_padding->Resize({-1, x_dims[1]}); diff --git a/lite/operators/search_group_padding_op.h b/lite/operators/search_group_padding_op.h index a8e96c9697b5f7de70349efa1f8b378a47c3823c..6a93c7410128aa86b034308562b8c3ccd4ca78df 100644 --- a/lite/operators/search_group_padding_op.h +++ b/lite/operators/search_group_padding_op.h @@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite { SearchGroupPaddingOp() {} explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "search_group_padding"; } diff --git a/lite/operators/search_seq_depadding_op.cc b/lite/operators/search_seq_depadding_op.cc index 12d5123e05b41665550fb7e6b90a636093959263..6ad4f1ab171486468bf34b8341344410ed99f59b 100644 --- a/lite/operators/search_seq_depadding_op.cc +++ b/lite/operators/search_seq_depadding_op.cc @@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const { return true; } -bool SearchSeqDepaddingOpLite::InferShape() const { +bool SearchSeqDepaddingOpLite::InferShapeImpl() const { DDim pad_dims = param_.pad->dims(); param_.out->Resize({-1, pad_dims[1]}); return true; diff --git a/lite/operators/search_seq_depadding_op.h b/lite/operators/search_seq_depadding_op.h index 445d9e0f3bcba6204243e80023d826bf53d90c60..aa1cc22d4b048ca81445e735e09226b7dfe2fd03 100644 --- a/lite/operators/search_seq_depadding_op.h +++ b/lite/operators/search_seq_depadding_op.h @@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_seq_fc_op.cc b/lite/operators/search_seq_fc_op.cc index c5cca5331ab80479656b1212df02c20d463a3707..2a4525ac6e6f7e0cdd62a0a653e7188b274545af 100644 --- a/lite/operators/search_seq_fc_op.cc +++ b/lite/operators/search_seq_fc_op.cc @@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const { return true; } -bool SearchSeqFcOpLite::InferShape() const { +bool SearchSeqFcOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto w_dims = param_.w->dims(); const auto& x_lod = param_.x->lod(); diff --git a/lite/operators/search_seq_fc_op.h b/lite/operators/search_seq_fc_op.h index 3c4f7d82bfa66c2f323063f0297438c81ce18397..bacafcfe6ffa2a2c518cf3b8f226fa29c9b95e95 100644 --- a/lite/operators/search_seq_fc_op.h +++ b/lite/operators/search_seq_fc_op.h @@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_seq_softmax_op.cc b/lite/operators/search_seq_softmax_op.cc index 973ffa04c4562334af6d379b5446902036de8c5e..9b0550341c50df9cd48fa922139fc759c5289e97 100644 --- a/lite/operators/search_seq_softmax_op.cc +++ b/lite/operators/search_seq_softmax_op.cc @@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const { return true; } -bool SearchSeqSoftmaxOp::InferShape() const { +bool SearchSeqSoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); param_.output->set_lod(param_.x->lod()); return true; diff --git a/lite/operators/search_seq_softmax_op.h b/lite/operators/search_seq_softmax_op.h index f97e8ddd3a6c446fb5c53d5e603f43bbdf1e2525..dca3619eab9013f22d962b16c577c73862ee5e64 100644 --- a/lite/operators/search_seq_softmax_op.h +++ b/lite/operators/search_seq_softmax_op.h @@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_arithmetic_op.cc b/lite/operators/sequence_arithmetic_op.cc index 29c39ebc23f54c2c3c052e322575d97570195cfc..e17a179a860e13622979e5b42b07ae3459876fc7 100644 --- a/lite/operators/sequence_arithmetic_op.cc +++ b/lite/operators/sequence_arithmetic_op.cc @@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const { return true; } -bool SequenceArithmeticOp::InferShape() const { +bool SequenceArithmeticOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); param_.Out->set_lod(param_.X->lod()); return true; diff --git a/lite/operators/sequence_arithmetic_op.h b/lite/operators/sequence_arithmetic_op.h index 9f844dfbf429599d829bc786c66ba6d05e40d79d..cf9ef1583aeaed977c515441ca629b2e66efb3d2 100644 --- a/lite/operators/sequence_arithmetic_op.h +++ b/lite/operators/sequence_arithmetic_op.h @@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_concat_op.cc b/lite/operators/sequence_concat_op.cc index 88afe5e00fe2bfc173a8a1d1d0e63562cfb52518..91c70c0d2ff2d506d29dbeb01780de962f9a27f1 100644 --- a/lite/operators/sequence_concat_op.cc +++ b/lite/operators/sequence_concat_op.cc @@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const { return true; } -bool SequenceConcatOp::InferShape() const { return true; } +bool SequenceConcatOp::InferShapeImpl() const { return true; } bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/sequence_concat_op.h b/lite/operators/sequence_concat_op.h index 8cdc07ebca83b9c400b00a0f40556a788c5854e6..c7d61db7852fb8894c5c4ed7c3d4283480c90e48 100644 --- a/lite/operators/sequence_concat_op.h +++ b/lite/operators/sequence_concat_op.h @@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite { SequenceConcatOp() {} explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_concat"; } diff --git a/lite/operators/sequence_conv_op.cc b/lite/operators/sequence_conv_op.cc index 89596a22c616b45d0e72cc14501e4f6c148ad86c..681e05c9b69953c4dde6c873e66bee2e93839aaf 100644 --- a/lite/operators/sequence_conv_op.cc +++ b/lite/operators/sequence_conv_op.cc @@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const { return true; } -bool SequenceConvOp::InferShape() const { +bool SequenceConvOp::InferShapeImpl() const { const auto *input = param_.X; const auto *filter = param_.Filter; auto in_dims = input->dims(); diff --git a/lite/operators/sequence_conv_op.h b/lite/operators/sequence_conv_op.h index 34d65d3cc9324aea7b50a1d939a594b817889896..3ec7ac4d3da7822335e047ca1c681809914c192b 100644 --- a/lite/operators/sequence_conv_op.h +++ b/lite/operators/sequence_conv_op.h @@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite { SequenceConvOp() {} explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_expand_as_op.cc b/lite/operators/sequence_expand_as_op.cc index 22a4743103fd4b188357d067a062ea827de7aaa0..02c787b5a51749851de1484101a6339142bc9726 100644 --- a/lite/operators/sequence_expand_as_op.cc +++ b/lite/operators/sequence_expand_as_op.cc @@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const { return true; } -bool SequenceExpandAsOpLite::InferShape() const { +bool SequenceExpandAsOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto y_lod = param_.y->lod(); auto out_dims = x_dims; diff --git a/lite/operators/sequence_expand_as_op.h b/lite/operators/sequence_expand_as_op.h index 2eae8a26da31eb2937ab88f15d70bd44515e6a5f..19d6905c1a428ce4ac8b2cdb545f194bf47ee62d 100644 --- a/lite/operators/sequence_expand_as_op.h +++ b/lite/operators/sequence_expand_as_op.h @@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_expand_op.cc b/lite/operators/sequence_expand_op.cc index 0a5427a62ffca44070c9551a4f1c869ae184f0be..4bb3c66b26673a27a961729d6fe22d54ef9298fe 100644 --- a/lite/operators/sequence_expand_op.cc +++ b/lite/operators/sequence_expand_op.cc @@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const { return true; } -bool SequenceExpandOp::InferShape() const { +bool SequenceExpandOp::InferShapeImpl() const { const auto x_lod = param_.X->lod(); auto x_dims = param_.X->dims(); int ref_level = param_.ref_level; diff --git a/lite/operators/sequence_expand_op.h b/lite/operators/sequence_expand_op.h index da4b2fe71edb7f731bf53872960612e16efbef93..fffe2110d871941522e5924943be764e3ee51db5 100644 --- a/lite/operators/sequence_expand_op.h +++ b/lite/operators/sequence_expand_op.h @@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_pool_concat_op.cc b/lite/operators/sequence_pool_concat_op.cc index 9ee0d4d5967e0d36bb893b42033f2c5319c940bb..ce490e8246c621cb23b3a3eecc0e8ddc4bca28b1 100644 --- a/lite/operators/sequence_pool_concat_op.cc +++ b/lite/operators/sequence_pool_concat_op.cc @@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const { return true; } -bool SequencePoolConcatOp::InferShape() const { +bool SequencePoolConcatOp::InferShapeImpl() const { int out_dim = 0; for (int i = 0; i < param_.X.size(); ++i) { out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size()); diff --git a/lite/operators/sequence_pool_concat_op.h b/lite/operators/sequence_pool_concat_op.h index 7a70ceaf298ebd7d02c319b08a86f40dc36cb648..58e6fc18ba49f6885e1f4ffb86cba47ca86f9623 100644 --- a/lite/operators/sequence_pool_concat_op.h +++ b/lite/operators/sequence_pool_concat_op.h @@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite { SequencePoolConcatOp() {} explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_pool_op.cc b/lite/operators/sequence_pool_op.cc index be3726ffe7a73c50f92bec2f2a96fb1625e31a9e..6b4f7d8b789f11c815b86f7dcc990e6db7855bbd 100644 --- a/lite/operators/sequence_pool_op.cc +++ b/lite/operators/sequence_pool_op.cc @@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const { return true; } -bool SequencePoolOp::InferShape() const { +bool SequencePoolOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); out_dims[0] = input->lod()[0].size() - 1; diff --git a/lite/operators/sequence_pool_op.h b/lite/operators/sequence_pool_op.h index 215dd113a3e5d9cdb1707a9b1b70c5712a43ec5d..7b9e36bb5e6e5f47cf49b1bd0df62795b7d57b7e 100644 --- a/lite/operators/sequence_pool_op.h +++ b/lite/operators/sequence_pool_op.h @@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite { SequencePoolOp() {} explicit SequencePoolOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc index c7e86af65033205bcb389cecff8db14721507142..37ebd8a2bae3919062bc0e71e3a10193850e7877 100644 --- a/lite/operators/sequence_reshape_op.cc +++ b/lite/operators/sequence_reshape_op.cc @@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const { return true; } -bool SequenceReshapeOp::InferShape() const { +bool SequenceReshapeOp::InferShapeImpl() const { int new_dim = param_.new_dim; auto x_numel = param_.x->dims().production(); std::vector out_shape{x_numel / new_dim, diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h index c8378aebc44acf22017eee17f5b58d6ff4dd65bf..4ef395bdaa762d178e925f088c5c2becd357f669 100644 --- a/lite/operators/sequence_reshape_op.h +++ b/lite/operators/sequence_reshape_op.h @@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_reverse_op.cc b/lite/operators/sequence_reverse_op.cc index dd8fa2e8fd5816cc92355c9c73caf1aa76baf36c..19a47cac9da666269fc5ef2a172ff0295b71e95d 100644 --- a/lite/operators/sequence_reverse_op.cc +++ b/lite/operators/sequence_reverse_op.cc @@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const { return true; } -bool SequenceReverseOp::InferShape() const { +bool SequenceReverseOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/sequence_reverse_op.h b/lite/operators/sequence_reverse_op.h index 326d0f68927199e9353a5bbe8c072d342c9e3d69..68d9fdb0f16cf0b2e13b7ed7417572a7b971e785 100644 --- a/lite/operators/sequence_reverse_op.h +++ b/lite/operators/sequence_reverse_op.h @@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite { SequenceReverseOp() {} explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_reverse"; } diff --git a/lite/operators/sequence_softmax_op.cc b/lite/operators/sequence_softmax_op.cc index d106097ed5c2e3a712bbd87904164ccd612d1f9e..eb1821129d8b036a252fb36ab69094c8a58cce95 100644 --- a/lite/operators/sequence_softmax_op.cc +++ b/lite/operators/sequence_softmax_op.cc @@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool SequenceSoftmaxOp::InferShape() const { +bool SequenceSoftmaxOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/sequence_softmax_op.h b/lite/operators/sequence_softmax_op.h index 37dfc0d444be5c608c87c2418041237d4ac4643c..5942cb0441d7af7237c7761fe4ccd5d613321c87 100644 --- a/lite/operators/sequence_softmax_op.h +++ b/lite/operators/sequence_softmax_op.h @@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_topk_avg_pooling_op.cc b/lite/operators/sequence_topk_avg_pooling_op.cc index 6f5cbeeeee5816132d2ebcb7094949189931b931..cb6f12c4b33bfc04beae2574ca384fcd77ac5004 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.cc +++ b/lite/operators/sequence_topk_avg_pooling_op.cc @@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const { return true; } -bool SequenceTopkAvgPoolingOpLite::InferShape() const { +bool SequenceTopkAvgPoolingOpLite::InferShapeImpl() const { int channel_num = param_.channel_num; std::vector topks = param_.topks; auto row_dim = param_.ROW->dims(); diff --git a/lite/operators/sequence_topk_avg_pooling_op.h b/lite/operators/sequence_topk_avg_pooling_op.h index 1c1cfe3a9c7bc82c3e79fc372b98293183509dca..a619edc908a5e4d4a8db97a931acb2ce24e39008 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.h +++ b/lite/operators/sequence_topk_avg_pooling_op.h @@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sgd_op.cc b/lite/operators/sgd_op.cc index 621454259548d27f9dad23f01e1e392b007bcb5b..eb8cb6b72473310ca1df12e8510d74cc3d76f4aa 100644 --- a/lite/operators/sgd_op.cc +++ b/lite/operators/sgd_op.cc @@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const { return true; } -bool SGDOpLite::InferShape() const { +bool SGDOpLite::InferShapeImpl() const { param_.ParamOut->Resize(param_.Param->dims()); return true; } diff --git a/lite/operators/sgd_op.h b/lite/operators/sgd_op.h index 9159bf95a6a50b5cd7b5d0ffed15e06f8d0e11c5..6a29c8bfa61b455e2257600975e851860e8797cc 100644 --- a/lite/operators/sgd_op.h +++ b/lite/operators/sgd_op.h @@ -33,7 +33,7 @@ class SGDOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/shape_op.cc b/lite/operators/shape_op.cc index c6d5dc4d01a93dd4cc648358db0b6f462a116eb0..1661a909268eb15ea2c4b393e9a2831d438465c7 100644 --- a/lite/operators/shape_op.cc +++ b/lite/operators/shape_op.cc @@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const { return true; } -bool ShapeOpLite::InferShape() const { +bool ShapeOpLite::InferShapeImpl() const { std::vector shape_vec; shape_vec.push_back(static_cast(param_.X->dims().size())); param_.Out->Resize(shape_vec); diff --git a/lite/operators/shape_op.h b/lite/operators/shape_op.h index ada9961c75b1cbc6c91d94a4ed3473ca12d8dcd6..6512b8ac0213519b068a10a74fdcb9d715d73255 100644 --- a/lite/operators/shape_op.h +++ b/lite/operators/shape_op.h @@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/shuffle_channel_op.cc b/lite/operators/shuffle_channel_op.cc index 926aa932f3d278945b659b6113df6479c7515e20..d45643a3d82d9177f7719908ea572258e0029bef 100644 --- a/lite/operators/shuffle_channel_op.cc +++ b/lite/operators/shuffle_channel_op.cc @@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const { return true; } -bool ShuffleChannelOpLite::InferShape() const { +bool ShuffleChannelOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/shuffle_channel_op.h b/lite/operators/shuffle_channel_op.h index c48a47f61902087cecf874ee7ddee8313a3cf92a..768345898141dd869c6a59f69170559d68a9f498 100644 --- a/lite/operators/shuffle_channel_op.h +++ b/lite/operators/shuffle_channel_op.h @@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/slice_op.cc b/lite/operators/slice_op.cc index bbc3d1429e202dac7b9a53c00d83ee34de7ef3d1..cf7d94535cce5fa32d0f917c9d39e4746cee1c30 100644 --- a/lite/operators/slice_op.cc +++ b/lite/operators/slice_op.cc @@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const { return true; } -bool SliceOp::InferShape() const { +bool SliceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto in_dims = param_.X->dims(); diff --git a/lite/operators/slice_op.h b/lite/operators/slice_op.h index 936a1405f46ffd9e3375da1cd57b0570b07fcbbf..ec69f23d8ded4a7435bec0a2bd1f838603c7a7be 100644 --- a/lite/operators/slice_op.h +++ b/lite/operators/slice_op.h @@ -30,7 +30,7 @@ class SliceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 0989c9139763a435d67deb21a2ab233e1c2f3bd9..000953007c27e37bc05d85d810880f6ccd7728ce 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const { return true; } -bool SoftmaxOp::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (param_.x->dims() == last_input_shapes[0] && - param_.x->lod() == last_input_lods[0]) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - return true; -} - -bool SoftmaxOp::InferShape() const { +bool SoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); diff --git a/lite/operators/softmax_op.h b/lite/operators/softmax_op.h index c65d039fda02c5396eff829bede3b4ffdeac0051..20dc2f461e4f83e0b363d44e07c4204c656f2cf3 100644 --- a/lite/operators/softmax_op.h +++ b/lite/operators/softmax_op.h @@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_lod_tensor_op.cc b/lite/operators/split_lod_tensor_op.cc index 9b665b6026a44caa31b89ec7806188f90f5f1595..2900c8165dba3b8f0b83ef288c89ed0e56b4820d 100644 --- a/lite/operators/split_lod_tensor_op.cc +++ b/lite/operators/split_lod_tensor_op.cc @@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const { return true; } -bool SplitLodTensorOpLite::InferShape() const { +bool SplitLodTensorOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out_true->Resize(x_dims); param_.out_false->Resize(x_dims); diff --git a/lite/operators/split_lod_tensor_op.h b/lite/operators/split_lod_tensor_op.h index c7feef4f85df652d0c24f830076a078e20c111f9..fb7f85de5cae69d3c0844ee0eeabe98d45acde4a 100644 --- a/lite/operators/split_lod_tensor_op.h +++ b/lite/operators/split_lod_tensor_op.h @@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 834d68a3156700605e621a1ba71faec33fb7b745..71deb5631dd3523ebb0367b7db5e4049b785be7b 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const { return true; } -bool SplitOp::InferShape() const { +bool SplitOp::InferShapeImpl() const { const auto &outs = param_.output; auto in_dims = param_.x->dims(); int axis = param_.axis; diff --git a/lite/operators/split_op.h b/lite/operators/split_op.h index 66190742155a8268e510d5a8da47ab958a043418..3bb40a8d35e25145057d8c5790b25028ea571cd5 100644 --- a/lite/operators/split_op.h +++ b/lite/operators/split_op.h @@ -30,7 +30,7 @@ class SplitOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index 01f96c28ff6be38e426030aa3c580f28f73b3a38..633a6b4d4e45fd30bd72c8dcdfbbd96b8a8e8ebe 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const { return true; } -bool SqueezeOp::InferShape() const { +bool SqueezeOp::InferShapeImpl() const { std::vector squeeze_dims = param_.axes; DDim in_dims = param_.X->dims(); DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true); @@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const { return true; } -bool Squeeze2Op::InferShape() const { - SqueezeOp::InferShape(); +bool Squeeze2Op::InferShapeImpl() const { + SqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/squeeze_op.h b/lite/operators/squeeze_op.h index 1a550c5fbee59d43170b5ffa16caa81521c14d87..983e17acf6483da9e3e33c83b48e6e61455a4914 100644 --- a/lite/operators/squeeze_op.h +++ b/lite/operators/squeeze_op.h @@ -30,7 +30,7 @@ class SqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/stack_op.cc b/lite/operators/stack_op.cc index 8fdf61e8224aa06792bdbb3f41a4f1701039d8dd..0f9ba6662b16ce20acad497a4915cfc848b319cd 100644 --- a/lite/operators/stack_op.cc +++ b/lite/operators/stack_op.cc @@ -32,7 +32,7 @@ bool StackOp::CheckShape() const { return true; } -bool StackOp::InferShape() const { +bool StackOp::InferShapeImpl() const { auto input = param_.X; auto input_dims = input[0]->dims(); int axis = param_.axis; diff --git a/lite/operators/stack_op.h b/lite/operators/stack_op.h index 068d905338bde892b44630c64d3ec43771614f2a..9ce73057a313fd4b4f96914b3e962120de11ac43 100644 --- a/lite/operators/stack_op.h +++ b/lite/operators/stack_op.h @@ -31,7 +31,7 @@ class StackOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/subgraph_op.cc b/lite/operators/subgraph_op.cc index 58388669afa060d48ea4c3d674dff94c386f104a..9ac07e96334eda9f0001d33e0789f9de15c4ca67 100644 --- a/lite/operators/subgraph_op.cc +++ b/lite/operators/subgraph_op.cc @@ -22,7 +22,7 @@ namespace operators { bool SubgraphOp::CheckShape() const { return true; } -bool SubgraphOp::InferShape() const { return CheckShape(); /* enrich me */ } +bool SubgraphOp::InferShapeImpl() const { return CheckShape(); /* enrich me */ } bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { param_.input_names = op_desc.Input("Inputs"); diff --git a/lite/operators/subgraph_op.h b/lite/operators/subgraph_op.h index 7f593159c8651cc18fbea17e559f62297d5022e9..edbfb922044d60165e589d389cd8cfb3b2547796 100644 --- a/lite/operators/subgraph_op.h +++ b/lite/operators/subgraph_op.h @@ -35,7 +35,7 @@ class SubgraphOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/topk_op.cc b/lite/operators/topk_op.cc index fbfb825544870dfaf3e18d1595f2824970b7352b..4a68cbb4745473b21cc7b6c5f6c8fcef6e186e57 100644 --- a/lite/operators/topk_op.cc +++ b/lite/operators/topk_op.cc @@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const { return true; } -bool TopkOp::InferShape() const { +bool TopkOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); out_dims[out_dims.size() - 1] = param_.K; auto out = param_.Out; diff --git a/lite/operators/topk_op.h b/lite/operators/topk_op.h index 037fa413ea5ce6fcb5eb04502cf232cea7e109e0..d5888e5f1800ba37f4bed61c146b6af75e3f91fc 100644 --- a/lite/operators/topk_op.h +++ b/lite/operators/topk_op.h @@ -30,7 +30,7 @@ class TopkOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index 71086b492b538e293a1f08ed7f492a46d6eb02f8..40780346d038c875a2eb96b11aff9d1c2a578a2f 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const { return true; } -bool TransposeOp::InferShape() const { +bool TransposeOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); @@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const { return true; } -bool Transpose2Op::InferShape() const { +bool Transpose2Op::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); diff --git a/lite/operators/transpose_op.h b/lite/operators/transpose_op.h index ce352a7d82f4a9dd3899f21c252c003c1924dda6..39b75b96d858bb80a51e428b8d7f402258dd9cc1 100644 --- a/lite/operators/transpose_op.h +++ b/lite/operators/transpose_op.h @@ -31,7 +31,7 @@ class TransposeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -50,7 +50,7 @@ class Transpose2Op : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/uniform_random_op.cc b/lite/operators/uniform_random_op.cc index 93e74e2b0172e8c3948925f3334b011f37bc097e..512648bfe4acf245286c9be21223520789134897 100644 --- a/lite/operators/uniform_random_op.cc +++ b/lite/operators/uniform_random_op.cc @@ -22,7 +22,7 @@ namespace operators { bool UniformRandomOpLite::CheckShape() const { return true; } -bool UniformRandomOpLite::InferShape() const { +bool UniformRandomOpLite::InferShapeImpl() const { param_.Out->Resize(param_.shape); return true; } diff --git a/lite/operators/uniform_random_op.h b/lite/operators/uniform_random_op.h index f7dde8882f47fc533e0d47dac99acdb431509341..a7890ea3e74afb3fd67f7ba4d1f02861a7e4ae48 100644 --- a/lite/operators/uniform_random_op.h +++ b/lite/operators/uniform_random_op.h @@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 39b275b7b55f79f2c8daf16ab0a6acd2e76e8b48..b5ae90248abb4f2496a4dbca1c12317cf3a7d325 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const { return true; } -bool UnsqueezeOp::InferShape() const { +bool UnsqueezeOp::InferShapeImpl() const { std::vector final_axes; auto axes = param_.axes; auto *axes_tensor = param_.axes_tensor; @@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const { return true; } -bool Unsqueeze2Op::InferShape() const { - UnsqueezeOp::InferShape(); +bool Unsqueeze2Op::InferShapeImpl() const { + UnsqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/unsqueeze_op.h b/lite/operators/unsqueeze_op.h index 1e88828c6c5fdef767850909c0dae8ec65e9d1e0..5139b69c63699f041973c3cf31b38d6c7e9fa847 100644 --- a/lite/operators/unsqueeze_op.h +++ b/lite/operators/unsqueeze_op.h @@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 51f43c709990d7ac1e664336e252ed684479b783..8cf11f6465d73646ec9bf846cbe6347bdc4b9f5b 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -21,7 +21,7 @@ namespace operators { bool VarConv2dOp::CheckShape() const { return true; } -bool VarConv2dOp::InferShape() const { return true; } +bool VarConv2dOp::InferShapeImpl() const { return true; } bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = const_cast( diff --git a/lite/operators/var_conv_2d_op.h b/lite/operators/var_conv_2d_op.h index ce6309419cc582c2f93250dd6e8e59c04a951f91..5fa492d28ec858426bea7d3d06598813d94dbbb8 100644 --- a/lite/operators/var_conv_2d_op.h +++ b/lite/operators/var_conv_2d_op.h @@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite { VarConv2dOp() {} explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "var_conv_2d"; } diff --git a/lite/operators/while_op.cc b/lite/operators/while_op.cc index dba266af770183698680a49cb7ba4fe5dda2f5b2..1dcf9553f331ee6646ad6d93de048728a0886116 100644 --- a/lite/operators/while_op.cc +++ b/lite/operators/while_op.cc @@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const { return true; } -bool WhileOpLite::InferShape() const { return true; } +bool WhileOpLite::InferShapeImpl() const { return true; } bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("X"); diff --git a/lite/operators/while_op.h b/lite/operators/while_op.h index fcba722dbc182d0de617c3bf397a0266dc3d9cb2..94aec15a6d3eb60036bf9c2168fdbd855b84a396 100644 --- a/lite/operators/while_op.h +++ b/lite/operators/while_op.h @@ -30,7 +30,7 @@ class WhileOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/write_to_array_op.cc b/lite/operators/write_to_array_op.cc index bf2d9bc4b755c5800497e895f597aff22147e34f..d2cf7b4f94513d1058c3b4f4de1ec70c8c244b7e 100644 --- a/lite/operators/write_to_array_op.cc +++ b/lite/operators/write_to_array_op.cc @@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const { return true; } -bool WriteToArrayOp::InferShape() const { +bool WriteToArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; if (param_.Out->size() < id + 1) { param_.Out->resize(id + 1); diff --git a/lite/operators/write_to_array_op.h b/lite/operators/write_to_array_op.h index 8c987a24509d915d2ec59b90808993abe779623e..9460b7e364047750991d03468956462497fc4cc1 100644 --- a/lite/operators/write_to_array_op.h +++ b/lite/operators/write_to_array_op.h @@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/yolo_box_op.cc b/lite/operators/yolo_box_op.cc index c8186d3f3182e21856919c46b83fe96a6e2bef93..0a5481a8fb01b5401734beacbc18a0bafcc48457 100644 --- a/lite/operators/yolo_box_op.cc +++ b/lite/operators/yolo_box_op.cc @@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const { return true; } -bool YoloBoxOp::InferShape() const { +bool YoloBoxOp::InferShapeImpl() const { auto* X = param_.X; auto anchors = param_.anchors; int anchor_num = anchors.size() / 2; diff --git a/lite/operators/yolo_box_op.h b/lite/operators/yolo_box_op.h index 2e2ea6d63408ca7d1a1cd7db48b82bf1ced294de..85448000f34bb1f0b768f78bb5929d1a26462043 100644 --- a/lite/operators/yolo_box_op.h +++ b/lite/operators/yolo_box_op.h @@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;