diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.cc b/paddle/fluid/operators/ngraph/ngraph_bridge.cc index dafc31b546e3ca6d8dc8d5634dd51cff9fe5bfb7..4ff50935d6c78a01db222dcc8bcca3b22985d943 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.cc +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include #include #include "ngraph/ngraph.hpp" @@ -24,6 +25,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/ngraph_helper.h" +constexpr int64_t kNoPadding = -1; + namespace paddle { namespace operators { @@ -31,6 +34,34 @@ bool NgraphBridge::isRegister(const std::string& str) { return ops::NgraphSingleton::Lookup(str); } +bool NgraphBridge::isSupported( + const std::unique_ptr& op) { + static std::unordered_set skip_op_list{"reshape", "reshape2", + "lookup_table"}; + bool result = true; + auto& op_type = op->Type(); + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + if (!isRegister(op_type)) { + if (skip_op_list.count(op_type)) { + if (op_type == "lookup_table") { + if (op_attrs.Get("is_sparse") || + (op_attrs.Get("padding_idx") != kNoPadding)) { + result = false; + } + } else if ((op_type == "reshape") || (op_type == "reshape2")) { + if (op->Input("Shape") != paddle::framework::kEmptyVarName) { + result = false; + } + } else { + result = false; + } + } + } else { + result = false; + } + return result; +} + void NgraphBridge::BuildNgNode( const std::shared_ptr& op) { auto& op_type = op->Type(); diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.h b/paddle/fluid/operators/ngraph/ngraph_bridge.h index b609c284959238689eaf35c87d1bc4e4330b5c1f..0b43ec53874d962699abef3cf843c5518d6f072d 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.h +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.h @@ -39,6 +39,8 @@ class NgraphBridge { static bool isRegister(const std::string& str); + static bool isSupported(const std::unique_ptr& op); + private: std::shared_ptr< std::unordered_map>> diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 2486ae6bb5caf387cc44f89bf6f1a1cb5d20ad1d..e459bb9edc68636a6a4e915c2733ffe14fc19a02 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -134,12 +134,11 @@ static std::vector> NgraphOpIntervals( int pivot = left; while (pivot < right) { auto op_type = ops->at(pivot)->Type(); - if (NgraphBridge::isRegister(op_type)) { + if (!NgraphBridge::isSupported(ops->at(pivot))) { ++pivot; } else { int start = pivot, end = start; - while (pivot < right && - (!NgraphBridge::isRegister(ops->at(pivot)->Type()))) { + while (pivot < right && (NgraphBridge::isSupported(ops->at(pivot)))) { ++pivot; ++end; } diff --git a/paddle/fluid/operators/ngraph/ops/reshape_op.h b/paddle/fluid/operators/ngraph/ops/reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..be3d38af492d1eae6ff64c531734e50965dea40c --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/reshape_op.h @@ -0,0 +1,107 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +ngraph::Shape calc_output_shape(const ngraph::Shape& input_shape, + const std::vector& v_shape) { + auto out_shape = v_shape; + for (size_t i = 0; i < v_shape.size(); ++i) { + if (v_shape[i] == 0) { + out_shape[i] = input_shape[i]; + } + } + int size_input = ngraph::shape_size(input_shape); + int size_out = 1; + for (auto o : out_shape) { + if (o > 0) size_out *= o; + } + for (auto& o : out_shape) { + if (o == -1) o = size_input / size_out; + } + return ngraph::Shape(out_shape.begin(), out_shape.end()); +} + +template +static void BuildReshapeNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + std::shared_ptr input = + platform::GetInputNode(op, "X", ngb_node_map); + auto input_shape = input->get_shape(); + // TODO(mozga-intel) The vector of shape is not supported yet, that's + // asDispensable() operator" + std::shared_ptr shape = + platform::GetInputNode(op, "Shape", ngb_node_map); + + auto op_attrs = framework::AttrReader(op->Attrs()); + std::vector v_shape = op_attrs.Get>("shape"); + auto out = input; + if (shape != nullptr) { + ngraph::Shape new_shape; + for (auto& it : shape->get_shape()) { + new_shape.push_back(it); + } + out = platform::NgReshaper(input, shape->get_shape()); + } else { + auto out_shape = calc_output_shape(input_shape, v_shape); + out = platform::NgReshaper(input, out_shape); + } + + if (is_v2) { + platform::SetOutputNode(op, "XShape", input, ngb_node_map); + } + platform::SetOutputNode(op, "Out", out, ngb_node_map); +} + +template +void BuildReshapeGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + std::shared_ptr input; + if (is_v2) { + input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map); + } else { + input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + } + auto dx = platform::NgReshaper(dout, input->get_shape()); + paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(reshape, BuildReshapeNode); +REGISTER_NG_OP(reshape2, BuildReshapeNode); +REGISTER_NG_OP(reshape_grad, BuildReshapeGradNode); +REGISTER_NG_OP(reshape2_grad, BuildReshapeGradNode); diff --git a/paddle/fluid/platform/ngraph_helper.h b/paddle/fluid/platform/ngraph_helper.h index 9e6521653b80abec1c5212f5deb84153335c2a9c..9c75f8dc6342ea8a5f3fa580e84610956d86555d 100644 --- a/paddle/fluid/platform/ngraph_helper.h +++ b/paddle/fluid/platform/ngraph_helper.h @@ -77,9 +77,7 @@ std::shared_ptr GetNode( std::unordered_map>> ngb_node_map) { auto& var_names = var_map.at(name); - PADDLE_ENFORCE_EQ(var_names.size(), 1, - "op %s name %s expects one associated var", op->Type(), - name); + if (var_names.size() == 0) return nullptr; if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { return (*ngb_node_map)[var_names[0]]; } else { @@ -189,6 +187,22 @@ inline void TrimTrailingSingularDims(ngraph::Shape* shape) { } } } + +ngraph::element::Type GetNgType(paddle::framework::proto::VarType::Type dtype) { + ngraph::element::Type ng_dtype; + if (dtype == paddle::framework::proto::VarType::FP32) { + ng_dtype = ngraph::element::f32; + } else if (dtype == paddle::framework::proto::VarType::FP64) { + ng_dtype = ngraph::element::f64; + } else if (dtype == paddle::framework::proto::VarType::INT64) { + ng_dtype = ngraph::element::i64; + } else if (dtype == paddle::framework::proto::VarType::INT32) { + ng_dtype = ngraph::element::i32; + } else { + PADDLE_THROW("unsupported data type: %s", dtype); + } + return ng_dtype; +} } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..cffa28327143960274f38f8a7844031293b0995e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py @@ -0,0 +1,23 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") + +from test_reshape_op import TestReshapeOp, TestReshapeOpDimInfer1, TestReshapeOpDimInfer2, TestReshapeOpWithInputShape + +if __name__ == '__main__': + unittest.main()