From 50152b80c00cbabbcbb696361720dacc0d1f54c3 Mon Sep 17 00:00:00 2001 From: xiaogang Date: Wed, 25 Mar 2020 21:23:21 +0800 Subject: [PATCH] fix: fix infershape profile (#3240) test=develop --- lite/core/op_lite.h | 5 +++++ lite/core/program.cc | 3 ++- lite/operators/conv_op.cc | 32 +++++++++++++++++++++++++++++-- lite/operators/conv_op.h | 1 + lite/operators/elementwise_ops.cc | 32 +++++++++++++++++++++++++++++++ lite/operators/elementwise_ops.h | 1 + lite/operators/fc_op.cc | 28 +++++++++++++++++++++++++++ lite/operators/fc_op.h | 1 + lite/operators/softmax_op.cc | 29 ++++++++++++++++++++++++++++ lite/operators/softmax_op.h | 1 + 10 files changed, 130 insertions(+), 3 deletions(-) diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 5dec9ed7aa..77d8091b4b 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -65,6 +65,7 @@ class OpLite : public Registry { virtual bool CheckShape() const { return true; } // Inference the outputs' shape. virtual bool InferShape() const { return true; } + virtual bool SmartInferShape() { return this->InferShape(); } // Run this operator. virtual bool Run(); // Indicate whether the Op runs only once or not @@ -150,6 +151,10 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; + std::vector last_input_shapes; + std::vector last_output_shapes; + std::vector>> last_output_lods; + std::vector>> last_input_lods; }; /* diff --git a/lite/core/program.cc b/lite/core/program.cc index 7284c3983c..580389fbad 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -286,7 +286,8 @@ void Instruction::Run() { return; } - op_->InferShape(); + // op_->InferShape(); + op_->SmartInferShape(); kernel_->Launch(); has_run_ = true; } diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 9ae52d1cb6..70ad3a32a8 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -80,6 +80,34 @@ void UpdatePaddingAndDilation(std::vector* paddings, } } +bool ConvOpLite::SmartInferShape() { + if (!last_input_shapes.empty()) { + if (last_input_shapes[0] == param_.x->dims() && + last_input_lods[0] == param_.x->lod()) { + param_.output->Resize(last_output_shapes[0]); + param_.output->set_lod(last_output_lods[0]); + return true; + } + } + + this->InferShape(); + + if (!last_input_shapes.empty()) { + last_input_shapes.clear(); + last_input_lods.clear(); + } + last_input_shapes.push_back(param_.x->dims()); + last_input_lods.push_back(param_.x->lod()); + + if (!last_output_shapes.empty()) { + last_output_shapes.clear(); + last_output_lods.clear(); + } + last_output_shapes.push_back(param_.output->dims()); + last_output_lods.push_back(param_.output->lod()); + + return true; +} bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); @@ -104,9 +132,9 @@ bool ConvOpLite::InferShape() const { // Set output dims param_.output->Resize(lite::DDim(output_shape)); - // share LoD - // param_.output->set_lod(param_.x->lod()); + param_.output->set_lod(param_.x->lod()); + return true; } diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 63107022f1..3379fb4095 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -36,6 +36,7 @@ class ConvOpLite : public OpLite { bool CheckShape() const override; bool InferShape() const override; + bool SmartInferShape() override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index 3dc6f06955..044126b3c2 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -26,7 +26,38 @@ bool ElementwiseOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } +bool ElementwiseOp::SmartInferShape() { + if (!last_input_shapes.empty()) { + if (last_input_shapes[0] == param_.X->dims() && + last_input_shapes[1] == param_.Y->dims() && + last_input_lods[0] == param_.X->lod() && + last_input_lods[1] == param_.Y->lod()) { + param_.Out->Resize(last_output_shapes[0]); + param_.Out->set_lod(last_output_lods[0]); + return true; + } + } + + this->InferShape(); + + if (!last_input_shapes.empty()) { + last_input_shapes.clear(); + last_input_lods.clear(); + } + last_input_shapes.push_back(param_.X->dims()); + last_input_lods.push_back(param_.X->lod()); + last_input_shapes.push_back(param_.Y->dims()); + last_input_lods.push_back(param_.Y->lod()); + + if (!last_output_shapes.empty()) { + last_output_shapes.clear(); + last_output_lods.clear(); + } + last_output_shapes.push_back(param_.Out->dims()); + last_output_lods.push_back(param_.Out->lod()); + return true; +} bool ElementwiseOp::InferShape() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); @@ -81,6 +112,7 @@ bool ElementwiseOp::InferShape() const { auto out_lod = param_.Out->mutable_lod(); *out_lod = param_.X->lod(); } + return true; } diff --git a/lite/operators/elementwise_ops.h b/lite/operators/elementwise_ops.h index d888e3d1c1..9d6e5781b9 100644 --- a/lite/operators/elementwise_ops.h +++ b/lite/operators/elementwise_ops.h @@ -28,6 +28,7 @@ class ElementwiseOp : public OpLite { bool CheckShape() const override; bool InferShape() const override; + bool SmartInferShape() override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index eff9300fea..345fc0d605 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -48,6 +48,33 @@ bool FcOpLite::CheckShape() const { return true; } +bool FcOpLite::SmartInferShape() { + if (!last_input_shapes.empty() && !last_output_shapes.empty()) { + if (last_input_shapes[0] == param_.input->dims() && + last_input_lods[0] == param_.input->lod()) { + param_.output->Resize(last_output_shapes[0]); + param_.output->set_lod(last_output_lods[0]); + return true; + } + } + + this->InferShape(); + + if (!last_input_shapes.empty()) { + last_input_shapes.clear(); + last_input_lods.clear(); + } + last_input_shapes.push_back(param_.input->dims()); + last_input_lods.push_back(param_.input->lod()); + if (!last_output_shapes.empty()) { + last_output_shapes.clear(); + last_output_lods.clear(); + } + last_output_shapes.push_back(param_.output->dims()); + last_output_lods.push_back(param_.output->lod()); + + return true; +} bool FcOpLite::InferShape() const { const auto& input_dims = param_.input->dims(); const auto& w_dims = param_.w->dims(); @@ -64,6 +91,7 @@ bool FcOpLite::InferShape() const { // share LoD param_.output->set_lod(param_.input->lod()); + return true; } diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index ec449cd4bd..f5dc302e27 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -36,6 +36,7 @@ class FcOpLite : public OpLite { bool CheckShape() const override; bool InferShape() const override; + bool SmartInferShape() override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 1e89fc1a2a..0989c91397 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -29,10 +29,39 @@ bool SoftmaxOp::CheckShape() const { return true; } +bool SoftmaxOp::SmartInferShape() { + if (!last_input_shapes.empty() && !last_output_shapes.empty()) { + if (param_.x->dims() == last_input_shapes[0] && + param_.x->lod() == last_input_lods[0]) { + param_.output->Resize(last_output_shapes[0]); + param_.output->set_lod(last_output_lods[0]); + return true; + } + } + + this->InferShape(); + + if (!last_input_shapes.empty()) { + last_input_shapes.clear(); + last_input_lods.clear(); + } + last_input_shapes.push_back(param_.x->dims()); + last_input_lods.push_back(param_.x->lod()); + + if (!last_output_shapes.empty()) { + last_output_shapes.clear(); + last_output_lods.clear(); + } + last_output_shapes.push_back(param_.output->dims()); + last_output_lods.push_back(param_.output->lod()); + return true; +} + bool SoftmaxOp::InferShape() const { param_.output->Resize(param_.x->dims()); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); + return true; } diff --git a/lite/operators/softmax_op.h b/lite/operators/softmax_op.h index bb24acad34..c65d039fda 100644 --- a/lite/operators/softmax_op.h +++ b/lite/operators/softmax_op.h @@ -31,6 +31,7 @@ class SoftmaxOp : public OpLite { bool CheckShape() const override; bool InferShape() const override; + bool SmartInferShape() override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; -- GitLab