From 693083a461c8d60dc0554edee828d41d3eaeb1be Mon Sep 17 00:00:00 2001 From: Chengmo Date: Wed, 1 Jul 2020 19:34:04 +0800 Subject: [PATCH] add index sample (#25260) * test=release/1.8, add index sample --- python/paddle/fluid/layers/nn.py | 82 +++++++++++++++++++ .../tests/unittests/test_index_sample_op.py | 28 +++++++ 2 files changed, 110 insertions(+) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 05cdacb0cd..d76d4e6bc0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -212,6 +212,7 @@ __all__ = [ 'flip', 'roll', 'log_softmax', + 'index_sample', ] @@ -16555,3 +16556,84 @@ def log_softmax(input, axis=None, dtype=None, name=None): type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) return outs_log + + +def index_sample(x, index): + """ + **IndexSample Layer** + IndexSample OP returns the element of the specified location of X, + and the location is specified by Index. + + .. code-block:: text + + Args: + x (Variable): The source input tensor with 2-D shape. Supported data type is + int32, int64, float32, float64. + index (Variable): The index input tensor with 2-D shape, first dimension should be same with X. + Data type is int32 or int64. + + Returns: + Variable: A tensor with the same shape as `index` . + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + data = np.array([[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]]).astype('float32') + + data_index = np.array([[0, 1, 2], + [1, 2, 3], + [0, 0, 0]]).astype('int32') + + target_data = np.array([[100, 200, 300, 400], + [500, 600, 700, 800], + [900, 1000, 1100, 1200]]).astype('int32') + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(data) + index = fluid.dygraph.to_variable(data_index) + target = fluid.dygraph.to_variable(target_data) + + out_z1 = fluid.layers.index_sample(x, index) + print(out_z1.numpy()) + #[[1. 2. 3.] + # [6. 7. 8.] + # [9. 9. 9.]] + + # Use the index of the maximum value by topk op + # get the value of the element of the corresponding index in other tensors + top_value, top_index = fluid.layers.topk(x, k=2) + out_z2 = fluid.layers.index_sample(target, top_index) + print(top_value.numpy()) + #[[ 4. 3.] + # [ 8. 7.] + # [12. 11.]] + + print(top_index.numpy()) + #[[3 2] + # [3 2] + # [3 2]] + + print(out_z2.numpy()) + #[[ 400 300] + # [ 800 700] + # [1200 1100]] + """ + helper = LayerHelper("index_sample", **locals()) + + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'fluid.layers.index_sample') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], + 'fluid.layers.index_sample') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='index_sample', + inputs={'X': x, + 'Index': index}, + outputs={'Out': out}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index 750084e881..dc851fe3cf 100644 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -96,5 +96,33 @@ class TestCase4(TestIndexSampleOp): self.index_type = "int64" +class TestIndexSampleShape(unittest.TestCase): + def test_shape(self): + import paddle.fluid as fluid + import paddle + + # create x value + x_shape = (2, 5) + x_type = "float64" + x_np = np.random.random(x_shape).astype(x_type) + + # create index value + index_shape = (2, 3) + index_type = "int32" + index_np = np.random.randint( + low=0, high=x_shape[1], size=index_shape).astype(index_type) + + x = fluid.data(name='x', shape=[-1, 5], dtype='float64') + index = fluid.data(name='index', shape=[-1, 3], dtype='int32') + output = fluid.layers.index_sample(x=x, index=index) + + place = fluid.CPUPlace() + exe = fluid.Executor(place=place) + exe.run(fluid.default_startup_program()) + + feed = {'x': x_np, 'index': index_np} + res = exe.run(feed=feed, fetch_list=[output]) + + if __name__ == "__main__": unittest.main() -- GitLab