From 2c58f1a83c9d119e541ecfbddacb1df94467e1bf Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Thu, 30 May 2019 20:08:58 -0700 Subject: [PATCH] [NGraph] Added lookup table to ngraph engine test=develop (#17647) --- cmake/external/ngraph.cmake | 2 +- .../operators/ngraph/ops/lookup_table_op.h | 103 ++++++++++++++++++ .../ngraph/test_lookup_table_ngraph_op.py | 21 ++++ 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/ngraph/ops/lookup_table_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_lookup_table_ngraph_op.py diff --git a/cmake/external/ngraph.cmake b/cmake/external/ngraph.cmake index 28f36973fd..cdcbdd46a8 100644 --- a/cmake/external/ngraph.cmake +++ b/cmake/external/ngraph.cmake @@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs) INCLUDE(ExternalProject) SET(NGRAPH_PROJECT "extern_ngraph") -SET(NGRAPH_GIT_TAG "096ad6ef0c04d57db1522940dbdf9a0652768065") +SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a") SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) diff --git a/paddle/fluid/operators/ngraph/ops/lookup_table_op.h b/paddle/fluid/operators/ngraph/ops/lookup_table_op.h new file mode 100644 index 0000000000..5126854dc2 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/lookup_table_op.h @@ -0,0 +1,103 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "ngraph/op/embedding_lookup.hpp" +#include "paddle/fluid/operators/lookup_table_op.h" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildLookupTableNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + const bool is_sparse = op_attrs.Get("is_sparse"); + const int64_t padding_idx = op_attrs.Get("padding_idx"); + + auto ng_ids = paddle::platform::GetInputNode(op, "Ids", ngb_node_map); + PADDLE_ENFORCE_NOT_NULL(ng_ids); + + const auto ng_w = paddle::platform::GetInputNode(op, "W", ngb_node_map); + PADDLE_ENFORCE_NOT_NULL(ng_w); + + if (is_sparse) { + PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op."); + } + + if (padding_idx != kNoPadding) { + PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op."); + } + auto shape = ng_ids->get_shape(); + if (shape.back() == 1) { + shape.pop_back(); + ng_ids = platform::NgReshaper(ng_ids, shape); + } + auto ng_lookup = std::make_shared(ng_w, ng_ids); + platform::SetOutputNode(op, "Out", ng_lookup, ngb_node_map); +} + +void BuildLookupTableGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + const bool is_sparse = op_attrs.Get("is_sparse"); + const int64_t padding_idx = op_attrs.Get("padding_idx"); + + auto ng_ids = paddle::platform::GetInputNode(op, "Ids", ngb_node_map); + PADDLE_ENFORCE_NOT_NULL(ng_ids); + + const auto ng_w = paddle::platform::GetInputNode(op, "W", ngb_node_map); + PADDLE_ENFORCE_NOT_NULL(ng_w); + + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + + if (is_sparse) { + PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op."); + } + + if (padding_idx != kNoPadding) { + PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op."); + } + auto shape = ng_ids->get_shape(); + if (shape.back() == 1) { + shape.pop_back(); + ng_ids = platform::NgReshaper(ng_ids, shape); + } + + std::shared_ptr W0 = paddle::platform::CreateConstant( + dout->get_element_type(), ng_w->get_shape(), {0}); + auto dW = std::make_shared(W0, ng_ids, dout); + platform::SetOutputNode(op, "W@GRAD", dW, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(lookup_table, BuildLookupTableNode); +REGISTER_NG_OP(lookup_table_grad, BuildLookupTableGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_lookup_table_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_lookup_table_ngraph_op.py new file mode 100644 index 0000000000..c9111c2210 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_lookup_table_ngraph_op.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest, sys +sys.path.append("../") +from test_lookup_table_op import * + +if __name__ == "__main__": + unittest.main() -- GitLab