提交 38745a62 编写于 作者: D dzhwinter

"exported scatter to python"

上级 cbfd15f9
...@@ -32,6 +32,7 @@ __all__ = [ ...@@ -32,6 +32,7 @@ __all__ = [
'fill_constant', 'fill_constant',
'ones', 'ones',
'zeros', 'zeros',
'scatter',
] ]
...@@ -364,6 +365,27 @@ def zeros(shape, dtype, force_cpu=False): ...@@ -364,6 +365,27 @@ def zeros(shape, dtype, force_cpu=False):
return fill_constant(value=0.0, **locals()) return fill_constant(value=0.0, **locals())
def scatter(input, index, updates):
"""
Scatter input through the index
Out[Index] = Ref[Index] + Updates
Args:
input(variable): The Tensor/LoDTensor to be scatterd.
index(variable): The index input of scatter op where Ref will be updated.
updates(variable): The updated value to be added to the output.
"""
helper = LayerHelper("scatter", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='scatter',
inputs={'Ref': input,
'Index': index,
'Updates': updates},
outputs={'Out': [out]})
return out
def save(x, file_path, overwrite=True): def save(x, file_path, overwrite=True):
""" """
Saves a variable as a file. Saves a variable as a file.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册