diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index 5fcb17b9f3ac390548aba33db7d0b8350cde7e00..58091ee3bc2eecf2ea07767885461071ea285c2e 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -34,7 +34,8 @@ std::map}, - {"tanh", paddle::operators::ngraphs::BuildUnaryNode}}; + {"tanh", paddle::operators::ngraphs::BuildUnaryNode}, + {"top_k", paddle::operators::ngraphs::BuildTopKNode}}; void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { auto& op_type = op->Type(); diff --git a/paddle/fluid/operators/ngraph/ngraph_ops.h b/paddle/fluid/operators/ngraph/ngraph_ops.h index 0ed77ff5577cf4f45a8865db9b42e8bda9839478..869a4fd574224d2adba5e4dc21669c7477604236 100644 --- a/paddle/fluid/operators/ngraph/ngraph_ops.h +++ b/paddle/fluid/operators/ngraph/ngraph_ops.h @@ -23,3 +23,4 @@ limitations under the License. */ #include "ops/binary_unnary_op.h" #include "ops/mul_op.h" +#include "ops/top_k_op.h" diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c7922e35e3c04d343a0bfb63ab31614836f190c --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h @@ -0,0 +1,51 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#pragma once + +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +static void BuildTopKNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + int k = op_attrs.Get("k"); + auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto top_k = std::make_shared( + input, input->get_shape().size() - 1, ngraph::element::i64, k); + std::shared_ptr indices = + std::make_shared(top_k, 0); + std::shared_ptr out = + std::make_shared(top_k, 1); + auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); + if (dummy_out && dummy_out->get_element_type() != out->get_element_type()) { + out = std::make_shared(out, + dummy_out->get_element_type()); + } + paddle::platform::SetOutputNode(op, "Indices", indices, ngb_node_map); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle +#endif diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0171087dce5d4c7b72eca7f7e4fb955af94812 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +from paddle.fluid.tests.unittests.test_top_k_op import TestTopkOp, TestTopkOp3d, TestTopkOp2, TestTopkOp3, TestTopkOp4 + + +class TestNGRAPHTopkOp(TestTopkOp): + def setUp(self): + super(TestNGRAPHTopkOp, self).setUp() + + +class TestNGRAPHTopkOp2(TestTopkOp2): + def setUp(self): + super(TestNGRAPHTopkOp2, self).setUp() + + +class TestNGRAPHTopkOp3(TestTopkOp3): + def setUp(self): + super(TestNGRAPHTopkOp3, self).setUp() + + +class TestNGRAPHTopkOp4(TestTopkOp4): + def setUp(self): + super(TestNGRAPHTopkOp4, self).setUp() + + +if __name__ == "__main__": + unittest.main()