提交 804afc51 编写于 作者: B baojun 提交者: tensor-tang

Minor ngraph fix (#16270)

* take care edge cases test=develop

* use pragma test=develop
上级 9195c3bb
...@@ -92,12 +92,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -92,12 +92,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int size = ops->size(); int size = ops->size();
int left = 0; int left = 0;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType) { while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != framework::kFetchOpType) {
++left; ++left;
} }
if (left == size) {
return intervals;
}
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) { while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
for (auto& var_name_item : ops->at(left)->Outputs()) { for (auto& var_name_item : ops->at(left)->Outputs()) {
...@@ -112,10 +110,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -112,10 +110,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) { while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
++right; ++right;
} }
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
int index = right; int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) { while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
...@@ -127,6 +121,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -127,6 +121,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++index; ++index;
} }
if (left == size || ops->at(left)->Type() == framework::kFetchOpType) {
left = 0;
}
// (left, right - 1) represents indices between feed and fetch // (left, right - 1) represents indices between feed and fetch
int pivot = left; int pivot = left;
while (pivot < right) { while (pivot < right) {
...@@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, ...@@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
} }
void NgraphEngine::Prepare(const std::vector<int>& interval) { void NgraphEngine::Prepare(const std::vector<int>& interval) {
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) { for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR ||
...@@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
std::vector<paddle::framework::OpDesc*> ops_desc; std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) { for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc); ops_desc.emplace_back(op_desc);
if (op_desc->Type() == framework::kFetchOpType) {
has_fetch = true;
}
} }
for (auto op_desc : ops_desc) { for (auto op_desc : ops_desc) {
...@@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
if (interval[0] > 0 && if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType && ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) && interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) { ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
this->op_state_ = OpState::FULL; is_full = true;
} }
if (this->op_state_ == OpState::FULL) { if (is_full) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN; this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else { } else {
this->op_state_ = this->op_state_ =
...@@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
framework::OpRegistry::CreateOp(*(ops_desc[idx]))); framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx; ++idx;
} }
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) { while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx); auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) { for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
...@@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
++idx; ++idx;
} }
if (!has_fetch) {
op_state_ = OpState::UNKNOWN;
}
BuildNgIO(ops_desc, interval); BuildNgIO(ops_desc, interval);
} }
...@@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
} }
} }
} }
for (size_t i = 0; i < var_in_.size(); ++i) { for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i]; auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) { if (persistables_.find(var_name) == persistables_.end()) {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_ #pragma once
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
...@@ -35,7 +35,6 @@ enum class OpState { /* nGraph support state on ops */ ...@@ -35,7 +35,6 @@ enum class OpState { /* nGraph support state on ops */
PARTIAL_TRAIN, /* Support partial ops for train */ PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */ FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST, /* Support partial list of ops for test */ PARTIAL_TEST, /* Support partial list of ops for test */
FULL, /* All ops supported from feed to fetch */
UNKNOWN /* Output all for debug purpose */ UNKNOWN /* Output all for debug purpose */
}; };
...@@ -119,4 +118,3 @@ class NgraphEngine { ...@@ -119,4 +118,3 @@ class NgraphEngine {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册