未验证 提交 86d8659c 编写于 作者: W whs 提交者: GitHub

Add python wrapper for gather op. (#11033)

* Add python wrapper for gather op.

* Add unitest for 'rank==1' and fix comments.

* Fix comments.
上级 28dc9ba3
...@@ -1009,3 +1009,9 @@ ____ ...@@ -1009,3 +1009,9 @@ ____
.. autofunction:: paddle.fluid.layers.upsampling_bilinear2d .. autofunction:: paddle.fluid.layers.upsampling_bilinear2d
:noindex: :noindex:
gather
____
.. autofunction:: paddle.fluid.layers.gather
:noindex:
...@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel {
auto index_dims = ctx->GetInputDim("Index"); auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1); PADDLE_ENFORCE(index_dims.size() == 1);
int batch_size = ctx->GetInputDim("Index")[0]; int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
......
...@@ -82,6 +82,7 @@ __all__ = [ ...@@ -82,6 +82,7 @@ __all__ = [
'roi_pool', 'roi_pool',
'dice_loss', 'dice_loss',
'upsampling_bilinear2d', 'upsampling_bilinear2d',
'gather',
'random_crop', 'random_crop',
] ]
...@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): ...@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
def dice_loss(input, label, epsilon=0.00001): def dice_loss(input, label, epsilon=0.00001):
""" """
**Dice loss Layer**
Dice loss for comparing the similarity of two batch of data, Dice loss for comparing the similarity of two batch of data,
usually is used for binary image segmentation i.e. labels are binary. usually is used for binary image segmentation i.e. labels are binary.
The dice loss can be defined as below equation: The dice loss can be defined as below equation:
...@@ -3999,6 +3999,55 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): ...@@ -3999,6 +3999,55 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
return out return out
def gather(input, index):
"""
Output is obtained by gathering entries of the outer-most dimension
of X indexed by `index` and concatenate them together.
.. math::
Out = X[Index]
.. code-block:: text
Given:
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [1, 2]
Then:
Out = [[3, 4],
[5, 6]]
Args:
input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1.
Returns:
output (Variable): The output is a tensor with the same rank as input.
Examples:
.. code-block:: python
output = fluid.layers.gather(x, index)
"""
helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
helper.append_op(
type="gather",
inputs={"X": input,
"Index": index},
outputs={"Out": out})
return out
def random_crop(input, shape, seed=1): def random_crop(input, shape, seed=1):
helper = LayerHelper("random_crop", **locals()) helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -20,8 +20,9 @@ from op_test import OpTest ...@@ -20,8 +20,9 @@ from op_test import OpTest
class TestGatherOp(OpTest): class TestGatherOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather" self.op_type = "gather"
xnp = np.random.random((10, 20)).astype("float32") self.config()
self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")} xnp = np.random.random(self.x_shape).astype("float32")
self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self): def test_check_output(self):
...@@ -30,6 +31,16 @@ class TestGatherOp(OpTest): ...@@ -30,6 +31,16 @@ class TestGatherOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def config(self):
self.x_shape = (10, 20)
self.index = [1, 3, 5]
class TestCase1(TestGatherOp):
def config(self):
self.x_shape = (10)
self.index = [1, 3, 5]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册