提交 1c8b34dd 编写于 作者: B baojun 提交者: tensor-tang

fix training validation test=develop (#16698)

上级 a06f4b2b
...@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {}; ...@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {}; std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr; framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr; const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {}; std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string, std::unordered_map<std::string,
...@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int size = ops->size(); int size = ops->size();
int left = 0; int left = 0;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType && while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != "read" &&
ops->at(left)->Type() != framework::kFetchOpType) { ops->at(left)->Type() != framework::kFetchOpType) {
++left; ++left;
} }
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) { while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
ops->at(left)->Type() == "read")) {
for (auto& var_name_item : ops->at(left)->Outputs()) { for (auto& var_name_item : ops->at(left)->Outputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
NgraphEngine::feed_vars.emplace_back(var_name); NgraphEngine::feed_vars.emplace_back(var_name);
...@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
for (auto op_desc : ops_desc) { for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) { if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false; this->is_test_ = false;
break; break;
} }
...@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
} }
bool is_persistable = bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false; (p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (is_test && is_persistable) { if (!is_training && is_test && is_persistable) {
ti->set_stale(false); ti->set_stale(false);
} }
(*p_t_in).emplace_back(ti); (*p_t_in).emplace_back(ti);
......
...@@ -57,6 +57,7 @@ class NgraphEngine { ...@@ -57,6 +57,7 @@ class NgraphEngine {
void Run(const framework::Scope& scope, const platform::Place& place) const; void Run(const framework::Scope& scope, const platform::Place& place) const;
static bool is_training;
static const framework::BlockDesc* p_bdesc; static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars; static std::vector<std::string> feed_vars, fetch_vars;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册