From 86d8659c8de7e91c066935e723da29f31ffd6364 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 1 Jun 2018 15:14:08 +0800 Subject: [PATCH] Add python wrapper for gather op. (#11033) * Add python wrapper for gather op. * Add unitest for 'rank==1' and fix comments. * Fix comments. --- doc/fluid/api/layers.rst | 6 +++ paddle/fluid/operators/gather_op.cc | 1 - python/paddle/fluid/layers/nn.py | 51 ++++++++++++++++++- .../fluid/tests/unittests/test_gather_op.py | 15 +++++- 4 files changed, 69 insertions(+), 4 deletions(-) diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index f53da4d194..dbb99d3c03 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -1009,3 +1009,9 @@ ____ .. autofunction:: paddle.fluid.layers.upsampling_bilinear2d :noindex: +gather +____ + +.. autofunction:: paddle.fluid.layers.gather + :noindex: + diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index e21b572589..aa3e05b83b 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel { auto index_dims = ctx->GetInputDim("Index"); PADDLE_ENFORCE(index_dims.size() == 1); int batch_size = ctx->GetInputDim("Index")[0]; - PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); framework::DDim output_dims(ctx->GetInputDim("X")); output_dims[0] = batch_size; ctx->SetOutputDim("Out", output_dims); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cb87653c47..56f5c6b4be 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -82,6 +82,7 @@ __all__ = [ 'roi_pool', 'dice_loss', 'upsampling_bilinear2d', + 'gather', 'random_crop', ] @@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): def dice_loss(input, label, epsilon=0.00001): """ - **Dice loss Layer** Dice loss for comparing the similarity of two batch of data, usually is used for binary image segmentation i.e. labels are binary. The dice loss can be defined as below equation: @@ -3999,6 +3999,55 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): return out +def gather(input, index): + """ + Output is obtained by gathering entries of the outer-most dimension + of X indexed by `index` and concatenate them together. + + .. math:: + + Out = X[Index] + + + .. code-block:: text + + + Given: + + X = [[1, 2], + [3, 4], + [5, 6]] + + Index = [1, 2] + + Then: + + Out = [[3, 4], + [5, 6]] + + Args: + input (Variable): The source input with rank>=1. + index (Variable): The index input with rank=1. + + Returns: + output (Variable): The output is a tensor with the same rank as input. + + Examples: + .. code-block:: python + + output = fluid.layers.gather(x, index) + """ + helper = LayerHelper('gather', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="gather", + inputs={"X": input, + "Index": index}, + outputs={"Out": out}) + return out + + def random_crop(input, shape, seed=1): helper = LayerHelper("random_crop", **locals()) dtype = helper.input_dtype() diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 6fd043c27e..4ae9086480 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -20,8 +20,9 @@ from op_test import OpTest class TestGatherOp(OpTest): def setUp(self): self.op_type = "gather" - xnp = np.random.random((10, 20)).astype("float32") - self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")} + self.config() + xnp = np.random.random(self.x_shape).astype("float32") + self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")} self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): @@ -30,6 +31,16 @@ class TestGatherOp(OpTest): def test_check_grad(self): self.check_grad(['X'], 'Out') + def config(self): + self.x_shape = (10, 20) + self.index = [1, 3, 5] + + +class TestCase1(TestGatherOp): + def config(self): + self.x_shape = (10) + self.index = [1, 3, 5] + if __name__ == "__main__": unittest.main() -- GitLab