From 3b436c621f85df4945f28399004f3cba168fef9e Mon Sep 17 00:00:00 2001 From: dingminghui Date: Wed, 20 May 2020 20:00:46 +0800 Subject: [PATCH] refactor(cast): use mlu cast as default --- lite/core/mir/mlu_postprocess_pass.cc | 10 +++++----- lite/kernels/mlu/subgraph_compute.h | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index e794e6a313..61a56fa7e1 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -888,14 +888,14 @@ void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { #endif g_stream_id = static_cast(reinterpret_cast(graph.get())); - bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST"); - ModifyValidPlaces(graph.get(), use_mlu_cast); + bool disable_mlu_cast = GetBoolFromEnv("LITE_DISABLE_MLU_CAST"); + ModifyValidPlaces(graph.get(), !disable_mlu_cast); // insert io_copy, layout and precision cast of subgraph's inputs and outputs for (auto& node : graph->mutable_nodes()) { if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") { const Type* subgraph_arg_type = nullptr; GetSubgraphOpArgType(&node, &subgraph_arg_type, graph.get()); - if (use_mlu_cast) { + if (!disable_mlu_cast) { AdjustSubgraph(&node, subgraph_arg_type); } @@ -903,14 +903,14 @@ void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { for (auto p_in : links_tmp) { if (NeedInsert(p_in, subgraph_arg_type)) { InsertBefore( - graph.get(), p_in, &node, subgraph_arg_type, use_mlu_cast); + graph.get(), p_in, &node, subgraph_arg_type, !disable_mlu_cast); } } links_tmp.assign(node.outlinks.begin(), node.outlinks.end()); for (auto p_out : links_tmp) { if (NeedInsert(p_out, subgraph_arg_type)) { InsertAfter( - graph.get(), p_out, &node, subgraph_arg_type, use_mlu_cast); + graph.get(), p_out, &node, subgraph_arg_type, !disable_mlu_cast); } } } diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index 70c429dd93..51381094ef 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -343,7 +343,7 @@ class SubgraphEngine : public subgraph::Engine { CHECK_EQ(graph_input->size(), origin_itensors_.size()); CHECK_EQ(graph_output->size(), origin_otensors_.size()); - bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST"); + bool disable_mlu_cast = GetBoolFromEnv("LITE_DISABLE_MLU_CAST"); if (!disable_batch_size_changeable_) { std::vector> @@ -386,7 +386,7 @@ class SubgraphEngine : public subgraph::Engine { for (size_t i = 0; i < origin_otensors_.size(); ++i) { // origin_otensors_[i]->Resize(new_output_size.at(i)); graph_out[i]->set_mlu_ptr( - GetOutputDataPtr(origin_otensors_[i], use_mlu_cast)); + GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast)); } } else { graph_out.reserve(origin_otensors_.size()); @@ -395,7 +395,8 @@ class SubgraphEngine : public subgraph::Engine { paddle::lite::subgraph::mlu::MLUTensor tmp( origin_otensors_[i]->dims().Vectorize()); tmp.set_mlu_dtype(graph_output->at(i)->dtype()); - tmp.set_mlu_ptr(GetOutputDataPtr(origin_otensors_[i], use_mlu_cast)); + tmp.set_mlu_ptr( + GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast)); graph_out.push_back( std::make_shared(tmp)); } @@ -410,7 +411,7 @@ class SubgraphEngine : public subgraph::Engine { for (size_t i = 0; i < origin_otensors_.size(); ++i) { origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape()); graph_output->at(i)->set_mlu_ptr( - GetOutputDataPtr(origin_otensors_[i], use_mlu_cast)); + GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast)); } graph->Compute(forward_param, exec_queue); } -- GitLab