diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bbf1623c39c8010f7d48fc8d7f653c34e92fb99a..8c5cc44528a754f7612a23b1de09c247ca3f0c8e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/framework/array.h b/paddle/fluid/framework/array.h new file mode 100644 index 0000000000000000000000000000000000000000..be9efcd74924a2050a2fd9ab83059590a1a2a2fd --- /dev/null +++ b/paddle/fluid/framework/array.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace framework { +template +class Array { + static_assert(N > 0, "The size of array must be larger than 0"); + + public: + HOSTDEVICE Array() {} + + HOSTDEVICE explicit Array(const T &val) { + for (size_t i = 0; i < N; ++i) data_[i] = val; + } + + HOSTDEVICE const T *Get() const { return data_; } + + HOSTDEVICE T *GetMutable() { return data_; } + + HOSTDEVICE T &operator[](size_t index) { return data_[index]; } + + HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; } + + HOSTDEVICE constexpr size_t size() const { return N; } + + private: + T data_[N]; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f4b48bc7391def082c82ed451fc5a752009a2f1 --- /dev/null +++ b/paddle/fluid/operators/stack_op.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/stack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; +REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, + ops::StackGradOpDescMaker); +REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); + +REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CPU_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..92c1bde2bcf089e5c715e90e564408e6ad37ba17 --- /dev/null +++ b/paddle/fluid/operators/stack_op.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/stack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CUDA_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c777d5feaec1c3a6216b01359a250072a674b700 --- /dev/null +++ b/paddle/fluid/operators/stack_op.h @@ -0,0 +1,278 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +#ifdef __NVCC__ +#include +#include "paddle/fluid/framework/array.h" +#endif + +namespace paddle { +namespace operators { + +class StackOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, + "Number of Inputs(X) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist."); + + auto input_dims = ctx->GetInputsDim("X"); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(X) must be the same"); + } + + // Only lod of X[0] would be shared with Y + ctx->ShareLoD("X", /*->*/ "Y"); + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim("Y", framework::make_ddim(vec)); + } +}; + +class StackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of stack op.").AsDuplicable(); + AddOutput("Y", "The output of stack op."); + AddAttr("axis", + "The axis along which all of the Inputs(X) should be stacked.") + .SetDefault(0); + AddComment(R"DOC( + Stack Operator. + + Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. + )DOC"); + } +}; + +template +struct StackFunctor { + HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) + : x_(x), y_(y), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + y_[idx] = x_[which_x][x_index]; + } + + private: + VecXType x_; + T *y_; + int n_; + int post_; +}; + +template +struct StackGradFunctor { + HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) + : dx_(dx), dy_(dy), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + dx_[which_x][x_index] = dy_[idx]; + } + + private: + VecDxType dx_; + const T *dy_; + int n_; + int post_; +}; + +template +static inline void StackFunctorForRange(const DeviceContext &ctx, + const VecXType &x, T *y, int total_num, + int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackFunctor(x, y, n, post)); +} + +template +static inline void StackGradFunctorForRange(const DeviceContext &ctx, + const VecDxType &dx, const T *dy, + int total_num, int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackGradFunctor(dx, dy, n, post)); +} + +template +class StackKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput("X"); + auto *y = ctx.Output("Y"); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + + int n = static_cast(x.size()); + auto *y_data = y->mutable_data(ctx.GetPlace()); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1, post = 1; + auto &dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + int total_num = pre * n * post; + + auto &dev_ctx = ctx.template device_context(); + constexpr auto kMaxThreshold = 16; + if (std::is_same::value || + n > kMaxThreshold) { +#ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; + thrust::device_vector device_x_vec(x_datas); + auto x_data_arr = device_x_vec.data().get(); +#else + auto x_data_arr = x_datas.data(); +#endif + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); +#ifdef __NVCC__ + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif + } +#ifdef __NVCC__ + else { // NOLINT + framework::Array x_data_arr; + for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i]; + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + } +#endif + } +}; + +class StackOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@Grad) must exist."); + + int axis = ctx->Attrs().Get("axis"); + auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = dy_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(), + static_cast(dy_dim[axis]), + "Number of Outputs(X@Grad) is wrong"); + auto vec = framework::vectorize2int(dy_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim( + framework::GradVarName("X"), + std::vector(dy_dim[axis], framework::make_ddim(vec))); + } +}; + +class StackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("stack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); + op->SetAttrMap(Attrs()); + return op; + } +}; + +template +class StackGradKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + + int n = dy->dims()[axis]; + std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) { + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + } + auto dy_data = dy->data(); + + int pre = 1; + for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; + int total_num = dy->numel(); + int post = total_num / (n * pre); + + auto &dev_ctx = ctx.template device_context(); + constexpr auto kMaxThreshold = 16; + if (std::is_same::value || + n > kMaxThreshold) { +#ifdef __NVCC__ + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors on GPU may be slow."; + thrust::device_vector device_dx_vec(dx_datas); + auto dx_data_arr = device_dx_vec.data().get(); +#else + auto dx_data_arr = dx_datas.data(); +#endif + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, + post); +#ifdef __NVCC__ + // Wait() must be called because device_dx_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif + } +#ifdef __NVCC__ + else { // NOLINT + framework::Array dx_data_arr; + for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i]; + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, + post); + } +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a815ba0f2f4a946f37da6baaafcd56fbb880adda..83250f65e4fadf1799f6473d03e087a3eb76fa69 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -104,6 +104,7 @@ __all__ = [ 'rank_loss', 'prelu', 'flatten', + 'stack', ] @@ -5522,3 +5523,17 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def stack(x, axis=0): + helper = LayerHelper('stack', **locals()) + axis = 0 if axis is None else axis + + if not isinstance(x, list) and not isinstance(x, tuple): + x = [x] + + out = helper.create_tmp_variable(x[0].dtype) + helper.append_op( + type='stack', inputs={'X': x}, outputs={'Y': out}, + attrs={'axis': axis}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py new file mode 100644 index 0000000000000000000000000000000000000000..defdeb5d70df4c39ed8e23247270e6eb3dd14a7a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -0,0 +1,92 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestStackOpBase(OpTest): + def initDefaultParameters(self): + self.num_inputs = 4 + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_x_names(self): + x_names = [] + for i in range(self.num_inputs): + x_names.append('x{}'.format(i)) + return x_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'stack' + self.x = [] + for i in range(self.num_inputs): + self.x.append( + np.random.random(size=self.input_dim).astype(self.dtype)) + + tmp = [] + x_names = self.get_x_names() + for i in range(self.num_inputs): + tmp.append((x_names[i], self.x[i])) + + self.inputs = {'X': tmp} + self.outputs = {'Y': np.stack(self.x, axis=self.axis)} + self.attrs = {'axis': self.axis} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(self.get_x_names(), 'Y') + + +class TestStackOp1(TestStackOpBase): + def initParameters(self): + self.num_inputs = 16 + + +class TestStackOp2(TestStackOpBase): + def initParameters(self): + self.num_inputs = 20 + + +class TestStackOp3(TestStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestStackOpBase): + def initParameters(self): + self.axis = -4 + + +class TestStackOp5(TestStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestStackOpBase): + def initParameters(self): + self.axis = 3 + + +if __name__ == '__main__': + unittest.main()