diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e362d3486487dd0b55e3e40d1c1358f2e5604ac5..fff03ffa67454388e867cf66ab88c7724a4dbabc 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -116,6 +116,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) +paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) diff --git a/paddle/fluid/operators/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_expand_as_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..33c1e1c973c80ba3943924331380d35b225ac800 --- /dev/null +++ b/paddle/fluid/operators/sequence_expand_as_op.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_expand_as_op.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; + +class SequenceExpandAsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceExpandAsOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of SequenceExpandAsOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceExpandAsOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = x_dims; + + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Dimension number of Input(X) should be at least 2."); + + if (ctx->IsRuntime()) { + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + framework::Variable* y_var = + boost::get(ctx->GetInputVarPtrs("Y")[0]); + + auto& x_dim = x_var->Get().dims(); + auto& y_lod = y_var->Get().lod(); + + PADDLE_ENFORCE_EQ(y_lod.size(), 1, + "Level number of Input(Y)'s lod should be 1."); + + PADDLE_ENFORCE_EQ(static_cast(x_dim[0]), y_lod[0].size() - 1, + "The first dimension of Input(X) should be equal " + "to the size of Input(Y)'s 0 level lod."); + + int64_t out_first_dim = 0; + if (y_lod[0].size() <= 1) { + out_first_dim = x_dims[0]; + } else { + for (size_t i = 1; i < y_lod[0].size(); ++i) { + out_first_dim += (y_lod[0][i] - y_lod[0][i - 1]); + } + } + out_dims[0] = out_first_dim; + } else { + out_dims[0] = -1; + } + + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("Y", /*->*/ "Out"); + } +}; + +class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor whose lod " + "level is at most 1."); + AddInput("Y", + "(LoDTensor, default LoDTensor) Referred LoDTensor whose " + "lod (specified level) is referred by Input(X)."); + AddOutput("Out", + "(LodTensor, default LoDTensor) Output LoDTensor which is " + "generated from Input(X) by referring lod of Input(Y)."); + AddComment(R"DOC( +Sequence Expand As Operator. + +This operator expands `X` according to the zeroth level lod of `Y`. Current +implementation requires the level number of Input(Y)'s lod should be 1, and +the first dimension of Input(X) should be equal to the size of Input(Y)'s zeroth +level lod, and lod of Input(X) is not considered. + +Following are cases to better explain how this works: + +Case 1: + +Given a 1-level LoDTensor input(X) + X.data = [[a], [b], [c], [d]] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 3, 6, 7, 8]] +ref_level: 0 +then we get 1-level LoDTensor + Out.lod = [[0, 3, 6, 7, 8]] + Out.data = [[a], [a], [a], [b], [b], [b], [c], [d]] + Out.dims = [8, 1] + +Case 2: + +Given a common Tensor input(X) + X.data = [[a, b], [c, d], [e, f]] + X.dims = [3, 2] +and input(Y) + Y.lod = [[0, 2, 3, 6]] +ref_level: 0 +then we get a common LoDTensor + Out.lod = [[0, 2, 3, 6]] + Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]] + Out.dims = [6, 2] + +)DOC"); + } +}; + +class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + ctx->ShareLoD("X", x_grad_name); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp, + ops::SequenceExpandAsOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad); +REGISTER_OP_CPU_KERNEL( + sequence_expand_as, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel); +REGISTER_OP_CPU_KERNEL( + sequence_expand_as_grad, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_as_op.cu b/paddle/fluid/operators/sequence_expand_as_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7357f5ae6e732f28307af65d1f1b6b3cbed1f640 --- /dev/null +++ b/paddle/fluid/operators/sequence_expand_as_op.cu @@ -0,0 +1,134 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/sequence_expand_as_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +static __global__ void sequence_expand_as_kernel(const T *in_data, + const size_t *expand_offset, + const size_t src_hight, + const size_t src_widht, + T *out_data) { + for (int h_id = blockIdx.x; h_id < src_hight; h_id += gridDim.x) { + int span = expand_offset[h_id + 1] - expand_offset[h_id]; + if (span == 0) continue; + const T *src = in_data + h_id * src_widht; + for (int w_id = threadIdx.x; w_id < src_widht; w_id += blockDim.x) { + T ele = src[w_id]; + int offset = expand_offset[h_id] * src_widht; + for (int k = 0; k < span; ++k) { + out_data[offset + k * src_widht + w_id] = ele; + } + } + } +} + +template +static __global__ void sequence_expand_as_grad_kernel( + const T *dout_data, const size_t *expand_offset, const size_t dst_hight, + const size_t dst_width, T *dx_data) { + for (int h_id = blockIdx.x; h_id < dst_hight; h_id += gridDim.x) { + T *dst = dx_data + h_id * dst_width; + int span = expand_offset[h_id + 1] - expand_offset[h_id]; + + for (int w_id = threadIdx.x; w_id < dst_width; w_id += blockDim.x) { + T result = 0; + for (int k = 0; k < span; ++k) { + int offset = (expand_offset[h_id] + k) * dst_width; + const T *src = dout_data + offset; + result += src[w_id]; + } + dst[w_id] = result; + } + } +} + +template +struct SequenceExpandFunctor { + void operator()( + const platform::CUDADeviceContext &context, const LoDTensor &x, + const framework::Vector &ref_lod, /*expand referenced lod*/ + LoDTensor *out) { + int hight = x.dims()[0]; + int width = framework::product(x.dims()) / hight; + + const int kThreadsPerBlock = 1024; + int thread_x = kThreadsPerBlock; + if (width < kThreadsPerBlock) { // block_cols is aligned by 32. + thread_x = ((width + 31) >> 5) << 5; + } + + int max_threads = context.GetMaxPhysicalThreadCount(); + int block_x = std::max(max_threads / thread_x, 1); + + dim3 block_size(thread_x); + dim3 grid_size(block_x); + sequence_expand_as_kernel<<>>( + x.data(), ref_lod.CUDAData(context.GetPlace()), hight, width, + out->mutable_data(context.GetPlace())); + } +}; + +template +struct SequenceExpandAsGradFunctor { + void operator()(const platform::CUDADeviceContext &context, + const LoDTensor &dout, + const framework::Vector &ref_lod, /*expand based lod*/ + LoDTensor *dx) { + int hight = dx->dims()[0]; + int width = framework::product(dx->dims()) / hight; + + const int kThreadsPerBlock = 1024; + int thread_x = kThreadsPerBlock; + if (width < kThreadsPerBlock) { // block_cols is aligned by 32. + thread_x = ((width + 31) >> 5) << 5; + } + + int max_threads = context.GetMaxPhysicalThreadCount(); + int block_x = std::max(max_threads / thread_x, 1); + + dim3 block_size(thread_x); + dim3 grid_size(block_x); + sequence_expand_as_grad_kernel<<>>( + dout.data(), ref_lod.CUDAData(context.GetPlace()), hight, width, + dx->mutable_data(context.GetPlace())); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_expand_as, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel, + ops::SequenceExpandAsKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_expand_as_grad, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel, + ops::SequenceExpandAsGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_as_op.h b/paddle/fluid/operators/sequence_expand_as_op.h new file mode 100644 index 0000000000000000000000000000000000000000..42c90d01c05e369efc276498aa94debb367a6bfa --- /dev/null +++ b/paddle/fluid/operators/sequence_expand_as_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include // std::iota +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" + +namespace paddle { +namespace operators { + +template +struct SequenceExpandFunctor { + void operator()( + const DeviceContext &ctx, const framework::LoDTensor &x, + const framework::Vector &ref_lod, /*expand referenced lod*/ + framework::LoDTensor *out); +}; + +template +struct SequenceExpandAsGradFunctor { + void operator()( + const DeviceContext &ctx, const framework::LoDTensor &dout, + const framework::Vector &ref_lod, /*expand referenced lod*/ + framework::LoDTensor *dx); +}; + +template +struct SequenceExpandFunctor { + void operator()( + const platform::CPUDeviceContext &context, const framework::LoDTensor &x, + const framework::Vector &ref_lod, /*expand referenced lod*/ + framework::LoDTensor *out) { + int64_t hight = x.dims()[0]; + int64_t width = framework::product(x.dims()) / hight; + + const T *in_data = x.data(); + T *out_data = out->mutable_data(context.GetPlace()); + + for (int h_id = 0; h_id < hight; ++h_id) { + size_t span = ref_lod[h_id + 1] - ref_lod[h_id]; + if (span == 0) continue; + const T *src = in_data + h_id * width; + for (int64_t w_id = 0; w_id < width; ++w_id) { + T ele = src[w_id]; + size_t offset = ref_lod[h_id] * width; + for (size_t k = 0; k < span; ++k) { + out_data[offset + k * width + w_id] = ele; + } + } + } + } +}; + +template +class SequenceExpandAsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *x = context.Input("X"); + auto *y = context.Input("Y"); + auto *out = context.Output("Out"); + + auto &y_lod = y->lod(); + PADDLE_ENFORCE_EQ(y_lod.size(), 1, "LoD of Y should be 1."); + PADDLE_ENFORCE_GT(y_lod[0].size(), 1, "."); + + out->mutable_data(context.GetPlace()); + + auto &dev_ctx = context.template device_context(); + SequenceExpandFunctor seq_espand_functor; + seq_espand_functor(dev_ctx, *x, y_lod[0], out); + } +}; + +/* + *Given Grad(Out) + * + * Grad(Out).lod = [[0, 3, 6]] + * Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + * Then + * Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)] + * = [0.6, 1.5] + * Grad(X).lod = Input(X).lod + * + * */ +template +struct SequenceExpandAsGradFunctor { + void operator()( + const platform::CPUDeviceContext &context, + const framework::LoDTensor &dout, + const framework::Vector &ref_lod, /*expand referenced lod*/ + framework::LoDTensor *dx) { + int64_t hight = dx->dims()[0]; + int64_t width = framework::product(dx->dims()) / hight; + + const T *dout_data = dout.data(); + T *dx_data = dx->mutable_data(context.GetPlace()); + + for (int64_t h_id = 0; h_id < hight; ++h_id) { + T *dst = dx_data + h_id * width; + size_t span = ref_lod[h_id + 1] - ref_lod[h_id]; + for (int64_t w_id = 0; w_id < width; ++w_id) { + T result = 0; + for (size_t k = 0; k < span; ++k) { + size_t offset = (ref_lod[h_id] + k) * width; + result += dout_data[offset + w_id]; + } + dst[w_id] = result; + } + } + } +}; + +template +class SequenceExpandAsGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *g_out = + context.Input(framework::GradVarName("Out")); + auto *y = context.Input("Y"); + auto *g_x = + context.Output(framework::GradVarName("X")); + + g_x->mutable_data(context.GetPlace()); + + SequenceExpandAsGradFunctor functor; + functor(context.template device_context(), *g_out, + y->lod()[0], g_x); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3ae0fac4bef5c47964f9a9cd8dd45b57e705e1f8..3bc3acabee8a5143d53d8fccdf149962d2ad25a2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -54,6 +54,7 @@ __all__ = [ 'conv2d_transpose', 'conv3d_transpose', 'sequence_expand', + 'sequence_expand_as', 'sequence_pad', 'lstm_unit', 'reduce_sum', @@ -2666,6 +2667,71 @@ def sequence_expand(x, y, ref_level=-1, name=None): return tmp +def sequence_expand_as(x, y, name=None): + """Sequence Expand As Layer. This layer will expand the input variable **x** + according to the zeroth level lod of **y**. Current implementation requires + the level number of Input(Y)'s lod must be 1, and the first dimension of + Input(X) should be equal to the size of Input(Y)'s zeroth level lod, and + lod of Input(X) is not considered. + + Following examples will explain how sequence_expand_as works: + + .. code-block:: text + + * Case 1: + + Given a 1-level LoDTensor input(X) + X.data = [[a], [b], [c], [d]] + X.dims = [4, 1] + and input(Y) + Y.lod = [[0, 3, 6, 7, 8]] + ref_level: 0 + then we get 1-level LoDTensor + Out.lod = [[0, 3, 6, 7, 8]] + Out.data = [[a], [a], [a], [b], [b], [b], [c], [d]] + Out.dims = [8, 1] + + * Case 2: + + Given a common Tensor input(X) + X.data = [[a, b], [c, d], [e, f]] + X.dims = [3, 2] + and input(Y) + Y.lod = [[0, 2, 3, 6]] + ref_level: 0 + then we get a common LoDTensor + Out.lod = [[0, 2, 3, 6]] + Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]] + Out.dims = [6, 2] + + Args: + x (Variable): The input variable which is a Tensor or LoDTensor. + y (Variable): The input variable which is a LoDTensor. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The expanded variable which is a LoDTensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[10], dtype='float32') + y = fluid.layers.data(name='y', shape=[10, 20], + dtype='float32', lod_level=1) + out = layers.sequence_expand_as(x=x, y=y) + """ + helper = LayerHelper('sequence_expand_as', input=x, **locals()) + dtype = helper.input_dtype() + tmp = helper.create_tmp_variable(dtype) + helper.append_op( + type='sequence_expand_as', + inputs={'X': x, + 'Y': y}, + outputs={'Out': tmp}) + return tmp + + @templatedoc() def sequence_pad(x, pad_value, maxlen=None): """ diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py b/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac97f7ed49fa7e6537efad134ab1320639dce9d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py @@ -0,0 +1,77 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSequenceExpandAs(OpTest): + def setUp(self): + self.op_type = 'sequence_expand_as' + self.set_data() + self.compute() + + def set_data(self): + x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + y_lod = [[1, 3, 4]] + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} + + def compute(self): + x = self.inputs['X'] + x_data, x_lod = x if type(x) == tuple else (x, None) + y_data, y_lod = self.inputs['Y'] + + assert len(y_lod) == 1 and len(y_lod[0]) == x_data.shape[0] + + repeats = [] + for i in range(len(y_lod[0])): + repeat_num = y_lod[0][i] + if repeat_num == 0: + continue + repeats.extend([i for _ in range(repeat_num)]) + + out_data = x_data[repeats] + self.outputs = {'Out': (out_data, y_lod)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceExpandAsCase1(TestSequenceExpandAs): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') + x_lod = [[2, 3]] + y_data = np.random.uniform(0.1, 1, [10, 1]).astype('float32') + y_lod = [[2, 2, 0, 3, 3]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +class TestSequenceExpandAsCase2(TestSequenceExpandAs): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') + x_lod = [[1]] + y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') + y_lod = [[2]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +if __name__ == '__main__': + unittest.main()