提交 adcfc53b 编写于 作者: B baojun 提交者: Tao Luo

upgrade ngraph version and simplify ngraph engine (#18853)

* upgrade ngraph to v0.24 test=develop

* simplify io test=develop
上级 2bb296df
...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs) ...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph") SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a") SET(NGRAPH_GIT_TAG "v0.24.0-rc.2")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
......
...@@ -92,13 +92,20 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -92,13 +92,20 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
std::vector<std::vector<int>> intervals; std::vector<std::vector<int>> intervals;
int size = ops->size(); int size = ops->size();
int left = 0; int left = 0, feed_idx = -1;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType && while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != "read" && ops->at(left)->Type() != "read" &&
ops->at(left)->Type() != framework::kFetchOpType) { ops->at(left)->Type() != framework::kFetchOpType) {
++left; ++left;
} }
if (left < size) {
auto op_type = ops->at(left)->Type();
if (op_type == framework::kFeedOpType || op_type == "read") {
feed_idx = left;
}
}
while (left < size && (ops->at(left)->Type() == framework::kFeedOpType || while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
ops->at(left)->Type() == "read")) { ops->at(left)->Type() == "read")) {
for (auto& var_name_item : ops->at(left)->Outputs()) { for (auto& var_name_item : ops->at(left)->Outputs()) {
...@@ -141,7 +148,9 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -141,7 +148,9 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++end; ++end;
} }
std::vector<int> interval = {start, end}; std::vector<int> interval = {start, end};
intervals.emplace_back(interval); if (feed_idx != -1 && start > feed_idx) {
intervals.emplace_back(interval);
}
} }
} // end while } // end while
return intervals; return intervals;
...@@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
NgraphEngine::p_bdesc = &block_desc; NgraphEngine::p_bdesc = &block_desc;
} }
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) { for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR ||
...@@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
std::vector<paddle::framework::OpDesc*> ops_desc; std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) { for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc); ops_desc.emplace_back(op_desc);
if (op_desc->Type() == framework::kFetchOpType) {
has_fetch = true;
}
}
for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) { if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true; is_training = true;
this->is_test_ = false; this->is_test_ = false;
break;
} }
} }
if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
is_full = true;
}
if (is_full) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else {
this->op_state_ =
this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN;
}
int idx = interval[0]; int idx = interval[0];
while (idx < interval[1]) { while (idx < interval[1]) {
this->fused_ops_.emplace_back( this->fused_ops_.emplace_back(
...@@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++idx; ++idx;
} }
if (!has_fetch) {
op_state_ = OpState::UNKNOWN;
}
if (var_in_.empty() && var_out_.empty()) { if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval); BuildNgIO(ops_desc, interval);
} }
...@@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
"op %s has more than 1 output - Not handling yet", "op %s has more than 1 output - Not handling yet",
op->Type()); op->Type());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
switch (this->op_state_) { if (this->is_test_) {
case OpState::PARTIAL_TEST: if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
find(fetch_vars.begin(), fetch_vars.end(), var_name) != fetch_vars.end()) {
fetch_vars.end()) { this->var_out_.emplace_back(var_name);
this->var_out_.emplace_back(var_name); }
} } else {
break; if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
case OpState::FULL_TEST: fetch_vars.end() ||
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetch_vars.end()) { persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::PARTIAL_TRAIN:
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::FULL_TRAIN:
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
break;
default:
this->var_out_.emplace_back(var_name); this->var_out_.emplace_back(var_name);
}
} }
} }
} }
......
...@@ -30,14 +30,6 @@ limitations under the License. */ ...@@ -30,14 +30,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum class OpState { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST, /* Support partial list of ops for test */
UNKNOWN /* Output all for debug purpose */
};
// cache engine repetitives // cache engine repetitives
struct EngineCache { struct EngineCache {
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle; std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
...@@ -78,7 +70,6 @@ class NgraphEngine { ...@@ -78,7 +70,6 @@ class NgraphEngine {
std::unordered_map<std::string, ngraph::element::Type> var_type_map_; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::set<std::string> persistables_; std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_; std::unordered_set<std::string> post_op_inputs_;
OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true}; bool is_test_{true};
std::string func_cache_key_; std::string func_cache_key_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册