提交 50152b80 编写于 作者: X xiaogang 提交者: GitHub

fix: fix infershape profile (#3240)

test=develop
上级 1ff03aec
......@@ -65,6 +65,7 @@ class OpLite : public Registry {
virtual bool CheckShape() const { return true; }
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); }
// Run this operator.
virtual bool Run();
// Indicate whether the Op runs only once or not
......@@ -150,6 +151,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
std::vector<DDimLite> last_output_shapes;
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods;
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods;
};
/*
......
......@@ -286,7 +286,8 @@ void Instruction::Run() {
return;
}
op_->InferShape();
// op_->InferShape();
op_->SmartInferShape();
kernel_->Launch();
has_run_ = true;
}
......
......@@ -80,6 +80,34 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
}
}
bool ConvOpLite::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.x->dims() &&
last_input_lods[0] == param_.x->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......@@ -104,9 +132,9 @@ bool ConvOpLite::InferShape() const {
// Set output dims
param_.output->Resize(lite::DDim(output_shape));
// share LoD
// param_.output->set_lod(param_.x->lod());
param_.output->set_lod(param_.x->lod());
return true;
}
......
......@@ -36,6 +36,7 @@ class ConvOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
......@@ -26,7 +26,38 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out);
return true;
}
bool ElementwiseOp::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.X->dims() &&
last_input_shapes[1] == param_.Y->dims() &&
last_input_lods[0] == param_.X->lod() &&
last_input_lods[1] == param_.Y->lod()) {
param_.Out->Resize(last_output_shapes[0]);
param_.Out->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.X->dims());
last_input_lods.push_back(param_.X->lod());
last_input_shapes.push_back(param_.Y->dims());
last_input_lods.push_back(param_.Y->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.Out->dims());
last_output_lods.push_back(param_.Out->lod());
return true;
}
bool ElementwiseOp::InferShape() const {
auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims();
......@@ -81,6 +112,7 @@ bool ElementwiseOp::InferShape() const {
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
}
return true;
}
......
......@@ -28,6 +28,7 @@ class ElementwiseOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -48,6 +48,33 @@ bool FcOpLite::CheckShape() const {
return true;
}
bool FcOpLite::SmartInferShape() {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (last_input_shapes[0] == param_.input->dims() &&
last_input_lods[0] == param_.input->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.input->dims());
last_input_lods.push_back(param_.input->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool FcOpLite::InferShape() const {
const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims();
......@@ -64,6 +91,7 @@ bool FcOpLite::InferShape() const {
// share LoD
param_.output->set_lod(param_.input->lod());
return true;
}
......
......@@ -36,6 +36,7 @@ class FcOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
......@@ -29,10 +29,39 @@ bool SoftmaxOp::CheckShape() const {
return true;
}
bool SoftmaxOp::SmartInferShape() {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (param_.x->dims() == last_input_shapes[0] &&
param_.x->lod() == last_input_lods[0]) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims());
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod();
return true;
}
......
......@@ -31,6 +31,7 @@ class SoftmaxOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册