diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index dc162e248a8b3701c157f79155f4cee232144d9b..f61d1254fd1419490c483532ff15257e5d8f4507 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -154,6 +154,7 @@ paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', ' paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/operators/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..adb81bffccb50069b3a2e5f391f3fdfde231b2be --- /dev/null +++ b/paddle/fluid/operators/sequence_scatter_op.cc @@ -0,0 +1,156 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_scatter_op.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class SequenceScatterOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The source input of sequence scatter op"); + AddInput("Ids", + "(LoDTensor) The index input of sequence scatter op where X" + " will be updated, must be a LoDTensor"); + AddInput("Updates", + "(LoDTensor) The values to scatter to the input tensor " + "X, must be a LoDTensor with the same LoD information as Ids"); + AddOutput("Out", + "(Tensor) The output tensor of sequence scatter op, which " + "has the same dims as X"); + AddComment(R"DOC( +Sequence Scatter Operator. + +This operator scatters the Updates tensor to the input X. It uses the LoD +information of Ids to select the rows to update, and use the values in Ids as +the columns to update in each row of X. + +Following are cases to better explain how this works: + +Example 1: +Given an all-ones Tensor input(X) + X.data = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + X.dims = [3, 6] +a LoDTensor input(Ids) + Ids.data = [[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], [5], [4]] + Ids.lod = [[0, 3, 8, 12]] +and a Tensor input(Updates) + Updates.data = [[0.3], [0.3], [0.4], [0.1], [0.2], [0.3], [0.4], [0.0], [0.2], [0.3], [0.1], [0.4]] + Updates.lod = [[ 0, 3, 8, 12]] +then we get an output Tensor + Out.data = [[1.3, 1.3, 1.4, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.4, 1.3, 1.2, 1.1], + [1.0, 1.0, 1.3, 1.2, 1.4, 1.1]] + Out.dims = X.dims = [3, 6] +)DOC"); + } +}; + +class SequenceScatterOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + // Enforce has inputs and outputs + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of SequenceScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Updates"), + "Input(Updates) of SequenceScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceScatterOp should not be null."); + + // Set output dim the same as input + auto ref_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", ref_dims); + + // Enforce the Updates and Ids are the same shape + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], + ctx->GetInputDim("Ids")[0], + "Updates and Ids should have same shape."); + + // Enforce LoD of ids and updates be the same + if (ctx->IsRuntime()) { + framework::Variable* ids_var = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); + framework::Variable* updates_var = + boost::get(ctx->GetInputVarPtrs("Updates")[0]); + + auto& ids_lod = ids_var->Get().lod(); + auto& updates_lod = updates_var->Get().lod(); + PADDLE_ENFORCE_EQ(ids_lod.size(), 1, + "Currently only level 1 LoD could be" + " processed by sequence scatter op."); + PADDLE_ENFORCE_EQ(updates_lod.size(), 1, + "Currently only level 1 LoD " + "could be processed by sequence scatter op."); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class SequenceScatterGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp, + ops::SequenceScatterOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp); +REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel, + ops::SequenceScatterOpKernel, + ops::SequenceScatterOpKernel, + ops::SequenceScatterOpKernel); +REGISTER_OP_CPU_KERNEL(sequence_scatter_grad, + ops::SequenceScatterGradientOpKernel, + ops::SequenceScatterGradientOpKernel, + ops::SequenceScatterGradientOpKernel, + ops::SequenceScatterGradientOpKernel); diff --git a/paddle/fluid/operators/sequence_scatter_op.h b/paddle/fluid/operators/sequence_scatter_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d9b681b7aa76849a40d50e3348418d7604641c10 --- /dev/null +++ b/paddle/fluid/operators/sequence_scatter_op.h @@ -0,0 +1,122 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class SequenceScatterOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* ids = ctx.Input("Ids"); + auto* updates = ctx.Input("Updates"); + auto* out = ctx.Output("Out"); + + auto& ids_lod = ids->lod(); + + // Initialize out as same as x + out->mutable_data(ctx.GetPlace()); + framework::TensorCopySync(*x, ctx.GetPlace(), out); + + auto x_dims = x->dims(); + auto out_dims = out->dims(); + + for (int i = 0; i < x_dims.size(); ++i) + PADDLE_ENFORCE(x_dims[i] == out_dims[i], + "Input and output shape of " + "sequence scatter op must exactly be the same."); + + size_t slice_size = 1; + for (int i = 1; i < x_dims.size(); ++i) slice_size *= x_dims[i]; + + auto lod_vec = ids_lod[0]; + unsigned int seg = 0; + for (int i = 0; i < ids->dims()[0]; ++i) { + PADDLE_ENFORCE_LT(seg, lod_vec.size() - 1, + "Segment num must not exceed batch size.\n"); + int lower_bound = lod_vec[seg]; + int upper_bound = lod_vec[seg + 1]; + if (i >= lower_bound && i < upper_bound) { + T* p_out = out->data(); + const T* p_updates = updates->data(); + const int64_t* p_index = ids->data(); + p_out[seg * slice_size + p_index[i]] += p_updates[i]; + } else { + ++seg; + --i; + } + } + } +}; + +template +class SequenceScatterGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto* ids = ctx.Input("Ids"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + + auto& ids_lod = ids->lod(); + + dX->mutable_data(ctx.GetPlace()); + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + dUpdates->mutable_data(ctx.GetPlace()); + + auto dx_dims = dX->dims(); + auto dout_dims = dOut->dims(); + + for (int i = 0; i < dx_dims.size(); ++i) + PADDLE_ENFORCE(dx_dims[i] == dout_dims[i], + "Input and output shape of " + "sequence scatter grad op must exactly be the same."); + + size_t slice_size = 1; + for (int i = 1; i < dx_dims.size(); ++i) slice_size *= dx_dims[i]; + + auto lod_vec = ids_lod[0]; + unsigned int seg = 0; + + for (int i = 0; i < ids->dims()[0]; ++i) { + PADDLE_ENFORCE_LT(seg, lod_vec.size() - 1, + "Segment num must not exceed batch size.\n"); + int lower_bound = lod_vec[seg]; + int upper_bound = lod_vec[seg + 1]; + if (i >= lower_bound && i < upper_bound) { + const T* p_dOut = dOut->data(); + const int64_t* p_index = ids->data(); + T* p_dUpdates = dUpdates->data(); + p_dUpdates[i] = p_dOut[seg * slice_size + p_index[i]]; + } else { + ++seg; + --i; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c7df815175c912d3e0e476a53c370d42cf45e5e0..f896cfa04b3e0d89daaa1bd7fd893b5892a09a4e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -100,6 +100,7 @@ __all__ = [ 'resize_bilinear', 'gather', 'scatter', + 'sequence_scatter', 'random_crop', 'mean_iou', 'relu', @@ -5425,6 +5426,66 @@ def scatter(input, index, updates, name=None): return out +def sequence_scatter(input, index, updates, name=None): + """ + **Sequence Scatter Layer** + + This operator scatters the Updates tensor to the input X. It uses the LoD + information of Ids to select the rows to update, and use the values in Ids as + the columns to update in each row of X. + + Here is an example: + Given the following input: + .. code-block:: text + input.data = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + input.dims = [3, 6] + + index.data = [[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], [5], [4]] + index.lod = [[0, 3, 8, 12]] + + updates.data = [[0.3], [0.3], [0.4], [0.1], [0.2], [0.3], [0.4], [0.0], [0.2], [0.3], [0.1], [0.4]] + updates.lod = [[ 0, 3, 8, 12]] + + Then we have the output: + .. code-block:: text + out.data = [[1.3, 1.3, 1.4, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.4, 1.3, 1.2, 1.1], + [1.0, 1.0, 1.3, 1.2, 1.4, 1.1]] + out.dims = X.dims = [3, 6] + + Args: + input (Variable): The source input with rank>=1. + index (Variable): A LoD Tensor. The index input of sequence scatter op + where input will be updated. The index input with rank=1. Its dtype + should be int32 or int64 as it is used as indexes. + updates (Variable): A LoD Tensor. The values to scatter to the input + tensor X, must be a LoDTensor with the same LoD information as index. + name (str|None): The output variable name. Default None. + + Returns: + output (Variable): The output is a tensor with the same shape as input. + + Examples: + + .. code-block:: python + + output = fluid.layers.sequence_scatter(input, index, updates) + + """ + helper = LayerHelper('sequence_scatter', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="sequence_scatter", + inputs={"X": input, + "Ids": index, + "Updates": updates}, + outputs={"Out": out}) + return out + + @templatedoc() def random_crop(x, shape, seed=None): """ diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 7a97d907f42eba3f39c7366ce0aaa29f2c3270b1..9a17d3213c902e0c18123097025349434d271d7f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -382,6 +382,30 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_sequence_scatter(self): + program = Program() + with program_guard(program): + x = layers.data( + name='x', + shape=[3, 6], + append_batch_size=False, + dtype='float32') + idx = layers.data( + name='idx', + shape=[12, 1], + append_batch_size=False, + dtype='int32', + lod_level=1) + updates = layers.data( + name='updates', + shape=[12, 1], + append_batch_size=False, + dtype='float32', + lod_level=1) + out = layers.sequence_scatter(input=x, index=idx, updates=updates) + self.assertIsNotNone(out) + print(str(program)) + def test_lod_reset(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py b/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d239e9c798745cbb3dda9df56dbd717aab74ed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py @@ -0,0 +1,51 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSequenceScatterOp(OpTest): + def setUp(self): + self.op_type = "sequence_scatter" + + X_data = np.random.uniform(0.1, 1.0, [3, 6]).astype('float32') + Ids_data = np.array([[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], + [5], [4]]).astype('int64') + Ids_lod = [[3, 5, 4]] + Updates_data = np.random.uniform(0.1, 1.0, [12, 1]).astype('float32') + Updates_lod = Ids_lod + + Out_data = np.copy(X_data) + Out_data[0][Ids_data[0:3]] += Updates_data[0:3] + Out_data[1][Ids_data[3:8]] += Updates_data[3:8] + Out_data[2][Ids_data[8:]] += Updates_data[8:] + + self.inputs = { + 'X': X_data, + 'Ids': (Ids_data, Ids_lod), + 'Updates': (Updates_data, Updates_lod) + } + self.outputs = {'Out': Out_data} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +if __name__ == "__main__": + unittest.main()