提交 804afc51 编写于 作者: B baojun 提交者: tensor-tang

Minor ngraph fix (#16270)

* take care edge cases test=develop

* use pragma test=develop
上级 9195c3bb
......@@ -92,12 +92,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int size = ops->size();
int left = 0;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType) {
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != framework::kFetchOpType) {
++left;
}
if (left == size) {
return intervals;
}
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
for (auto& var_name_item : ops->at(left)->Outputs()) {
......@@ -112,10 +110,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
++right;
}
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
......@@ -127,6 +121,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++index;
}
if (left == size || ops->at(left)->Type() == framework::kFetchOpType) {
left = 0;
}
// (left, right - 1) represents indices between feed and fetch
int pivot = left;
while (pivot < right) {
......@@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
}
void NgraphEngine::Prepare(const std::vector<int>& interval) {
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
......@@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc);
if (op_desc->Type() == framework::kFetchOpType) {
has_fetch = true;
}
}
for (auto op_desc : ops_desc) {
......@@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
this->op_state_ = OpState::FULL;
ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
is_full = true;
}
if (this->op_state_ == OpState::FULL) {
if (is_full) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else {
this->op_state_ =
......@@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
......@@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
++idx;
}
if (!has_fetch) {
op_state_ = OpState::UNKNOWN;
}
BuildNgIO(ops_desc, interval);
}
......@@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
}
}
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#pragma once
#include <memory>
#include <set>
#include <string>
......@@ -35,7 +35,6 @@ enum class OpState { /* nGraph support state on ops */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST, /* Support partial list of ops for test */
FULL, /* All ops supported from feed to fetch */
UNKNOWN /* Output all for debug purpose */
};
......@@ -119,4 +118,3 @@ class NgraphEngine {
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册