diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b4bed0a188a5f721b99363ca2580c95595f62b94..b8658b17be2167e5959a751d7cc9ef99613eafca 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -196,6 +196,8 @@ paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c')) paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9')) paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) +paddle.fluid.layers.scatter_nd_add (ArgSpec(args=['ref', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c2fa5ee7484b52b95a28abf1d8827cd0')) +paddle.fluid.layers.scatter_nd (ArgSpec(args=['index', 'updates', 'shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '14b5449ce42f8ff4ac4ce79b41c86cc5')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf')) paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', '042af0b8abea96b40c22f6e70d99e042')) paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e714b4aa7993dfe9c1a38886875dbaac')) diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index 43699f57b6c8d857684efcaca8a1cd91dd5aecff..aed0f824e6966b2d15e50bddbef4f782566420c4 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -38,8 +38,9 @@ class GatherNdOp : public framework::OperatorWithKernel { auto index_dims = ctx->GetInputDim("Index"); auto index_dims_size = index_dims.size(); - PADDLE_ENFORCE_LE(index_dims[index_dims_size - 1], x_dims_size, - "Input(Index).shape[-1] <= Input(X).rank"); + PADDLE_ENFORCE_LE( + index_dims[index_dims_size - 1], x_dims_size, + "Input(Index).shape[-1] should be no greater than Input(X).rank"); PADDLE_ENFORCE_GE(index_dims_size, 2UL, "The rank of Input(Index) should be greater than 1"); diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index f4aabd4618742174b0cc977fc79a4a6bb046d30b..f8d08b2e44c9626f92337e0f20c7517432125349 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..41f18eaeaf8bd894282929321a483ef5859c5895 --- /dev/null +++ b/paddle/fluid/operators/scatter_nd_add_op.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/scatter_nd_add_op.h" +#include +#include +#include "paddle/fluid/framework/ddim.h" + +namespace paddle { +namespace operators { + +class ScatterNdAddOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of ScatterNdAddOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, + "Input(Index) of ScatterNdAddOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true, + "Input(Updates) of ScatterNdAddOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of ScatterNdAddOp should not be null."); + + auto ref_dims = ctx->GetInputDim("X"); + auto ref_dims_size = ref_dims.size(); + auto index_dims = ctx->GetInputDim("Index"); + auto index_dims_size = index_dims.size(); + auto updates_dims = ctx->GetInputDim("Updates"); + auto updates_dims_size = updates_dims.size(); + + PADDLE_ENFORCE_LE( + index_dims[index_dims_size - 1], ref_dims_size, + "Input(Index).shape[-1] should be no greater than Input(X).rank"); + PADDLE_ENFORCE_GE(index_dims_size, 2UL, + "The rank of Input(Index) should be greater than 1"); + + // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] + std::vector r_updates_dims; + for (int64_t i = 0; i < index_dims_size - 1; ++i) { + r_updates_dims.emplace_back(index_dims[i]); + } + for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) { + r_updates_dims.emplace_back(ref_dims[i]); + } + + PADDLE_ENFORCE_EQ(r_updates_dims.size(), updates_dims_size, + "Updates has wrong shape"); + + for (int64_t i = 0; i < updates_dims_size; ++i) { + PADDLE_ENFORCE_EQ(r_updates_dims[i], updates_dims[i], + "Updates has wrong shape"); + } + ctx->SetOutputDim("Out", ref_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), + ctx.Input("Updates")->type(), + "Ref and Updates must have same type"); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class ScatterNdAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->HasOutput(framework::GradVarName("Updates"))) { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); + } +}; + +class ScatterNdAddOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The source input of scatter_nd_add op"); + AddInput("Index", + "The index input of scatter_nd_add op where X will be updated"); + AddInput("Updates", "The updated value of scatter_nd_add op"); + AddOutput("Out", "The output of scatter_nd_add op"); + AddComment(R"DOC( +Scatter_nd_add Operator. + +Output is obtained by applying sparse addition to a single value or slice in a Variable. + + Given: + * Case 1: + ref = [0, 1, 2, 3, 4, 5] + index = [[1], [2], [3], [1]] + updates = [9, 10, 11, 12] + + we get: + + output = [0, 22, 12, 14, 4, 5] + + * Case 2: + ref = [[65, 17], [-14, -25]] + index = [[], []] + updates = [[[-1, -2], [1, 2]], + [[3, 4], [-3, -4]]] + ref.shape = (2, 2) + index.shape = (2, 0) + updates.shape = (2, 2, 2) + + we get: + + output = [[67, 19], [-16, -27]] +)DOC"); + } +}; + +class ScatterNdAddGradDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("scatter_nd_add_grad"); + op->SetInput("Index", Input("Index")); + op->SetInput("Updates", Input("Updates")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ScatterNdAddGradNoNeedBufferVarsInference, + "Updates"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(scatter_nd_add, ops::ScatterNdAddOp, ops::ScatterNdAddOpMaker, + ops::ScatterNdAddGradDescMaker); + +REGISTER_OPERATOR(scatter_nd_add_grad, ops::ScatterNdAddGradOp, + ops::ScatterNdAddGradNoNeedBufferVarsInference); + +REGISTER_OP_CPU_KERNEL(scatter_nd_add, ops::ScatterNdAddOpKernel, + ops::ScatterNdAddOpKernel, + ops::ScatterNdAddOpKernel, + ops::ScatterNdAddOpKernel, + ops::ScatterNdAddOpKernel); + +REGISTER_OP_CPU_KERNEL(scatter_nd_add_grad, + ops::ScatterNdAddGradientOpKernel, + ops::ScatterNdAddGradientOpKernel, + ops::ScatterNdAddGradientOpKernel, + ops::ScatterNdAddGradientOpKernel, + ops::ScatterNdAddGradientOpKernel); diff --git a/paddle/fluid/operators/scatter_nd_add_op.cu b/paddle/fluid/operators/scatter_nd_add_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ecd9beb10cf9d73e014510ff8c628e5d5b6a2a73 --- /dev/null +++ b/paddle/fluid/operators/scatter_nd_add_op.cu @@ -0,0 +1,98 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/gather_op.h" +#include "paddle/fluid/operators/scatter.cu.h" +#include "paddle/fluid/operators/scatter_nd_add_op.h" + +namespace paddle { +namespace operators { + +template +class ScatterNdAddOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "This kernel only runs on GPU device."); + auto *X = ctx.Input("X"); + auto *Ids = ctx.Input("Index"); + auto *Updates = ctx.Input("Updates"); + auto *Out = ctx.Output("Out"); + + framework::TensorCopySync(*X, ctx.GetPlace(), Out); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUScatterNdAdd(ctx, *Updates, *Ids, Out); + } else { + GPUScatterNdAdd(ctx, *Updates, *Ids, Out); + } + } +}; + +template +class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "This kernel only runs on GPU device."); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Ids = ctx.Input("Index"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather + const auto &index_type = Ids->type(); + if (index_type == framework::proto::VarType::INT32) { + GPUGatherNd(ctx, *dOut, *Ids, dUpdates); + } else { + GPUGatherNd(ctx, *dOut, *Ids, dUpdates); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(scatter_nd_add, + ops::ScatterNdAddOpCUDAKernel, + ops::ScatterNdAddOpCUDAKernel, + ops::ScatterNdAddOpCUDAKernel, + ops::ScatterNdAddOpCUDAKernel, + ops::ScatterNdAddOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(scatter_nd_add_grad, + ops::ScatterNdAddGradOpCUDAKernel, + ops::ScatterNdAddGradOpCUDAKernel, + ops::ScatterNdAddGradOpCUDAKernel, + ops::ScatterNdAddGradOpCUDAKernel, + ops::ScatterNdAddGradOpCUDAKernel); diff --git a/paddle/fluid/operators/scatter_nd_add_op.h b/paddle/fluid/operators/scatter_nd_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4b90fa1cf50b003fe32148ccb65185bae71c5fa0 --- /dev/null +++ b/paddle/fluid/operators/scatter_nd_add_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ScatterNdAddOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "This kernel only runs on CPU."); + auto *X = ctx.Input("X"); + auto *Ids = ctx.Input("Index"); + auto *Updates = ctx.Input("Updates"); + auto *Out = ctx.Output("Out"); + + // In place output: Out = X + framework::TensorCopySync(*X, ctx.GetPlace(), Out); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + + if (index_type == framework::proto::VarType::INT32) { + ScatterNdAdd(ctx, *Updates, *Ids, Out); + } else { + ScatterNdAdd(ctx, *Updates, *Ids, Out); + } + } +}; + +template +class ScatterNdAddGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "This kernel only runs on CPU."); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Ids = ctx.Input("Index"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + + if (dX) { + // In place gradient: dX = dO + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + const auto &index_type = Ids->type(); + if (index_type == framework::proto::VarType::INT32) { + CPUGatherNd(ctx.device_context(), *dOut, *Ids, dUpdates); + } else { + CPUGatherNd(ctx.device_context(), *dOut, *Ids, dUpdates); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 932054dbf269537376056131dda30150770df843..d78d94454e65b306b3c7c4bddef05c08dd7a4d74 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -28,7 +28,7 @@ from ..framework import Variable, OpProtoHolder, in_dygraph_mode from ..dygraph import base from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ -from .tensor import concat, assign, fill_constant +from .tensor import concat, assign, fill_constant, zeros from . import utils from .. import unique_name from functools import reduce @@ -124,6 +124,8 @@ __all__ = [ 'gather', 'gather_nd', 'scatter', + 'scatter_nd_add', + 'scatter_nd', 'sequence_scatter', 'random_crop', 'mean_iou', @@ -8686,6 +8688,127 @@ def scatter(input, index, updates, name=None, overwrite=True): return out +def scatter_nd_add(ref, index, updates, name=None): + """ + **Scatter_nd_add Layer** + + Output is obtained by applying sparse addition to a single value + or slice in a Variable. :attr:`ref` is a Tensor with rank :math:`R` + and :attr:`index` is a Tensor with rank :math:`K` . Thus, :attr:`index` + has shape :math:`[i_0, i_1, ..., i_{K-2}, Q]` where :math:`Q \leq R` . :attr:`updates` + is a Tensor with rank :math:`K - 1 + R - Q` and its + shape is :math:`index.shape[:-1] + ref.shape[index.shape[-1]:]` . + According to the :math:`[i_0, i_1, ..., i_{K-2}]` of :attr:`index` , + add the corresponding :attr:`updates` slice to the :attr:`ref` slice + which is obtained by the last one dimension of :attr:`index` . + + .. code-block:: text + + Given: + + * Case 1: + ref = [0, 1, 2, 3, 4, 5] + index = [[1], [2], [3], [1]] + updates = [9, 10, 11, 12] + + we get: + + output = [0, 22, 12, 14, 4, 5] + + * Case 2: + ref = [[65, 17], [-14, -25]] + index = [[], []] + updates = [[[-1, -2], [1, 2]], + [[3, 4], [-3, -4]]] + ref.shape = (2, 2) + index.shape = (2, 0) + updates.shape = (2, 2, 2) + + we get: + + output = [[67, 19], [-16, -27]] + + Args: + ref (Variable): The ref input. + index (Variable): The index input with rank > 1 and index.shape[-1] <= ref.rank. + Its dtype should be int32 or int64 as it is used as indexes. + updates (Variable): The updated value of scatter_nd_add op, and it must have the same type + as ref. It must have the shape index.shape[:-1] + ref.shape[index.shape[-1]:] + name (str|None): The output variable name. Default None. + + Returns: + output (Variable): The output is a tensor with the same shape and type as ref. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + ref = fluid.layers.data(name='ref', shape=[3, 5, 9, 10], dtype='float32', append_batch_size=False) + index = fluid.layers.data(name='index', shape=[3, 2], dtype='int32', append_batch_size=False) + updates = fluid.layers.data(name='update', shape=[3, 9, 10], dtype='float32', append_batch_size=False) + + output = fluid.layers.scatter_nd_add(ref, index, updates) + """ + if ref.dtype != updates.dtype: + raise ValueError("ref and updates must have same data type.") + + helper = LayerHelper('scatter_nd_add', **locals()) + dtype = helper.input_dtype() + if name is None: + output = helper.create_variable_for_type_inference(dtype) + else: + output = helper.create_variable( + name=name, dtype=dtype, persistable=False) + helper.append_op( + type="scatter_nd_add", + inputs={"X": ref, + "Index": index, + "Updates": updates}, + outputs={"Out": output}) + return output + + +def scatter_nd(index, updates, shape, name=None): + """ + **Scatter_nd Layer** + + Output is obtained by scattering the :attr:`updates` in a new tensor according + to :attr:`index` . This op is similar to :code:`scatter_nd_add`, except the + tensor of :attr:`shape` is zero-initialized. Correspondingly, :code:`scatter_nd(index, updates, shape)` + is equal to :code:`scatter_nd_add(fluid.layers.zeros(shape, updates.dtype), index, updates)` . + If :attr:`index` has repeated elements, then the corresponding updates are accumulated. + Because of the numerical approximation issues, the different order of repeated elements + in :attr:`index` may cause different results. The specific calculation method can be + seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op. + + Args: + index (Variable): The index input with rank > 1 and index.shape[-1] <= len(shape). + Its dtype should be int32 or int64 as it is used as indexes. + updates (Variable): The updated value of scatter_nd op. + It must have the shape index.shape[:-1] + shape[index.shape[-1]:] + shape(tuple|list): Shape of output tensor. + name (str|None): The output variable name. Default None. + + Returns: + output (Variable): The output is a tensor with the same type as :attr:`updates` . + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + index = fluid.layers.data(name='index', shape=[3, 2], dtype='int64', append_batch_size=False) + updates = fluid.layers.data(name='update', shape=[3, 9, 10], dtype='float32', append_batch_size=False) + shape = [3, 5, 9, 10] + + output = fluid.layers.scatter_nd(index, updates, shape) + """ + return scatter_nd_add(zeros(shape, updates.dtype), index, updates, name) + + def sequence_scatter(input, index, updates, name=None): """ **Sequence Scatter Layer** diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index 3264b2aff44ebf6bab9e2cc0c19bd904a082b1e5..357b2dc060742e545ce896b60373fe39badf4e3d 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -158,7 +158,7 @@ class TestGatherNdOpRaise(OpTest): output = fluid.layers.gather_nd(x, index) except Exception as e: t = \ - "Input(Index).shape[-1] <= Input(X).rank" + "Input(Index).shape[-1] should be no greater than Input(X).rank" if t in str(e): raise IndexError diff --git a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py new file mode 100644 index 0000000000000000000000000000000000000000..dec9bfa43005cc11af5ff58e10d171ec69125bd9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py @@ -0,0 +1,291 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid + + +def numpy_scatter_nd(ref, index, updates, fun): + ref_shape = ref.shape + index_shape = index.shape + + end_size = index_shape[-1] + remain_numl = 1 + for i in range(len(index_shape) - 1): + remain_numl *= index_shape[i] + + slice_size = 1 + for i in range(end_size, len(ref_shape)): + slice_size *= ref_shape[i] + + flat_index = index.reshape([remain_numl] + list(index_shape[-1:])) + flat_updates = updates.reshape((remain_numl, slice_size)) + flat_output = ref.reshape(list(ref_shape[:end_size]) + [slice_size]) + + for i_up, i_out in enumerate(flat_index): + i_out = tuple(i_out) + flat_output[i_out] = fun(flat_output[i_out], flat_updates[i_up]) + return flat_output.reshape(ref.shape) + + +def numpy_scatter_nd_add(ref, index, updates): + return numpy_scatter_nd(ref, index, updates, lambda x, y: x + y) + + +def judge_update_shape(ref, index): + ref_shape = ref.shape + index_shape = index.shape + update_shape = [] + for i in range(len(index_shape) - 1): + update_shape.append(index_shape[i]) + for i in range(index_shape[-1], len(ref_shape), 1): + update_shape.append(ref_shape[i]) + return update_shape + + +class TestScatterNdAddSimpleOp(OpTest): + """ + A simple example + """ + + def setUp(self): + self.op_type = "scatter_nd_add" + ref_np = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]).astype("float32") + index_np = np.array([[1], [2], [3], [5], [1]]).astype("int32") + updates_np = np.array([9, 10, 11, 12, 13]).astype("float32") + expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) + #expect_np = [ 0. 23. 12. 14. 4. 17. 6. 7. 8.] + + self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} + self.outputs = {'Out': expect_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +class TestScatterNdAddWithEmptyIndex(OpTest): + """ + Index has empty element + """ + + def setUp(self): + self.op_type = "scatter_nd_add" + ref_np = np.array([[65, 17], [-14, -25]]).astype("float32") + index_np = np.array([[], []]).astype("int32") + updates_np = np.array([[[-1, -2], [1, 2]], + [[3, 4], [-3, -4]]]).astype("float32") + + expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) + #expect_np = [[67, 19], [-16, -27]] + + self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} + self.outputs = {'Out': expect_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + +class TestScatterNdAddWithHighRankSame(OpTest): + """ + Both Index and X have high rank, and Rank(Index) = Rank(X) + """ + + def setUp(self): + self.op_type = "scatter_nd_add" + shape = (10, 9, 8, 1, 15) + ref_np = np.random.rand(*shape).astype("float32") + index_np = np.vstack( + [np.random.randint( + 0, s, size=150) for s in shape]).T.astype("int32") + update_shape = judge_update_shape(ref_np, index_np) + updates_np = np.random.rand(*update_shape).astype("float32") + expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) + + self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} + self.outputs = {'Out': expect_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +class TestScatterNdAddWithHighRankDiff(OpTest): + """ + Both Index and X have high rank, and Rank(Index) < Rank(X) + """ + + def setUp(self): + self.op_type = "scatter_nd_add" + shape = (10, 9, 8, 1, 15) + ref_np = np.random.rand(*shape).astype("double") + index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T + index_np = index.reshape([10, 5, 10, 5]).astype("int64") + update_shape = judge_update_shape(ref_np, index_np) + updates_np = np.random.rand(*update_shape).astype("double") + expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) + + self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} + self.outputs = {'Out': expect_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +#Test Python API +class TestScatterNdOpAPI(OpTest): + """ + test scatter_nd_add api and scatter_nd api + """ + + def testcase1(self): + ref1 = fluid.layers.data( + name='ref1', + shape=[10, 9, 8, 1, 3], + dtype='float32', + append_batch_size=False) + index1 = fluid.layers.data( + name='index1', + shape=[5, 5, 8, 5], + dtype='int32', + append_batch_size=False) + updates1 = fluid.layers.data( + name='update1', + shape=[5, 5, 8], + dtype='float32', + append_batch_size=False) + output1 = fluid.layers.scatter_nd_add(ref1, index1, updates1) + + def testcase2(self): + ref2 = fluid.layers.data( + name='ref2', + shape=[10, 9, 8, 1, 3], + dtype='double', + append_batch_size=False) + index2 = fluid.layers.data( + name='index2', + shape=[5, 8, 5], + dtype='int32', + append_batch_size=False) + updates2 = fluid.layers.data( + name='update2', + shape=[5, 8], + dtype='double', + append_batch_size=False) + output2 = fluid.layers.scatter_nd_add( + ref2, index2, updates2, name="scatter_nd_add") + + def testcase3(self): + shape3 = [10, 9, 8, 1, 3] + index3 = fluid.layers.data( + name='index3', + shape=[5, 5, 8, 5], + dtype='int32', + append_batch_size=False) + updates3 = fluid.layers.data( + name='update3', + shape=[5, 5, 8], + dtype='float32', + append_batch_size=False) + output3 = fluid.layers.scatter_nd(index3, updates3, shape3) + + def testcase4(self): + shape4 = [10, 9, 8, 1, 3] + index4 = fluid.layers.data( + name='index4', + shape=[5, 5, 8, 5], + dtype='int32', + append_batch_size=False) + updates4 = fluid.layers.data( + name='update4', + shape=[5, 5, 8], + dtype='double', + append_batch_size=False) + output4 = fluid.layers.scatter_nd( + index4, updates4, shape4, name='scatter_nd') + + +#Test Raise Error +class TestScatterNdOpRaise(OpTest): + def test_check_raise(self): + def check_raise_is_test(): + try: + ref5 = fluid.layers.data( + name='ref5', shape=[3, 4, 5], dtype='float32') + index5 = fluid.layers.data( + name='index5', shape=[2, 10], dtype='int32') + updates5 = fluid.layers.data( + name='updates5', shape=[2, 10], dtype='float32') + output5 = fluid.layers.scatter_nd_add(ref5, index5, updates5) + except Exception as e: + t = \ + "Input(Index).shape[-1] should be no greater than Input(X).rank" + if t in str(e): + raise IndexError + + self.assertRaises(IndexError, check_raise_is_test) + + def test_check_raise2(self): + with self.assertRaises(ValueError): + ref6 = fluid.layers.data( + name='ref6', + shape=[10, 9, 8, 1, 3], + dtype='double', + append_batch_size=False) + index6 = fluid.layers.data( + name='index6', + shape=[5, 8, 5], + dtype='int32', + append_batch_size=False) + updates6 = fluid.layers.data( + name='update6', + shape=[5, 8], + dtype='float32', + append_batch_size=False) + output6 = fluid.layers.scatter_nd_add(ref6, index6, updates6) + + def test_check_raise3(self): + def check_raise_is_test(): + try: + shape = [3, 4, 5] + index7 = fluid.layers.data( + name='index7', shape=[2, 1], dtype='int32') + updates7 = fluid.layers.data( + name='updates7', shape=[2, 4, 5, 20], dtype='float32') + output7 = fluid.layers.scatter_nd(index7, updates7, shape) + except Exception as e: + t = \ + "Updates has wrong shape" + if t in str(e): + raise ValueError + + self.assertRaises(ValueError, check_raise_is_test) + + +if __name__ == "__main__": + unittest.main()