提交 1c8b34dd 编写于 作者: B baojun 提交者: tensor-tang

fix training validation test=develop (#16698)

上级 a06f4b2b
......@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string,
......@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int size = ops->size();
int left = 0;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != "read" &&
ops->at(left)->Type() != framework::kFetchOpType) {
++left;
}
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
ops->at(left)->Type() == "read")) {
for (auto& var_name_item : ops->at(left)->Outputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::feed_vars.emplace_back(var_name);
......@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false;
break;
}
......@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (is_test && is_persistable) {
if (!is_training && is_test && is_persistable) {
ti->set_stale(false);
}
(*p_t_in).emplace_back(ti);
......
......@@ -57,6 +57,7 @@ class NgraphEngine {
void Run(const framework::Scope& scope, const platform::Place& place) const;
static bool is_training;
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册