未验证 提交 3d11d018 编写于 作者: Q Qingsheng Li 提交者: GitHub

Fix scatter_op python API (#12742)

* Fix scatter_op python API and remove inconsistency between implementation and doc

* API spec change

* Change as review comment
上级 3ae97aab
......@@ -153,6 +153,7 @@ paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'n
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
......@@ -250,7 +251,6 @@ paddle.fluid.layers.logical_not ArgSpec(args=[], varargs='args', keywords='kwarg
paddle.fluid.layers.uniform_random_batch_size_like ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.gaussian_random ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.gaussian_random_batch_size_like ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.sum ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.slice ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.shape ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
......
......@@ -81,8 +81,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op");
AddInput("Updates", "The updated value of scatter op");
AddOutput("Out", "The output of scatter op");
AddComment(R"DOC(
Scatter Operator.
......@@ -90,7 +90,7 @@ This operator obtains output by updating the input on selected indices on the fi
$$
Out = X \\
Out[Ids] = X[Ids] + Updates
Out[Ids] = Updates
$$
)DOC");
......
......@@ -34,9 +34,9 @@ class ScatterOpKernel : public framework::OpKernel<T> {
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = X, Out[Ids] += Updates
// In place output: Out = X, Out[Ids] = Updates
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
// Apply ScatterUpdate: Out[index] += Updates[:]
// Apply ScatterUpdate: Out[index] = Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
}
};
......@@ -55,7 +55,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Ids]
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
};
......
......@@ -94,6 +94,7 @@ __all__ = [
'image_resize_short',
'resize_bilinear',
'gather',
'scatter',
'random_crop',
'mean_iou',
'relu',
......@@ -5036,6 +5037,47 @@ def gather(input, index):
return out
def scatter(input, index, updates, name=None):
"""
**Scatter Layer**
Output is obtained by updating the input on selected indices on the first
axis.
.. math::
Out = X
Out[Ids] = Updates
Args:
input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1. Its dtype should be
int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter op.
name (str|None): The output variable name. Default None.
Returns:
output (Variable): The output is a tensor with the same shape as input.
Examples:
.. code-block:: python
output = fluid.layers.scatter(input, index, updates)
"""
helper = LayerHelper('scatter', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
helper.append_op(
type="scatter",
inputs={"X": input,
"Ids": index,
"Updates": updates},
outputs={"Out": out})
return out
@templatedoc()
def random_crop(x, shape, seed=None):
"""
......
......@@ -65,7 +65,6 @@ __all__ = [
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'scatter',
'sum',
'slice',
'shape',
......
......@@ -347,6 +347,25 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(loss)
print(str(program))
def test_scatter(self):
program = Program()
with program_guard(program):
x = layers.data(
name='x',
shape=[3, 3],
append_batch_size=False,
dtype='float32')
idx = layers.data(
name='idx', shape=[2], append_batch_size=False, dtype='int32')
updates = layers.data(
name='updates',
shape=[2, 3],
append_batch_size=False,
dtype='float32')
out = layers.scatter(input=x, index=idx, updates=updates)
self.assertIsNotNone(out)
print(str(program))
def test_lod_reset(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册