diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index cd32200e925193b393f4531b87ed6b1e4291109d..014c9ecca4aa7875e5188b33595f5e8f79c1a9db 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -92,12 +92,10 @@ static std::vector> NgraphOpIntervals( int size = ops->size(); int left = 0; - while (left < size && ops->at(left)->Type() != framework::kFeedOpType) { + while (left < size && ops->at(left)->Type() != framework::kFeedOpType && + ops->at(left)->Type() != framework::kFetchOpType) { ++left; } - if (left == size) { - return intervals; - } while (left < size && ops->at(left)->Type() == framework::kFeedOpType) { for (auto& var_name_item : ops->at(left)->Outputs()) { @@ -112,10 +110,6 @@ static std::vector> NgraphOpIntervals( while (right < size && ops->at(right)->Type() != framework::kFetchOpType) { ++right; } - if (right == size) { - return intervals; - } - if (left >= right) return intervals; int index = right; while (index < size && ops->at(index)->Type() == framework::kFetchOpType) { @@ -127,6 +121,10 @@ static std::vector> NgraphOpIntervals( ++index; } + if (left == size || ops->at(left)->Type() == framework::kFetchOpType) { + left = 0; + } + // (left, right - 1) represents indices between feed and fetch int pivot = left; while (pivot < right) { @@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, } void NgraphEngine::Prepare(const std::vector& interval) { + bool has_fetch = false, is_full = false; for (auto& var : p_bdesc->AllVars()) { if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || var->GetType() == framework::proto::VarType::LOD_TENSOR || @@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector& interval) { std::vector ops_desc; for (auto op_desc : p_bdesc->AllOps()) { ops_desc.emplace_back(op_desc); + if (op_desc->Type() == framework::kFetchOpType) { + has_fetch = true; + } } for (auto op_desc : ops_desc) { @@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector& interval) { if (interval[0] > 0 && ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType && interval[1] < static_cast(ops_desc.size()) && - ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) { - this->op_state_ = OpState::FULL; + ops_desc.at(interval[1])->Type() == framework::kFetchOpType) { + is_full = true; } - if (this->op_state_ == OpState::FULL) { + if (is_full) { this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN; } else { this->op_state_ = @@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector& interval) { framework::OpRegistry::CreateOp(*(ops_desc[idx]))); ++idx; } - while (ops_desc.at(idx)->Type() != framework::kFetchOpType) { + while (idx < static_cast(ops_desc.size()) && + ops_desc.at(idx)->Type() != framework::kFetchOpType) { auto op_desc = ops_desc.at(idx); for (auto& var_name_item : op_desc->Inputs()) { for (auto& var_name : var_name_item.second) { @@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector& interval) { ++idx; } + if (!has_fetch) { + op_state_ = OpState::UNKNOWN; + } + BuildNgIO(ops_desc, interval); } @@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, } } } + for (size_t i = 0; i < var_in_.size(); ++i) { auto var_name = var_in_[i]; if (persistables_.find(var_name) == persistables_.end()) { diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index fef51464b5702e61d052f28050f6aefaecf0f615..b6532519e947bc59f0605c4f2008270f5e51b0e0 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_ -#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_ +#pragma once + #include #include #include @@ -35,7 +35,6 @@ enum class OpState { /* nGraph support state on ops */ PARTIAL_TRAIN, /* Support partial ops for train */ FULL_TEST, /* Support full list of ops for test */ PARTIAL_TEST, /* Support partial list of ops for test */ - FULL, /* All ops supported from feed to fetch */ UNKNOWN /* Output all for debug purpose */ }; @@ -119,4 +118,3 @@ class NgraphEngine { } // namespace operators } // namespace paddle -#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_