diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d21cb2697b3368540ea8ca721a7f9e21e04f3d48..aad64406f7f60e1582a7c4c8c12ad97a127188ae 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -164,6 +164,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) +paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8da0aaaafeb151e8f1900bc66f06e771c857fc00..e73d31562a93af4e9bdb3d5806e9182d0eea8167 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(unstack_op DEPS stack_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ff3249cc333231a0624cd5aab9603a6a75f4480 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/unstack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +USE_OP(stack); + +REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, + ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker); + +REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, + ops::UnStackOpGradInferShape); diff --git a/paddle/fluid/operators/unstack_op.h b/paddle/fluid/operators/unstack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..348a1038804ccb2551e5f729cc1a38bcef1511f5 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.h @@ -0,0 +1,135 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnStackOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist."); + + int axis = ctx->Attrs().Get("axis"); + int num = ctx->Attrs().Get("num"); + auto x_dim = ctx->GetInputDim("X"); + int rank = x_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast(num), + "Number of Outputs(Y) is wrong"); + if (x_dim[axis] > 0) { + PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); + } + auto vec = framework::vectorize2int(x_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim("Y", std::vector( // NOLINT + x_dim[axis], framework::make_ddim(vec))); + } +}; + +class UnStackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of unstack op."); + AddOutput("Y", "The output of unstack op.").AsDuplicable(); + AddAttr("axis", "The axis along which Input(X) should be unstacked.") + .SetDefault(0); + AddAttr("num", "The number of outputs(Y).").GreaterThan(0); + AddComment(R"DOC( + UnStack Operator. + + UnStack Input(X) into several tensors along Attr(axis). + )DOC"); + } +}; + +class UnStackOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_grad_op = framework::OpRegistry::CreateOp( + "stack_grad", {{framework::GradVarName("Y"), {Input("X")}}}, + {{framework::GradVarName("X"), Outputs("Y")}}, Attrs()); + stack_grad_op->Run(scope, place); + } +}; + +class UnStackOpGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, + "Number of Inputs(Y@Grad) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) must exist."); + + auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(Y@Grad) must be the same"); + } + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec)); + } +}; + +class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("unstack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class UnStackGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto stack_op = framework::OpRegistry::CreateOp( + "stack", {{"X", Inputs(framework::GradVarName("Y"))}}, + {{"Y", {Output(framework::GradVarName("X"))}}}, Attrs()); + stack_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ca10d73b0899d257d997ac435fef5481c50f8c59..9a6aa14c4ffb8865344140381a410036749fa959 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -105,6 +105,7 @@ __all__ = [ 'flatten', 'sequence_mask', 'stack', + 'unstack', ] @@ -5601,3 +5602,44 @@ def stack(x, axis=0): type='stack', inputs={'X': x}, outputs={'Y': out}, attrs={'axis': axis}) return out + + +def unstack(x, axis=0, num=None): + """ + **UnStack Layer** + + This layer unstacks input :code:`x` into several tensors along axis. + + If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`. + If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`, + and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is + raised. + + Args: + x (Variable): Input variable. + axis (int): The axis along which the input is unstacked. + num (int|None): The number of output variables. + + Returns: + list(Variable): The unstacked variables. + + """ + + helper = LayerHelper('unstack', **locals()) + if num is None: + if axis is None or x.shape[axis] <= 0: + raise ValueError('unknown unstack number') + else: + num = x.shape[axis] + + outs = [] + for _ in num: + outs.append(helper.create_tmp_variable(x.dtype)) + + helper.append_op( + type='unstack', + inputs={'X': [x]}, + outputs={'Y': outs}, + attrs={'axis': axis, + 'num': num}) + return outs diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbac8928ec40dc3e1c0e91e7779ec9ec978d884 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestUnStackOpBase(OpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad('X', self.get_y_names()) + + +class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + +class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + +if __name__ == '__main__': + unittest.main()