提交 739b4ac1 编写于 作者: J jiaopu

fix runtime shape

上级 6871c0c7
...@@ -123,8 +123,8 @@ class Graph { ...@@ -123,8 +123,8 @@ class Graph {
void Compute(cnrtInvokeFuncParam_t forward_param, void Compute(cnrtInvokeFuncParam_t forward_param,
cnrtQueue_t que, cnrtQueue_t que,
std::vector<std::shared_ptr<MLUTensor>> in, const std::vector<std::shared_ptr<MLUTensor>>& in,
std::vector<std::shared_ptr<MLUTensor>> out) { const std::vector<std::shared_ptr<MLUTensor>>& out) {
std::vector<cnmlTensor_t> in_tensor; std::vector<cnmlTensor_t> in_tensor;
std::vector<cnmlTensor_t> out_tensor; std::vector<cnmlTensor_t> out_tensor;
input_addrs_.resize(in.size()); input_addrs_.resize(in.size());
......
...@@ -238,7 +238,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -238,7 +238,8 @@ class SubgraphEngine : public subgraph::Engine {
for (size_t i = 0; i < origin_itensors_.size(); ++i) { for (size_t i = 0; i < origin_itensors_.size(); ++i) {
paddle::lite::subgraph::mlu::MLUTensor tmp( paddle::lite::subgraph::mlu::MLUTensor tmp(
graph_input->at(i)->get_origin_shape()); origin_itensors_[i]->dims().Vectorize());
// graph_input->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_input->at(i)->dtype()); tmp.set_mlu_dtype(graph_input->at(i)->dtype());
tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data())); tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data()));
graph_in.push_back( graph_in.push_back(
...@@ -251,7 +252,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -251,7 +252,8 @@ class SubgraphEngine : public subgraph::Engine {
->mutable_data<typename paddle::lite::subgraph::mlu::FPTypeTraits< ->mutable_data<typename paddle::lite::subgraph::mlu::FPTypeTraits<
Precision>::T>(TARGET(kMLU))); Precision>::T>(TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp( paddle::lite::subgraph::mlu::MLUTensor tmp(
graph_output->at(i)->get_origin_shape()); origin_otensors_[i]->dims().Vectorize());
// graph_output->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_output->at(i)->dtype()); tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data); tmp.set_mlu_ptr(p_data);
graph_out.push_back( graph_out.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册