From ecfa68ecaa3f8f620aacfef3a4759566b954d0ae Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Wed, 19 Dec 2018 02:50:15 +0100 Subject: [PATCH] Enable fill_constant operator for a ngraph test=develop --- paddle/fluid/framework/ngraph_bridge.cc | 1 + paddle/fluid/operators/ngraph/ngraph_ops.h | 1 + .../operators/ngraph/ops/binary_unnary_op.h | 1 - .../operators/ngraph/ops/fill_constant_op.h | 61 +++++++++++++++++++ .../ngraph/test_fill_constant_ngraph_op.py | 37 +++++++++++ 5 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/ngraph/ops/fill_constant_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index 5fcb17b9f..ca29e7c09 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -31,6 +31,7 @@ std::map>>)>> NgraphBridge::NG_NODE_MAP = { + {"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, {"mul", paddle::operators::ngraphs::BuildMulNode}, {"mul_grad", paddle::operators::ngraphs::BuildMulGradNode}, {"relu", paddle::operators::ngraphs::BuildUnaryNode}, diff --git a/paddle/fluid/operators/ngraph/ngraph_ops.h b/paddle/fluid/operators/ngraph/ngraph_ops.h index 0ed77ff55..956932d94 100644 --- a/paddle/fluid/operators/ngraph/ngraph_ops.h +++ b/paddle/fluid/operators/ngraph/ngraph_ops.h @@ -22,4 +22,5 @@ limitations under the License. */ #pragma once #include "ops/binary_unnary_op.h" +#include "ops/fill_constant_op.h" #include "ops/mul_op.h" diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h index 4e2f5e231..6610380fc 100644 --- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h +++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h @@ -45,7 +45,6 @@ static void BuildUnaryNode( auto out = std::make_shared(input); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); } - } // namespace ngraphs } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h new file mode 100644 index 000000000..bd818c893 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -0,0 +1,61 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#pragma once + +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +static void BuildFillConstantNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + auto vsp = op_attrs.Get>("shape"); + ngraph::Shape shape; + for (auto& sp : vsp) { + shape.push_back(sp); + } + float value = op_attrs.Get("value"); + ngraph::element::Type ng_dtype; + auto data_type = static_cast( + op_attrs.Get("dtype")); + if (data_type == paddle::framework::proto::VarType::FP32) { + ng_dtype = ngraph::element::f32; + } else if (data_type == paddle::framework::proto::VarType::FP64) { + ng_dtype = ngraph::element::f64; + } else if (data_type == paddle::framework::proto::VarType::INT64) { + ng_dtype = ngraph::element::i64; + } else if (data_type == paddle::framework::proto::VarType::INT32) { + ng_dtype = ngraph::element::i32; + } else if (data_type == paddle::framework::proto::VarType::BOOL) { + ng_dtype = ngraph::element::boolean; + } else { + PADDLE_THROW("unsupported data type: %s", data_type); + } + auto out = ngraph::op::Constant::create(ng_dtype, shape, {value}); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle +#endif diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py new file mode 100644 index 000000000..835376ffe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from paddle.fluid.tests.unittests.test_fill_constant_op import TestFillConstantOp1, TestFillConstantOp2, TestFillConstantOpWithSelectedRows + + +class TestNGRAPHFillConstantOp1(TestFillConstantOp1): + def setUp(self): + super(TestNGRAPHFillConstantOp1, self).setUp() + + +class TestNGRAPHFillConstantOp2(TestFillConstantOp2): + def setUp(self): + super(TestNGRAPHFillConstantOp2, self).setUp() + + +class TestNGRAPHFillConstantOpWithSelectedRows( + TestFillConstantOpWithSelectedRows): + def setUp(self): + super(TestFillConstantOpWithSelectedRows, self).setUp() + + +if __name__ == "__main__": + unittest.main() -- GitLab