diff --git a/paddle/fluid/framework/array.h b/paddle/fluid/framework/array.h new file mode 100644 index 0000000000000000000000000000000000000000..be9efcd74924a2050a2fd9ab83059590a1a2a2fd --- /dev/null +++ b/paddle/fluid/framework/array.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace framework { +template +class Array { + static_assert(N > 0, "The size of array must be larger than 0"); + + public: + HOSTDEVICE Array() {} + + HOSTDEVICE explicit Array(const T &val) { + for (size_t i = 0; i < N; ++i) data_[i] = val; + } + + HOSTDEVICE const T *Get() const { return data_; } + + HOSTDEVICE T *GetMutable() { return data_; } + + HOSTDEVICE T &operator[](size_t index) { return data_[index]; } + + HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; } + + HOSTDEVICE constexpr size_t size() const { return N; } + + private: + T data_[N]; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..34e7351073e0f72a50ba8753f8d7c9c4ba59056a --- /dev/null +++ b/paddle/fluid/operators/stack_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/stack_op.h" + +namespace paddle { +namespace operators { + +struct CPUStackFunctor { + template + void operator()(const DeviceContext& ctx, const std::vector& x, + T* y, int pre, int n, int post) const { + int total_num = pre * post * n; + for (int idx = 0; idx < total_num; ++idx) { + int i = idx / (n * post); + int which_x = idx / post - i * n; + int x_index = i * post + idx % post; + y[idx] = x[which_x][x_index]; + } + } +}; + +struct CPUStackGradFunctor { + template + void operator()(const DeviceContext& ctx, std::vector& dx, // NOLINT + const T* dy, int pre, int n, int post) const { + int total_num = pre * post * n; + for (int idx = 0; idx < total_num; ++idx) { + int i = idx / (n * post); + int which_x = idx / post - i * n; + int x_index = i * post + idx % post; + dx[which_x][x_index] = dy[idx]; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; +REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, + ops::StackGradOpDescMaker); +REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); + +REGISTER_OP_CPU_KERNEL( + stack, + ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CPU_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f13018b3b5ad10e2639e9d93827f0bf1aa32cf9 --- /dev/null +++ b/paddle/fluid/operators/stack_op.cu @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/array.h" +#include "paddle/fluid/operators/stack_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void StackCUDAKernel(VecXType x, T* y, int total_num, int n, + int post) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < total_num) { + int i = idx / (n * post); + int which_x = idx / post - i * n; + int x_index = i * post + idx % post; + y[idx] = x[which_x][x_index]; + } +} + +template +__global__ void StackGradCUDAKernel(VecDxType dx, const T* dy, int total_num, + int n, int post) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < total_num) { + int i = idx / (n * post); + int which_x = idx / post - i * n; + int x_index = i * post + idx % post; + dx[which_x][x_index] = dy[idx]; + } +} + +struct GPUStackFunctor { + template + void operator()(const DeviceContext& ctx, const std::vector& x, + T* y, int pre, int n, int post) const { + int total_num = pre * post * n; + int threads = 512; + int grid = (total_num + threads - 1) / threads; + + constexpr auto kMaxThreshold = 16; + if (n <= kMaxThreshold) { + framework::Array arr; + for (int i = 0; i < n; ++i) arr[i] = x[i]; + StackCUDAKernel<<>>(arr, y, total_num, n, + post); + } else { + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors may be slow on GPU."; + thrust::device_vector dev_x(x); + StackCUDAKernel<<>>(dev_x.data().get(), y, + total_num, n, post); + } + } +}; + +struct GPUStackGradFunctor { + template + void operator()(const DeviceContext& ctx, std::vector& dx, // NOLINT + const T* dy, int pre, int n, int post) const { + int total_num = pre * post * n; + int threads = 512; + int grid = (total_num + threads - 1) / threads; + + constexpr auto kMaxThreshold = 16; + if (n <= kMaxThreshold) { + framework::Array arr; + for (int i = 0; i < n; ++i) arr[i] = dx[i]; + StackGradCUDAKernel<<>>( + arr, dy, total_num, n, post); + } else { + VLOG(10) << "Stack more than " << kMaxThreshold + << " tensors may be slow on GPU."; + thrust::device_vector dev_dx(dx); + StackGradCUDAKernel<<>>( + dev_dx.data().get(), dy, total_num, n, post); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + stack, + ops::StackKernel, + ops::StackKernel); + +REGISTER_OP_CUDA_KERNEL(stack_grad, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c438bdb63ae88eee0df61c5c34c989a2c52532cf --- /dev/null +++ b/paddle/fluid/operators/stack_op.h @@ -0,0 +1,192 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +inline void GetPrePostForStackOp(const framework::DDim &dim, int axis, int *pre, + int *post) { + *pre = 1; + for (auto i = 0; i < axis; ++i) (*pre) *= dim[i]; + *post = 1; + for (auto i = axis; i < dim.size(); ++i) (*post) *= dim[i]; +} + +class StackOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, + "Number of Inputs(X) must be larger than 0"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist."); + + auto input_dims = ctx->GetInputsDim("X"); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(X) must be the same"); + } + + // Only lod of X[0] would be shared with Y + ctx->ShareLoD("X", /*->*/ "Y"); + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE( + axis >= -(rank + 1) && axis < rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize2int(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim("Y", framework::make_ddim(vec)); + } +}; + +class StackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of stack op.").AsDuplicable(); + AddOutput("Y", "The output of stack op."); + AddAttr("axis", + "The axis along which all of the Inputs(X) should be stacked.") + .SetDefault(0); + AddComment(R"DOC( + Stack Operator. + + Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. + )DOC"); + } +}; + +template +class StackKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput("X"); + auto *y = ctx.Output("Y"); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + + int n = static_cast(x.size()); + auto *y_data = y->mutable_data(ctx.GetPlace()); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1, post = 1; + auto &dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + + Functor functor; + functor(ctx.template device_context(), x_datas, y_data, pre, + n, post); + } +}; + +class StackOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@Grad) must exist."); + + int axis = ctx->Attrs().Get("axis"); + auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = dy_dim.size(); + PADDLE_ENFORCE(axis >= -rank && axis < rank, + "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(), + static_cast(dy_dim[axis]), + "Number of Outputs(X@Grad) is wrong"); + auto vec = framework::vectorize2int(dy_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim( + framework::GradVarName("X"), + std::vector(dy_dim[axis], framework::make_ddim(vec))); + } +}; + +class StackGradOpDescMaker + : public framework:: + SingleGradOpDescMaker /*framework::GradOpDescMakerBase*/ { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + /* + using framework::GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator ()() const override { + auto x_grads = InputGrad("X", false); + std::vector> grad_ops; + grad_ops.reserve(x_grads.size()); + auto og = OutputGrad("Y"); + std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), + [&og](const std::string& x_grad) { + auto* grad_op = new framework::OpDesc(); + grad_op->SetInput("X", og); + grad_op->SetOutput("Y", {x_grad}); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + }); + return grad_ops; + } + */ + + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("stack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); + op->SetAttrMap(Attrs()); + return op; + } +}; + +template +class StackGradKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + + int n = dy->dims()[axis]; + std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + auto dy_data = dy->data(); + + int pre = 1; + for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; + int post = dy->numel() / (n * pre); + GradFunctor functor; + functor(ctx.template device_context(), dx_datas, dy_data, + pre, n, post); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 733157ea05ed39434b9a750e3a94ea548f512ce6..48e37796e1b4190e50602421106a105e4d4f6d74 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase { PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); + + auto ctx = executor.Prepare(*program, block->ID()); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - - executor.Run(*program, ¤t_scope, block->ID(), - false /*create_local_scope*/); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false); } } }; @@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase { framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto ctx = executor.Prepare(*program, block->ID()); auto *step_scopes = scope.FindVar(Input(kStepScopes))->GetMutable(); @@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase { } } } - - executor.Run(*program, *cur_scope_iter, block->ID(), false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 71592618f540a8f42d9a25dd8a1af5e67a592f21..b0fb09385a6f97cb6b0d3dcb10bcc26f91de2aaf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -29,80 +29,21 @@ from .. import unique_name from functools import reduce __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'conv3d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'pool3d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'conv3d_transpose', - 'sequence_expand', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'hsigmoid', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'lod_reset', - 'lrn', - 'pad', - 'label_smooth', - 'roi_pool', - 'dice_loss', - 'image_resize', - 'image_resize_short', - 'resize_bilinear', - 'gather', - 'scatter', - 'random_crop', - 'mean_iou', - 'relu', - 'log', - 'crop', - 'rank_loss', - 'prelu', - 'flatten', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', 'conv3d', + 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'pool3d', + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose', + 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', + 'reduce_min', 'reduce_prod', 'sequence_first_step', 'sequence_last_step', + 'dropout', 'split', 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', + 'matmul', 'topk', 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', + 'nce', 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', + 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', + 'label_smooth', 'roi_pool', 'dice_loss', 'image_resize', + 'image_resize_short', 'resize_bilinear', 'gather', 'scatter', 'random_crop', + 'mean_iou', 'relu', 'log', 'crop', 'rank_loss', 'prelu', 'flatten', 'stack' ] @@ -5517,3 +5458,16 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def stack(x, axis=0): + helper = LayerHelper('stack', **locals()) + axis = 0 if axis is None else axis + + if not isinstance(x, list) and not isinstance(x, tuple): + x = [x] + + out = helper.create_tmp_variable(x[0].dtype) + helper.append_op( + type='stack', inputs={'X': x}, outpus={'Y': out}, attrs={'axis': axis}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py new file mode 100644 index 0000000000000000000000000000000000000000..defdeb5d70df4c39ed8e23247270e6eb3dd14a7a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -0,0 +1,92 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +class TestStackOpBase(OpTest): + def initDefaultParameters(self): + self.num_inputs = 4 + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_x_names(self): + x_names = [] + for i in range(self.num_inputs): + x_names.append('x{}'.format(i)) + return x_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'stack' + self.x = [] + for i in range(self.num_inputs): + self.x.append( + np.random.random(size=self.input_dim).astype(self.dtype)) + + tmp = [] + x_names = self.get_x_names() + for i in range(self.num_inputs): + tmp.append((x_names[i], self.x[i])) + + self.inputs = {'X': tmp} + self.outputs = {'Y': np.stack(self.x, axis=self.axis)} + self.attrs = {'axis': self.axis} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(self.get_x_names(), 'Y') + + +class TestStackOp1(TestStackOpBase): + def initParameters(self): + self.num_inputs = 16 + + +class TestStackOp2(TestStackOpBase): + def initParameters(self): + self.num_inputs = 20 + + +class TestStackOp3(TestStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestStackOpBase): + def initParameters(self): + self.axis = -4 + + +class TestStackOp5(TestStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestStackOpBase): + def initParameters(self): + self.axis = 3 + + +if __name__ == '__main__': + unittest.main()