From 739b4ac1874c03c6a8217d485545f60763f3238d Mon Sep 17 00:00:00 2001 From: jiaopu Date: Thu, 7 May 2020 18:28:01 +0800 Subject: [PATCH] fix runtime shape --- lite/kernels/mlu/bridges/graph.h | 4 ++-- lite/kernels/mlu/subgraph_compute.h | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lite/kernels/mlu/bridges/graph.h b/lite/kernels/mlu/bridges/graph.h index eb670f9a5f..f785bd04a6 100644 --- a/lite/kernels/mlu/bridges/graph.h +++ b/lite/kernels/mlu/bridges/graph.h @@ -123,8 +123,8 @@ class Graph { void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que, - std::vector> in, - std::vector> out) { + const std::vector>& in, + const std::vector>& out) { std::vector in_tensor; std::vector out_tensor; input_addrs_.resize(in.size()); diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index 6cc6f6e686..eb116321d1 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -238,7 +238,8 @@ class SubgraphEngine : public subgraph::Engine { for (size_t i = 0; i < origin_itensors_.size(); ++i) { paddle::lite::subgraph::mlu::MLUTensor tmp( - graph_input->at(i)->get_origin_shape()); + origin_itensors_[i]->dims().Vectorize()); + // graph_input->at(i)->get_origin_shape()); tmp.set_mlu_dtype(graph_input->at(i)->dtype()); tmp.set_mlu_ptr(const_cast(origin_itensors_[i]->raw_data())); graph_in.push_back( @@ -251,7 +252,8 @@ class SubgraphEngine : public subgraph::Engine { ->mutable_data::T>(TARGET(kMLU))); paddle::lite::subgraph::mlu::MLUTensor tmp( - graph_output->at(i)->get_origin_shape()); + origin_otensors_[i]->dims().Vectorize()); + // graph_output->at(i)->get_origin_shape()); tmp.set_mlu_dtype(graph_output->at(i)->dtype()); tmp.set_mlu_ptr(p_data); graph_out.push_back( -- GitLab