提交 adcfc53b 编写于 作者: B baojun 提交者: Tao Luo

upgrade ngraph version and simplify ngraph engine (#18853)

* upgrade ngraph to v0.24 test=develop

* simplify io test=develop
上级 2bb296df
......@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a")
SET(NGRAPH_GIT_TAG "v0.24.0-rc.2")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
......
......@@ -92,13 +92,20 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
std::vector<std::vector<int>> intervals;
int size = ops->size();
int left = 0;
int left = 0, feed_idx = -1;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != "read" &&
ops->at(left)->Type() != framework::kFetchOpType) {
++left;
}
if (left < size) {
auto op_type = ops->at(left)->Type();
if (op_type == framework::kFeedOpType || op_type == "read") {
feed_idx = left;
}
}
while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
ops->at(left)->Type() == "read")) {
for (auto& var_name_item : ops->at(left)->Outputs()) {
......@@ -141,7 +148,9 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++end;
}
std::vector<int> interval = {start, end};
intervals.emplace_back(interval);
if (feed_idx != -1 && start > feed_idx) {
intervals.emplace_back(interval);
}
}
} // end while
return intervals;
......@@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
NgraphEngine::p_bdesc = &block_desc;
}
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
......@@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc);
if (op_desc->Type() == framework::kFetchOpType) {
has_fetch = true;
}
}
for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false;
break;
}
}
if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
is_full = true;
}
if (is_full) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else {
this->op_state_ =
this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN;
}
int idx = interval[0];
while (idx < interval[1]) {
this->fused_ops_.emplace_back(
......@@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++idx;
}
if (!has_fetch) {
op_state_ = OpState::UNKNOWN;
}
if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval);
}
......@@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
switch (this->op_state_) {
case OpState::PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::FULL_TEST:
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::PARTIAL_TRAIN:
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::FULL_TRAIN:
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
break;
default:
if (this->is_test_) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
}
} else {
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
}
}
}
......
......@@ -30,14 +30,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum class OpState { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST, /* Support partial list of ops for test */
UNKNOWN /* Output all for debug purpose */
};
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
......@@ -78,7 +70,6 @@ class NgraphEngine {
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_;
OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true};
std::string func_cache_key_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册