提交 3b436c62 编写于 作者: D dingminghui 提交者: jackzhang235

refactor(cast): use mlu cast as default

上级 1ebad864
......@@ -888,14 +888,14 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#endif
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
ModifyValidPlaces(graph.get(), use_mlu_cast);
bool disable_mlu_cast = GetBoolFromEnv("LITE_DISABLE_MLU_CAST");
ModifyValidPlaces(graph.get(), !disable_mlu_cast);
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
const Type* subgraph_arg_type = nullptr;
GetSubgraphOpArgType(&node, &subgraph_arg_type, graph.get());
if (use_mlu_cast) {
if (!disable_mlu_cast) {
AdjustSubgraph(&node, subgraph_arg_type);
}
......@@ -903,14 +903,14 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto p_in : links_tmp) {
if (NeedInsert(p_in, subgraph_arg_type)) {
InsertBefore(
graph.get(), p_in, &node, subgraph_arg_type, use_mlu_cast);
graph.get(), p_in, &node, subgraph_arg_type, !disable_mlu_cast);
}
}
links_tmp.assign(node.outlinks.begin(), node.outlinks.end());
for (auto p_out : links_tmp) {
if (NeedInsert(p_out, subgraph_arg_type)) {
InsertAfter(
graph.get(), p_out, &node, subgraph_arg_type, use_mlu_cast);
graph.get(), p_out, &node, subgraph_arg_type, !disable_mlu_cast);
}
}
}
......
......@@ -343,7 +343,7 @@ class SubgraphEngine : public subgraph::Engine {
CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size());
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
bool disable_mlu_cast = GetBoolFromEnv("LITE_DISABLE_MLU_CAST");
if (!disable_batch_size_changeable_) {
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
......@@ -386,7 +386,7 @@ class SubgraphEngine : public subgraph::Engine {
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
// origin_otensors_[i]->Resize(new_output_size.at(i));
graph_out[i]->set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast));
}
} else {
graph_out.reserve(origin_otensors_.size());
......@@ -395,7 +395,8 @@ class SubgraphEngine : public subgraph::Engine {
paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize());
tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
tmp.set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast));
graph_out.push_back(
std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp));
}
......@@ -410,7 +411,7 @@ class SubgraphEngine : public subgraph::Engine {
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
graph_output->at(i)->set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast));
}
graph->Compute(forward_param, exec_queue);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册