From 6ee6700facb4419771ed1a6fda6ac45ef83b4ef6 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Wed, 15 May 2019 02:23:14 -0700 Subject: [PATCH] Eanble stack operator for a Ngraph, test=develop (#17406) --- paddle/fluid/operators/ngraph/ops/stack_op.h | 56 +++++++++++++++++++ .../unittests/ngraph/test_stack_ngraph_op.py | 22 ++++++++ 2 files changed, 78 insertions(+) create mode 100644 paddle/fluid/operators/ngraph/ops/stack_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_stack_ngraph_op.py diff --git a/paddle/fluid/operators/ngraph/ops/stack_op.h b/paddle/fluid/operators/ngraph/ops/stack_op.h new file mode 100644 index 000000000..d0e9545fd --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/stack_op.h @@ -0,0 +1,56 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildStackNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); + auto axis = op_attrs.Get("axis"); + std::vector> args; + for (auto& var_name_item : op->Inputs()) { + for (auto& var_name : var_name_item.second) { + auto& node = ngb_node_map->at(var_name); + auto shape = node->get_shape(); + axis = (axis < 0) ? axis + shape.size() + 1 : axis; + shape.insert(shape.begin() + axis, 1); + std::vector input_order(shape.size() - 1); + std::iota(std::begin(input_order), std::end(input_order), 0); + args.push_back(std::make_shared( + node, ngraph::AxisVector(input_order), shape)); + } + } + auto out = std::make_shared(args, axis); + platform::SetOutputNode(op, "Y", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(stack, BuildStackNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_stack_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_stack_ngraph_op.py new file mode 100644 index 000000000..23ef26133 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_stack_ngraph_op.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") +from test_stack_op import TestStackOpBase, TestStackOp1, TestStackOp2, TestStackOp3, TestStackOp4, TestStackOp5, TestStackOp6 + +if __name__ == '__main__': + unittest.main() -- GitLab