diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index d477d6a32aa9fb87adb727d035f07862c446647e..e846be82a4c30e94c6f8fbc1e9b3ef995adccd68 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -239,33 +239,56 @@ class SubgraphEngine : public subgraph::Engine { std::vector> graph_in; - graph_in.reserve(origin_itensors_.size()); + if (shape_tensor_map_in_.find(inputs_shape_) != + shape_tensor_map_in_.end()) { + graph_in = shape_tensor_map_in_[inputs_shape_]; + for (size_t i = 0; i < origin_itensors_.size(); ++i) { + graph_in[i]->set_mlu_ptr( + const_cast(origin_itensors_[i]->raw_data())); + } + } else { + graph_in.reserve(origin_itensors_.size()); + for (size_t i = 0; i < origin_itensors_.size(); ++i) { + paddle::lite::subgraph::mlu::MLUTensor tmp( + origin_itensors_[i]->dims().Vectorize()); + // graph_input->at(i)->get_origin_shape()); + tmp.set_mlu_dtype(graph_input->at(i)->dtype()); + tmp.set_mlu_ptr(const_cast(origin_itensors_[i]->raw_data())); + graph_in.push_back( + std::make_shared(tmp)); + } + shape_tensor_map_in_[inputs_shape_] = graph_in; + } + std::vector> graph_out; - graph_out.reserve(origin_otensors_.size()); - - for (size_t i = 0; i < origin_itensors_.size(); ++i) { - paddle::lite::subgraph::mlu::MLUTensor tmp( - origin_itensors_[i]->dims().Vectorize()); - // graph_input->at(i)->get_origin_shape()); - tmp.set_mlu_dtype(graph_input->at(i)->dtype()); - tmp.set_mlu_ptr(const_cast(origin_itensors_[i]->raw_data())); - graph_in.push_back( - std::make_shared(tmp)); - } - for (size_t i = 0; i < origin_otensors_.size(); ++i) { - origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape()); - void* p_data = static_cast( - origin_otensors_[i] - ->mutable_data::T>(TARGET(kMLU))); - paddle::lite::subgraph::mlu::MLUTensor tmp( - origin_otensors_[i]->dims().Vectorize()); - // graph_output->at(i)->get_origin_shape()); - tmp.set_mlu_dtype(graph_output->at(i)->dtype()); - tmp.set_mlu_ptr(p_data); - graph_out.push_back( - std::make_shared(tmp)); + if (shape_tensor_map_out_.find(inputs_shape_) != + shape_tensor_map_out_.end()) { + graph_out = shape_tensor_map_out_[inputs_shape_]; + for (size_t i = 0; i < origin_otensors_.size(); ++i) { + void* p_data = static_cast( + origin_otensors_[i] + ->mutable_data::T>(TARGET(kMLU))); + graph_out[i]->set_mlu_ptr(p_data); + } + } else { + graph_out.reserve(origin_otensors_.size()); + for (size_t i = 0; i < origin_otensors_.size(); ++i) { + origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape()); + void* p_data = static_cast( + origin_otensors_[i] + ->mutable_data::T>(TARGET(kMLU))); + paddle::lite::subgraph::mlu::MLUTensor tmp( + origin_otensors_[i]->dims().Vectorize()); + // graph_output->at(i)->get_origin_shape()); + tmp.set_mlu_dtype(graph_output->at(i)->dtype()); + tmp.set_mlu_ptr(p_data); + graph_out.push_back( + std::make_shared(tmp)); + } + shape_tensor_map_out_[inputs_shape_] = graph_out; } auto& mlu_context = this->ctx_->template As(); @@ -314,6 +337,12 @@ class SubgraphEngine : public subgraph::Engine { std::map>, std::shared_ptr> shape_graph_map_{}; + std::map>, + std::vector>> + shape_tensor_map_out_{}; + std::map>, + std::vector>> + shape_tensor_map_in_{}; }; // namespace mlu template