From b18948073454b3db4b51c3335bb620c0e09c50a9 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Tue, 14 May 2019 20:58:25 -0700 Subject: [PATCH] Ngraph Enable gather operator test=develop (#17296) --- paddle/fluid/operators/ngraph/ops/gather_op.h | 65 +++++++++++++++++++ .../unittests/ngraph/test_gather_ngraph_op.py | 21 ++++++ 2 files changed, 86 insertions(+) create mode 100644 paddle/fluid/operators/ngraph/ops/gather_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_gather_ngraph_op.py diff --git a/paddle/fluid/operators/ngraph/ops/gather_op.h b/paddle/fluid/operators/ngraph/ops/gather_op.h new file mode 100644 index 00000000000..273a328c520 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/gather_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildGatherNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto x = platform::GetInputNode(op, "X", ngb_node_map); + auto index = platform::GetInputNode(op, "Index", ngb_node_map); + auto x_shape = x->get_shape(); + size_t axis_1 = x_shape[0]; + size_t axis_2 = 1; + if (x_shape.size() > 1) { + axis_2 = std::accumulate(std::begin(x_shape) + 1, std::end(x_shape), 1, + std::multiplies()); + } + std::vector x_order(x_shape.size()); + std::iota(std::begin(x_order), std::end(x_order), 0); + auto x_reshape = std::make_shared( + x, ngraph::AxisVector(x_order), ngraph::Shape{axis_1, axis_2}); + auto x_reshape_shape = x_reshape->get_shape(); + auto result = std::make_shared(index, x_reshape); + auto result_shape = result->get_shape(); + std::vector out_shape(x_shape); + out_shape[0] = result_shape[0]; + std::vector axis_vector; + for (size_t i = 0; i < result_shape.size(); i++) { + axis_vector.push_back(i); + } + auto out = std::make_shared( + result, ngraph::AxisVector(axis_vector), out_shape); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(gather, BuildGatherNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_gather_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_gather_ngraph_op.py new file mode 100644 index 00000000000..403145dd734 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_gather_ngraph_op.py @@ -0,0 +1,21 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest, sys +sys.path.append("../") +from test_gather_op import TestGatherOp, TestCase1 + +if __name__ == "__main__": + unittest.main() -- GitLab