diff --git a/paddle/fluid/operators/ngraph/ops/slice_op.h b/paddle/fluid/operators/ngraph/ops/slice_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1ae4d198c23b7f92ebb571c6ef576a8c2a7e0feb --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/slice_op.h @@ -0,0 +1,111 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildSliceNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = paddle::platform::GetInputNode(op, "Input", ngb_node_map); + auto input_shape = input->get_shape(); + auto op_attrs = framework::AttrReader(op->Attrs()); + auto axes = op_attrs.Get>("axes"); + auto starts = op_attrs.Get>("starts"); + auto ends = op_attrs.Get>("ends"); + ngraph::Coordinate ng_start, ng_end; + int axis, start, end; + for (size_t i = 0; i < input_shape.size(); ++i) { + ng_start.push_back(0); + ng_end.push_back(input_shape[i]); + } + for (size_t i = 0; i < axes.size(); ++i) { + axis = input_shape[axes[i]]; + start = starts[i] < 0 ? (starts[i] + axis) : starts[i]; + end = ends[i] < 0 ? (ends[i] + axis) : ends[i]; + start = std::max(start, 0); + end = std::max(end, 0); + start = std::min(start, axis); + end = std::min(end, axis); + start = std::min(start, end); + ng_start[axes[i]] = start; + ng_end[axes[i]] = end; + } + auto out = std::make_shared(input, ng_start, ng_end); + platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +void BuildSliceGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = paddle::platform::GetInputNode(op, "Input", ngb_node_map); + auto input_shape = input->get_shape(); + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + auto axes = op_attrs.Get>("axes"); + auto starts = op_attrs.Get>("starts"); + auto ends = op_attrs.Get>("ends"); + auto reshape = input_shape; + ngraph::Coordinate ng_start, ng_end; + int axis, start, end; + for (size_t i = 0; i < input_shape.size(); ++i) { + ng_start.push_back(0); + ng_end.push_back(input_shape[i]); + } + for (size_t i = 0; i < axes.size(); ++i) { + axis = input_shape[axes[i]]; + start = starts[i] < 0 ? (starts[i] + axis) : starts[i]; + end = ends[i] < 0 ? (ends[i] + axis) : ends[i]; + start = std::max(start, 0); + end = std::max(end, 0); + start = std::min(start, axis); + end = std::min(end, axis); + start = std::min(start, end); + ng_start[axes[i]] = start; + ng_end[axes[i]] = end; + reshape[axes[i]] = end - start; + } + std::vector axisVec(dout->get_shape().size()); + std::iota(axisVec.begin(), axisVec.end(), 0); + auto dout_reshape = std::make_shared( + dout, ngraph::AxisVector(axisVec), reshape); + + std::shared_ptr input0 = paddle::platform::CreateConstant( + dout->get_element_type(), input_shape, {0}); + + auto din = std::make_shared(input0, dout_reshape, + ng_start, ng_end); + platform::SetOutputNode(op, "Input@GRAD", din, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(slice, BuildSliceNode); +REGISTER_NG_OP(slice_grad, BuildSliceGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_slice_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_slice_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..dc41e8a98a797bb3ad8bf503694bcee12fdf840c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_slice_ngraph_op.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") +from test_slice_op import TestSliceOp, TestCase1, TestCase2 + +if __name__ == '__main__': + unittest.main()