diff --git a/cmake/external/ngraph.cmake b/cmake/external/ngraph.cmake index cdcbdd46a8d55cc75706de7bc415478f4fe4f256..95440fc3f72a7fe89d32965effd31c85fe05ee60 100644 --- a/cmake/external/ngraph.cmake +++ b/cmake/external/ngraph.cmake @@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs) INCLUDE(ExternalProject) SET(NGRAPH_PROJECT "extern_ngraph") -SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a") +SET(NGRAPH_GIT_TAG "v0.24.0-rc.2") SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 3a943686577e31152356b2bc3c0b77eacb64f300..7d78c61739a9d1ba4577079ea48ea4d3467f3fd8 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -92,13 +92,20 @@ static std::vector> NgraphOpIntervals( std::vector> intervals; int size = ops->size(); - int left = 0; + int left = 0, feed_idx = -1; while (left < size && ops->at(left)->Type() != framework::kFeedOpType && ops->at(left)->Type() != "read" && ops->at(left)->Type() != framework::kFetchOpType) { ++left; } + if (left < size) { + auto op_type = ops->at(left)->Type(); + if (op_type == framework::kFeedOpType || op_type == "read") { + feed_idx = left; + } + } + while (left < size && (ops->at(left)->Type() == framework::kFeedOpType || ops->at(left)->Type() == "read")) { for (auto& var_name_item : ops->at(left)->Outputs()) { @@ -141,7 +148,9 @@ static std::vector> NgraphOpIntervals( ++end; } std::vector interval = {start, end}; - intervals.emplace_back(interval); + if (feed_idx != -1 && start > feed_idx) { + intervals.emplace_back(interval); + } } } // end while return intervals; @@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { NgraphEngine::p_bdesc = &block_desc; } - bool has_fetch = false, is_full = false; for (auto& var : p_bdesc->AllVars()) { if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || var->GetType() == framework::proto::VarType::LOD_TENSOR || @@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { std::vector ops_desc; for (auto op_desc : p_bdesc->AllOps()) { ops_desc.emplace_back(op_desc); - if (op_desc->Type() == framework::kFetchOpType) { - has_fetch = true; - } - } - - for (auto op_desc : ops_desc) { if (op_desc->Type().find("_grad") != std::string::npos) { is_training = true; this->is_test_ = false; - break; } } - if (interval[0] > 0 && - ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType && - interval[1] < static_cast(ops_desc.size()) && - ops_desc.at(interval[1])->Type() == framework::kFetchOpType) { - is_full = true; - } - - if (is_full) { - this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN; - } else { - this->op_state_ = - this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN; - } - int idx = interval[0]; while (idx < interval[1]) { this->fused_ops_.emplace_back( @@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ++idx; } - if (!has_fetch) { - op_state_ = OpState::UNKNOWN; - } - if (var_in_.empty() && var_out_.empty()) { BuildNgIO(ops_desc, interval); } @@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, "op %s has more than 1 output - Not handling yet", op->Type()); for (auto& var_name : var_name_item.second) { - switch (this->op_state_) { - case OpState::PARTIAL_TEST: - if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || - find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end()) { - this->var_out_.emplace_back(var_name); - } - break; - case OpState::FULL_TEST: - if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end()) { - this->var_out_.emplace_back(var_name); - } - break; - case OpState::PARTIAL_TRAIN: - if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end() || - post_op_inputs_.find(var_name) != post_op_inputs_.end() || - persistables_.find(var_name) != persistables_.end()) { - this->var_out_.emplace_back(var_name); - } - break; - case OpState::FULL_TRAIN: - if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end() || - persistables_.find(var_name) != persistables_.end()) { - this->var_out_.emplace_back(var_name); - } - break; - default: + if (this->is_test_) { + if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || + find(fetch_vars.begin(), fetch_vars.end(), var_name) != + fetch_vars.end()) { + this->var_out_.emplace_back(var_name); + } + } else { + if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != + fetch_vars.end() || + post_op_inputs_.find(var_name) != post_op_inputs_.end() || + persistables_.find(var_name) != persistables_.end()) { this->var_out_.emplace_back(var_name); + } } } } diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index 4cb1465371356cd1ef76113fc4e78d7f1e188746..7fa443a5d49b17d116895bdd3227561fb3f8515a 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -30,14 +30,6 @@ limitations under the License. */ namespace paddle { namespace operators { -enum class OpState { /* nGraph support state on ops */ - FULL_TRAIN, /* Support full ops for train */ - PARTIAL_TRAIN, /* Support partial ops for train */ - FULL_TEST, /* Support full list of ops for test */ - PARTIAL_TEST, /* Support partial list of ops for test */ - UNKNOWN /* Output all for debug purpose */ -}; - // cache engine repetitives struct EngineCache { std::shared_ptr ngraph_handle; @@ -78,7 +70,6 @@ class NgraphEngine { std::unordered_map var_type_map_; std::set persistables_; std::unordered_set post_op_inputs_; - OpState op_state_ = OpState::UNKNOWN; bool is_test_{true}; std::string func_cache_key_;