From 38745a626c3f937bec836c92c98a76deadf0a03d Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 13 Mar 2018 19:10:28 -0700 Subject: [PATCH] "exported scatter to python" --- python/paddle/fluid/layers/tensor.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index da066c34bde..3ce0adbb77c 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -32,6 +32,7 @@ __all__ = [ 'fill_constant', 'ones', 'zeros', + 'scatter', ] @@ -364,6 +365,27 @@ def zeros(shape, dtype, force_cpu=False): return fill_constant(value=0.0, **locals()) +def scatter(input, index, updates): + """ + Scatter input through the index + Out[Index] = Ref[Index] + Updates + + Args: + input(variable): The Tensor/LoDTensor to be scatterd. + index(variable): The index input of scatter op where Ref will be updated. + updates(variable): The updated value to be added to the output. + """ + helper = LayerHelper("scatter", **locals()) + out = helper.create_tmp_variable(dtype=dtype) + helper.append_op( + type='scatter', + inputs={'Ref': input, + 'Index': index, + 'Updates': updates}, + outputs={'Out': [out]}) + return out + + def save(x, file_path, overwrite=True): """ Saves a variable as a file. -- GitLab