提交 50152b80 编写于 作者: X xiaogang 提交者: GitHub

fix: fix infershape profile (#3240)

test=develop
上级 1ff03aec
...@@ -65,6 +65,7 @@ class OpLite : public Registry { ...@@ -65,6 +65,7 @@ class OpLite : public Registry {
virtual bool CheckShape() const { return true; } virtual bool CheckShape() const { return true; }
// Inference the outputs' shape. // Inference the outputs' shape.
virtual bool InferShape() const { return true; } virtual bool InferShape() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); }
// Run this operator. // Run this operator.
virtual bool Run(); virtual bool Run();
// Indicate whether the Op runs only once or not // Indicate whether the Op runs only once or not
...@@ -150,6 +151,10 @@ class OpLite : public Registry { ...@@ -150,6 +151,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
std::vector<DDimLite> last_output_shapes;
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods;
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods;
}; };
/* /*
......
...@@ -286,7 +286,8 @@ void Instruction::Run() { ...@@ -286,7 +286,8 @@ void Instruction::Run() {
return; return;
} }
op_->InferShape(); // op_->InferShape();
op_->SmartInferShape();
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -80,6 +80,34 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings, ...@@ -80,6 +80,34 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
} }
} }
bool ConvOpLite::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.x->dims() &&
last_input_lods[0] == param_.x->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool ConvOpLite::InferShape() const { bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
...@@ -104,9 +132,9 @@ bool ConvOpLite::InferShape() const { ...@@ -104,9 +132,9 @@ bool ConvOpLite::InferShape() const {
// Set output dims // Set output dims
param_.output->Resize(lite::DDim(output_shape)); param_.output->Resize(lite::DDim(output_shape));
// share LoD // share LoD
// param_.output->set_lod(param_.x->lod()); param_.output->set_lod(param_.x->lod());
return true; return true;
} }
......
...@@ -36,6 +36,7 @@ class ConvOpLite : public OpLite { ...@@ -36,6 +36,7 @@ class ConvOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
bool SmartInferShape() override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
...@@ -26,7 +26,38 @@ bool ElementwiseOp::CheckShape() const { ...@@ -26,7 +26,38 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true; return true;
} }
bool ElementwiseOp::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.X->dims() &&
last_input_shapes[1] == param_.Y->dims() &&
last_input_lods[0] == param_.X->lod() &&
last_input_lods[1] == param_.Y->lod()) {
param_.Out->Resize(last_output_shapes[0]);
param_.Out->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.X->dims());
last_input_lods.push_back(param_.X->lod());
last_input_shapes.push_back(param_.Y->dims());
last_input_lods.push_back(param_.Y->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.Out->dims());
last_output_lods.push_back(param_.Out->lod());
return true;
}
bool ElementwiseOp::InferShape() const { bool ElementwiseOp::InferShape() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
...@@ -81,6 +112,7 @@ bool ElementwiseOp::InferShape() const { ...@@ -81,6 +112,7 @@ bool ElementwiseOp::InferShape() const {
auto out_lod = param_.Out->mutable_lod(); auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod(); *out_lod = param_.X->lod();
} }
return true; return true;
} }
......
...@@ -28,6 +28,7 @@ class ElementwiseOp : public OpLite { ...@@ -28,6 +28,7 @@ class ElementwiseOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -48,6 +48,33 @@ bool FcOpLite::CheckShape() const { ...@@ -48,6 +48,33 @@ bool FcOpLite::CheckShape() const {
return true; return true;
} }
bool FcOpLite::SmartInferShape() {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (last_input_shapes[0] == param_.input->dims() &&
last_input_lods[0] == param_.input->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.input->dims());
last_input_lods.push_back(param_.input->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool FcOpLite::InferShape() const { bool FcOpLite::InferShape() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims(); const auto& w_dims = param_.w->dims();
...@@ -64,6 +91,7 @@ bool FcOpLite::InferShape() const { ...@@ -64,6 +91,7 @@ bool FcOpLite::InferShape() const {
// share LoD // share LoD
param_.output->set_lod(param_.input->lod()); param_.output->set_lod(param_.input->lod());
return true; return true;
} }
......
...@@ -36,6 +36,7 @@ class FcOpLite : public OpLite { ...@@ -36,6 +36,7 @@ class FcOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
...@@ -29,10 +29,39 @@ bool SoftmaxOp::CheckShape() const { ...@@ -29,10 +29,39 @@ bool SoftmaxOp::CheckShape() const {
return true; return true;
} }
bool SoftmaxOp::SmartInferShape() {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (param_.x->dims() == last_input_shapes[0] &&
param_.x->lod() == last_input_lods[0]) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool SoftmaxOp::InferShape() const { bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims()); param_.output->Resize(param_.x->dims());
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod(); *out_lod = param_.x->lod();
return true; return true;
} }
......
...@@ -31,6 +31,7 @@ class SoftmaxOp : public OpLite { ...@@ -31,6 +31,7 @@ class SoftmaxOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册