From 7bd1d03ee5b786bd361338046eeb86ddeea0c026 Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Tue, 7 May 2019 20:07:13 -0700 Subject: [PATCH] Adding lrn op for ngraph engine (#17189) * added lrn op test=develop * Added CreateConstant method test=develop * avoid duplicates test=develop --- paddle/fluid/operators/ngraph/ops/lrn_op.h | 54 +++++++++++++++++++ paddle/fluid/platform/ngraph_helper.h | 21 ++++++++ .../unittests/ngraph/test_lrn_ngraph_op.py | 29 ++++++++++ 3 files changed, 104 insertions(+) create mode 100644 paddle/fluid/operators/ngraph/ops/lrn_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_lrn_ngraph_op.py diff --git a/paddle/fluid/operators/ngraph/ops/lrn_op.h b/paddle/fluid/operators/ngraph/ops/lrn_op.h new file mode 100644 index 00000000000..68a0eea0892 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/lrn_op.h @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { +static void BuildLrnNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + + auto op_attrs = framework::AttrReader(op->Attrs()); + const int n = op_attrs.Get("n"); + const float alpha = op_attrs.Get("alpha") * static_cast(n); + const float beta = op_attrs.Get("beta"); + const float k = op_attrs.Get("k"); + + auto lrn_out = std::make_shared(input, alpha, beta, k, n); + std::shared_ptr mid_out = paddle::platform::CreateConstant( + input->get_element_type(), input->get_shape(), {k}); + + platform::SetOutputNode(op, "MidOut", mid_out, ngb_node_map); + platform::SetOutputNode(op, "Out", lrn_out, ngb_node_map); +} + +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(lrn, BuildLrnNode); diff --git a/paddle/fluid/platform/ngraph_helper.h b/paddle/fluid/platform/ngraph_helper.h index e74f57a79a6..9e6521653b8 100644 --- a/paddle/fluid/platform/ngraph_helper.h +++ b/paddle/fluid/platform/ngraph_helper.h @@ -16,7 +16,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include "ngraph/ngraph.hpp" @@ -103,6 +105,25 @@ std::shared_ptr GetOutputNode( return GetNode(op, name, op->Outputs(), ngb_node_map); } +template +std::shared_ptr CreateConstant(const ngraph::element::Type& type, + ngraph::Shape shape, + std::initializer_list values) { + std::shared_ptr result; + if (values.size() == 1 && shape != ngraph::Shape{} && // NOLINT + shape != ngraph::Shape{1}) { + result = std::make_shared(type, ngraph::Shape{}, + std::vector{values}); + ngraph::AxisSet axis_set; + for (size_t i = 0; i < shape.size(); ++i) axis_set.insert(i); + result = std::make_shared(result, shape, axis_set); + } else { + result = std::make_shared(type, shape, + std::vector{values}); + } + return result; +} + void SetOutputNode( const std::shared_ptr& op, const std::string name, std::shared_ptr node, diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_lrn_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_lrn_ngraph_op.py new file mode 100644 index 00000000000..4c998c6ca2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_lrn_ngraph_op.py @@ -0,0 +1,29 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp + + +class TestLRNNGRAPHOp(TestLRNOp): + def test_check_output(self): + self.check_output(atol=0.002) + + +del TestLRNOp + +if __name__ == '__main__': + unittest.main() -- GitLab