From bacc8224920352467adc92c70a690ed204f35183 Mon Sep 17 00:00:00 2001 From: pawelpiotrowicz <48519735+pawelpiotrowicz@users.noreply.github.com> Date: Fri, 31 May 2019 04:04:29 +0200 Subject: [PATCH] [NGraph] Enable transpose ngraph operator (#17636) test=develop --- .../fluid/operators/ngraph/ops/transpose_op.h | 92 +++++++++++++++++++ .../ngraph/test_transpose_ngraph_op.py | 22 +++++ 2 files changed, 114 insertions(+) create mode 100644 paddle/fluid/operators/ngraph/ops/transpose_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_transpose_ngraph_op.py diff --git a/paddle/fluid/operators/ngraph/ops/transpose_op.h b/paddle/fluid/operators/ngraph/ops/transpose_op.h new file mode 100644 index 00000000000..2d4e49f79b1 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/transpose_op.h @@ -0,0 +1,92 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +template +static void BuildTransposeNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + std::vector axis = op_attrs.Get>("axis"); + + auto input_shape = input->get_shape(); + ngraph::Shape x_reshape_shape; + ngraph::AxisVector axis_vec; + for (auto& v : axis) { + axis_vec.push_back(v); + x_reshape_shape.push_back(input_shape[v]); + } + std::shared_ptr x_transpose = + std::make_shared(input, axis_vec, input_shape); + x_transpose = platform::NgReshaper(x_transpose, x_reshape_shape); + platform::SetOutputNode(op, "Out", x_transpose, ngb_node_map); + if (is_v2) { + platform::SetOutputNode(op, "XShape", input, ngb_node_map); + } +} + +template +static void BuildTransposeGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + std::vector axis = op_attrs.Get>("axis"); + + ngraph::AxisVector axis_vec(axis.size()); + for (size_t i = 0; i < axis.size(); ++i) { + axis_vec[axis.at(i)] = i; + } + + ngraph::Shape out_shape; + if (is_v2) { + out_shape = platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape(); + } else { + out_shape = platform::GetInputNode(op, "X", ngb_node_map)->get_shape(); + } + + std::shared_ptr x_transpose = + std::make_shared(input, axis_vec, out_shape); + + platform::SetOutputNode(op, "X@GRAD", x_transpose, ngb_node_map); +} + +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(transpose, BuildTransposeNode); +REGISTER_NG_OP(transpose_grad, BuildTransposeGradNode); +REGISTER_NG_OP(transpose2, BuildTransposeNode); +REGISTER_NG_OP(transpose2_grad, BuildTransposeGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_transpose_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_transpose_ngraph_op.py new file mode 100644 index 00000000000..27bf82fc598 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_transpose_ngraph_op.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") +from test_transpose_op import TestTransposeOp, TestCase0, TestCase1, TestCase2, TestCase3, TestCase4 + +if __name__ == '__main__': + unittest.main() -- GitLab