diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 9f73bbc1fdc72766a0b57bc72c62d208277c2f20..5ef385d2fcbaf01dce5c9b85321b41c103e5655a 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -75,6 +75,7 @@ std::vector NgraphEngine::feed_vars = {}; std::vector NgraphEngine::fetch_vars = {}; framework::Variable* NgraphEngine::pre_var_ptr = nullptr; const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr; +bool NgraphEngine::is_training = false; std::unordered_map NgraphEngine::engine_cache = {}; std::unordered_map> NgraphOpIntervals( int size = ops->size(); int left = 0; while (left < size && ops->at(left)->Type() != framework::kFeedOpType && + ops->at(left)->Type() != "read" && ops->at(left)->Type() != framework::kFetchOpType) { ++left; } - while (left < size && ops->at(left)->Type() == framework::kFeedOpType) { + while (left < size && (ops->at(left)->Type() == framework::kFeedOpType || + ops->at(left)->Type() == "read")) { for (auto& var_name_item : ops->at(left)->Outputs()) { for (auto& var_name : var_name_item.second) { NgraphEngine::feed_vars.emplace_back(var_name); @@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector& interval) { for (auto op_desc : ops_desc) { if (op_desc->Type().find("_grad") != std::string::npos) { + is_training = true; this->is_test_ = false; break; } @@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope, } bool is_persistable = (p_persistables->find(vi) != p_persistables->end()) ? true : false; - if (is_test && is_persistable) { + if (!is_training && is_test && is_persistable) { ti->set_stale(false); } (*p_t_in).emplace_back(ti); diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index b6532519e947bc59f0605c4f2008270f5e51b0e0..19400ac5b0ecd9d3254583b8db9889fc6cf8bc0f 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -57,6 +57,7 @@ class NgraphEngine { void Run(const framework::Scope& scope, const platform::Place& place) const; + static bool is_training; static const framework::BlockDesc* p_bdesc; static std::vector feed_vars, fetch_vars;