Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
16c41716
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16c41716
编写于
10月 25, 2019
作者:
L
liweibin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update pgl
上级
0bdc0da9
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
165 addition
and
220 deletion
+165
-220
pgl/__init__.py
pgl/__init__.py
+1
-1
pgl/contrib/heter_graph.py
pgl/contrib/heter_graph.py
+76
-91
pgl/contrib/heter_graph_wrapper.py
pgl/contrib/heter_graph_wrapper.py
+38
-124
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+1
-2
pgl/layers/__init__.py
pgl/layers/__init__.py
+2
-0
pgl/sample.py
pgl/sample.py
+46
-1
pgl/utils/paddle_helper.py
pgl/utils/paddle_helper.py
+1
-1
未找到文件。
pgl/__init__.py
浏览文件 @
16c41716
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Generate pgl apis
"""Generate pgl apis
"""
"""
__version__
=
"1.0.
0
"
__version__
=
"1.0.
1
"
from
pgl
import
layers
from
pgl
import
layers
from
pgl
import
graph_wrapper
from
pgl
import
graph_wrapper
from
pgl
import
graph
from
pgl
import
graph
...
...
pgl/contrib/heter_graph.py
浏览文件 @
16c41716
...
@@ -14,11 +14,12 @@
...
@@ -14,11 +14,12 @@
"""
"""
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
"""
"""
import
time
import
numpy
as
np
import
numpy
as
np
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
import
pgl.graph_kernel
as
graph_kernel
import
pgl.graph_kernel
as
graph_kernel
from
pgl
import
g
raph
from
pgl
.graph
import
G
raph
__all__
=
[
'HeterGraph'
]
__all__
=
[
'HeterGraph'
]
...
@@ -31,123 +32,111 @@ def _hide_num_nodes(shape):
...
@@ -31,123 +32,111 @@ def _hide_num_nodes(shape):
return
shape
return
shape
class
HeterGraph
(
object
):
class
NodeGraph
(
Graph
):
"""Implementation of graph structure in pgl
"""Implementation of a graph that has multple node types.
This is a simple implementation of heterogeneous graph structure in pgl
Args:
Args:
num_nodes_every_type: dict, number of nodes for every node type
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
edges_every_type: dict, every element is a list of (u, v) tuples.
def
__init__
(
self
,
num_nodes
,
edges
,
node_types
=
None
,
node_feat
=
None
,
edge_feat
=
None
):
super
(
NodeGraph
,
self
).
__init__
(
num_nodes
,
edges
,
node_feat
,
edge_feat
)
if
isinstance
(
node_types
,
list
):
self
.
_node_types
=
np
.
array
(
node_types
,
dtype
=
object
)[:,
1
]
else
:
self
.
_node_types
=
node_types
node_feat_every_type: features for every node type.
class
HeterGraph
(
object
):
"""Implementation of heterogeneous graph structure in pgl
This is a simple implementation of heterogeneous graph structure in pgl.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
Examples:
Examples:
.. code-block:: python
.. code-block:: python
import numpy as np
import numpy as np
num_nodes_every_type = {'type1':3,'type2':4, 'type3':2}
num_nodes = 4
edges_every_type = {
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
('type1','type2', 'edges_type1'): [(0,1), (1,2)],
edges = {
('type1', 'type3', 'edges_type2'): [(1,2), (3,1)],
'edges_type1': [(0,1), (3,2)],
}
'edges_type2': [(1,2), (3,1)],
node_feat_every_type = {
'type1': {'features1': np.random.randn(3, 4),
'features2': np.random.randn(3, 4)},
'type2': {'features3': np.random.randn(4, 4)},
'type3': {'features1': np.random.randn(2, 4),
'features2': np.random.randn(2, 4)}
}
}
edges_feat_every_type = {
node_feat = {'feature': np.random.randn(4, 16)}
('type1','type2','edges_type1'): {'h': np.random.randn(2, 4)},
edges_feat = {
('type1', 'type3', 'edges_type2'): {'h':np.random.randn(2, 4)},
'edges_type1': {'h': np.random.randn(2, 16)},
'edges_type2': {'h': np.random.randn(2, 16)},
}
}
g = heter_graph.HeterGraph(
g = heter_graph.HeterGraph(
num_nodes
_every_type=num_nodes_every_type,
num_nodes
=num_nodes,
edges
_every_type=edges_every_type
,
edges
=edges
,
node_
feat_every_type=node_feat_every_type
,
node_
types=node_types
,
edge_feat_every_type=edges_feat_every_type)
node_feat=node_feat,
edge_feat=edges_feat)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_nodes_every_type
,
num_nodes
,
edges_every_type
,
edges
,
node_feat_every_type
=
None
,
node_types
=
None
,
edge_feat_every_type
=
None
):
node_feat
=
None
,
edge_feat
=
None
):
self
.
_num_nodes_dict
=
num_nodes_every_type
self
.
_num_nodes
=
num_nodes
self
.
_edges_dict
=
edges_every_type
self
.
_edges_dict
=
edges
if
node_feat_every_type
is
not
None
:
self
.
_node_feat
=
node_feat_every_type
if
node_feat
is
not
None
:
self
.
_node_feat
=
node_feat
else
:
else
:
self
.
_node_feat
=
{}
self
.
_node_feat
=
{}
if
edge_feat
_every_type
is
not
None
:
if
edge_feat
is
not
None
:
self
.
_edge_feat
=
edge_feat
_every_type
self
.
_edge_feat
=
edge_feat
else
:
else
:
self
.
_edge_feat
=
{}
self
.
_edge_feat
=
{}
self
.
_multi_graph
=
{}
self
.
_multi_graph
=
{}
for
key
,
value
in
self
.
_edges_dict
.
items
():
for
key
,
value
in
self
.
_edges_dict
.
items
():
if
not
self
.
_node_feat
:
node_feat
=
None
else
:
node_feat
=
self
.
_node_feat
[
key
[
0
]]
if
not
self
.
_edge_feat
:
if
not
self
.
_edge_feat
:
edge_feat
=
None
edge_feat
=
None
else
:
else
:
edge_feat
=
self
.
_edge_feat
[
key
]
edge_feat
=
self
.
_edge_feat
[
key
]
self
.
_multi_graph
[
key
]
=
graph
.
Graph
(
self
.
_multi_graph
[
key
]
=
Node
Graph
(
num_nodes
=
self
.
_num_nodes
_dict
[
key
[
1
]]
,
num_nodes
=
self
.
_num_nodes
,
edges
=
value
,
edges
=
value
,
node_feat
=
node_feat
,
node_types
=
node_types
,
node_feat
=
self
.
_node_feat
,
edge_feat
=
edge_feat
)
edge_feat
=
edge_feat
)
@
property
def
num_nodes
(
self
):
"""Return the number of nodes.
"""
return
self
.
_num_nodes
def
__getitem__
(
self
,
edge_type
):
def
__getitem__
(
self
,
edge_type
):
"""__getitem__
"""__getitem__
"""
"""
return
self
.
_multi_graph
[
edge_type
]
return
self
.
_multi_graph
[
edge_type
]
def
meta_path_random_walk
(
self
,
start_node
,
edge_types
,
meta_path
,
max_depth
):
"""Meta path random walk sampling.
Args:
start_nodes: int, node to begin random walk.
edge_types: list, the edge types to be sampled.
meta_path: 'user-item-user'
max_depth: the max length of every walk.
"""
edges_type_list
=
[]
node_type_list
=
meta_path
.
split
(
'-'
)
for
i
in
range
(
1
,
len
(
node_type_list
)):
edges_type_list
.
append
(
(
node_type_list
[
i
-
1
],
node_type_list
[
i
],
edge_types
[
i
-
1
]))
no_neighbors_flag
=
False
walk
=
[
start_node
]
for
i
in
range
(
max_depth
):
for
e_type
in
edges_type_list
:
cur_node
=
[
walk
[
-
1
]]
nxt_node
=
self
.
_multi_graph
[
e_type
].
sample_successor
(
cur_node
,
max_degree
=
1
)
# list of np.array
nxt_node
=
nxt_node
[
0
]
if
len
(
nxt_node
)
==
0
:
no_neighbors_flag
=
True
break
else
:
walk
.
append
(
nxt_node
.
tolist
()[
0
])
if
no_neighbors_flag
:
break
return
walk
def
node_feat_info
(
self
):
def
node_feat_info
(
self
):
"""Return the information of node feature for HeterGraphWrapper.
"""Return the information of node feature for HeterGraphWrapper.
...
@@ -155,17 +144,13 @@ class HeterGraph(object):
...
@@ -155,17 +144,13 @@ class HeterGraph(object):
function is used to help constructing HeterGraphWrapper
function is used to help constructing HeterGraphWrapper
Return:
Return:
A
dict of
list of tuple (name, shape, dtype) for all given node feature.
A list of tuple (name, shape, dtype) for all given node feature.
"""
"""
node_feat_info
=
{}
node_feat_info
=
[]
for
node_type_name
,
feat_dict
in
self
.
_node_feat
.
items
():
for
feat_name
,
feat
in
self
.
_node_feat
.
items
():
tmp_node_feat_info
=
[]
node_feat_info
.
append
(
for
feat_name
,
feat
in
feat_dict
.
items
():
(
feat_name
,
_hide_num_nodes
(
feat
.
shape
),
feat
.
dtype
))
full_name
=
feat_name
tmp_node_feat_info
.
append
(
(
full_name
,
_hide_num_nodes
(
feat
.
shape
),
feat
.
dtype
))
node_feat_info
[
node_type_name
]
=
tmp_node_feat_info
return
node_feat_info
return
node_feat_info
...
@@ -193,7 +178,7 @@ class HeterGraph(object):
...
@@ -193,7 +178,7 @@ class HeterGraph(object):
"""Return the information of all edge types.
"""Return the information of all edge types.
Return:
Return:
A list of
tuple ('srctype','dsttype', 'edges_type') for
all edge types.
A list of all edge types.
"""
"""
edge_types_info
=
[]
edge_types_info
=
[]
...
...
pgl/contrib/heter_graph_wrapper.py
浏览文件 @
16c41716
...
@@ -26,6 +26,7 @@ from pgl.utils.logger import log
...
@@ -26,6 +26,7 @@ from pgl.utils.logger import log
from
pgl.graph_wrapper
import
GraphWrapper
from
pgl.graph_wrapper
import
GraphWrapper
ALL
=
"__ALL__"
ALL
=
"__ALL__"
__all__
=
[
"HeterGraphWrapper"
]
def
is_all
(
arg
):
def
is_all
(
arg
):
...
@@ -34,89 +35,6 @@ def is_all(arg):
...
@@ -34,89 +35,6 @@ def is_all(arg):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
class
BipartiteGraphWrapper
(
GraphWrapper
):
"""Implement a bipartite graph wrapper that creates a graph data holders.
"""
def
__init__
(
self
,
name
,
place
,
node_feat
=
[],
edge_feat
=
[]):
super
(
BipartiteGraphWrapper
,
self
).
__init__
(
name
,
place
,
node_feat
,
edge_feat
)
def
send
(
self
,
message_func
,
src_nfeat_list
=
None
,
dst_nfeat_list
=
None
,
efeat_list
=
None
):
"""Send message from all src nodes to dst nodes.
The UDF message function should has the following format.
.. code-block:: python
def message_func(src_feat, dst_feat, edge_feat):
'''
Args:
src_feat: the node feat dict attached to the src nodes.
dst_feat: the node feat dict attached to the dst nodes.
edge_feat: the edge feat dict attached to the
corresponding (src, dst) edges.
Return:
It should return a tensor or a dictionary of tensor. And each tensor
should have a shape of (num_edges, dims).
'''
pass
Args:
message_func: UDF function.
src_nfeat_list: a list of tuple (name, tensor) for src nodes
dst_nfeat_list: a list of tuple (name, tensor) for dst nodes
efeat_list: a list of names or tuple (name, tensor)
Return:
A dictionary of tensor representing the message. Each of the values
in the dictionary has a shape (num_edges, dim) which should be collected
by :code:`recv` function.
"""
if
efeat_list
is
None
:
efeat_list
=
{}
if
src_nfeat_list
is
None
:
src_nfeat_list
=
{}
if
dst_nfeat_list
is
None
:
dst_nfeat_list
=
{}
src
,
dst
=
self
.
edges
src_feat
=
{}
for
feat
in
src_nfeat_list
:
if
isinstance
(
feat
,
str
):
src_feat
[
feat
]
=
self
.
node_feat
[
feat
]
else
:
name
,
tensor
=
feat
src_feat
[
name
]
=
tensor
dst_feat
=
{}
for
feat
in
dst_nfeat_list
:
if
isinstance
(
feat
,
str
):
dst_feat
[
feat
]
=
self
.
node_feat
[
feat
]
else
:
name
,
tensor
=
feat
dst_feat
[
name
]
=
tensor
efeat
=
{}
for
feat
in
efeat_list
:
if
isinstance
(
feat
,
str
):
efeat
[
feat
]
=
self
.
edge_feat
[
feat
]
else
:
name
,
tensor
=
feat
efeat
[
name
]
=
tensor
src_feat
=
op
.
read_rows
(
src_feat
,
src
)
dst_feat
=
op
.
read_rows
(
dst_feat
,
dst
)
msg
=
message_func
(
src_feat
,
dst_feat
,
efeat
)
return
msg
class
HeterGraphWrapper
(
object
):
class
HeterGraphWrapper
(
object
):
"""Implement a heterogeneous graph wrapper that creates a graph data holders
"""Implement a heterogeneous graph wrapper that creates a graph data holders
that attributes and features in the heterogeneous graph.
that attributes and features in the heterogeneous graph.
...
@@ -146,33 +64,30 @@ class HeterGraphWrapper(object):
...
@@ -146,33 +64,30 @@ class HeterGraphWrapper(object):
import paddle.fluid as fluid
import paddle.fluid as fluid
import numpy as np
import numpy as np
num_nodes_every_type = {'type1':3,'type2':4, 'type3':2}
from pgl.contrib import heter_graph
edges_every_type = {
from pgl.contrib import heter_graph_wrapper
('type1','type2', 'edges_type1'): [(0,1), (1,2)],
num_nodes = 4
('type1', 'type3', 'edges_type2'): [(1,2), (3,1)],
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
}
edges = {
node_feat_every_type = {
'edges_type1': [(0,1), (3,2)],
'type1': {'features1': np.random.randn(3, 4),
'edges_type2': [(1,2), (3,1)],
'features2': np.random.randn(3, 4)},
'type2': {'features3': np.random.randn(4, 4)},
'type3': {'features1': np.random.randn(2, 4),
'features2': np.random.randn(2, 4)}
}
}
edges_feat_every_type = {
node_feat = {'feature': np.random.randn(4, 16)}
('type1','type2','edges_type1'): {'h': np.random.randn(2, 4)},
edges_feat = {
('type1', 'type3', 'edges_type2'): {'h':np.random.randn(2, 4)},
'edges_type1': {'h': np.random.randn(2, 16)},
'edges_type2': {'h': np.random.randn(2, 16)},
}
}
g = heter_graph.HeterGraph(
g = heter_graph.HeterGraph(
num_nodes_every_type=num_nodes_every_type,
num_nodes=num_nodes,
edges_every_type=edges_every_type
,
edges=edges
,
node_feat_every_type=node_feat_every_type
,
node_types=node_types
,
edge_feat_every_type=edges_feat_every_type)
node_feat=node_feat,
edge_feat=edges_feat)
place = fluid.CPUPlace()
place = fluid.CPUPlace()
gw = pgl.
heter_graph_wrapper.HeterGraphWrapper(
gw =
heter_graph_wrapper.HeterGraphWrapper(
name='heter_graph',
name='heter_graph',
place = place,
place = place,
edge_types = g.edge_types_info(),
edge_types = g.edge_types_info(),
...
@@ -186,10 +101,9 @@ class HeterGraphWrapper(object):
...
@@ -186,10 +101,9 @@ class HeterGraphWrapper(object):
self
.
_edge_types
=
edge_types
self
.
_edge_types
=
edge_types
self
.
_multi_gw
=
{}
self
.
_multi_gw
=
{}
for
edge_type
in
self
.
_edge_types
:
for
edge_type
in
self
.
_edge_types
:
type_name
=
self
.
__data_name_prefix
+
'/'
+
edge_type
[
type_name
=
self
.
__data_name_prefix
+
'/'
+
edge_type
0
]
+
'_'
+
edge_type
[
1
]
if
node_feat
:
if
node_feat
:
n_feat
=
node_feat
[
edge_type
[
0
]]
n_feat
=
node_feat
else
:
else
:
n_feat
=
{}
n_feat
=
{}
...
@@ -198,7 +112,7 @@ class HeterGraphWrapper(object):
...
@@ -198,7 +112,7 @@ class HeterGraphWrapper(object):
else
:
else
:
e_feat
=
{}
e_feat
=
{}
self
.
_multi_gw
[
edge_type
]
=
Bipartite
GraphWrapper
(
self
.
_multi_gw
[
edge_type
]
=
GraphWrapper
(
name
=
type_name
,
name
=
type_name
,
place
=
self
.
_place
,
place
=
self
.
_place
,
node_feat
=
n_feat
,
node_feat
=
n_feat
,
...
...
pgl/graph_wrapper.py
浏览文件 @
16c41716
...
@@ -596,8 +596,7 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -596,8 +596,7 @@ class GraphWrapper(BaseGraphWrapper):
feed_dict
[
self
.
__data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
__data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
__data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
__data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
__data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
feed_dict
[
self
.
__data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
graph
.
num_nodes
)
graph
.
num_nodes
)
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
__data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
__data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
...
...
pgl/layers/__init__.py
浏览文件 @
16c41716
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
from
pgl.layers
import
conv
from
pgl.layers
import
conv
from
pgl.layers.conv
import
*
from
pgl.layers.conv
import
*
from
pgl.layers.set2set
import
*
__all__
=
[]
__all__
=
[]
__all__
+=
conv
.
__all__
__all__
+=
conv
.
__all__
__all__
+=
set2set
.
__all__
pgl/sample.py
浏览文件 @
16c41716
...
@@ -22,7 +22,10 @@ import pgl
...
@@ -22,7 +22,10 @@ import pgl
from
pgl.utils.logger
import
log
from
pgl.utils.logger
import
log
from
pgl
import
graph_kernel
from
pgl
import
graph_kernel
__all__
=
[
'graphsage_sample'
,
'node2vec_sample'
,
'deepwalk_sample'
]
__all__
=
[
'graphsage_sample'
,
'node2vec_sample'
,
'deepwalk_sample'
,
'metapath_randomwalk'
]
def
edge_hash
(
src
,
dst
):
def
edge_hash
(
src
,
dst
):
...
@@ -251,3 +254,45 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
...
@@ -251,3 +254,45 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
prev_nodes
,
prev_succs
=
cur_nodes
,
cur_succs
prev_nodes
,
prev_succs
=
cur_nodes
,
cur_succs
cur_nodes
=
nxt_nodes
cur_nodes
=
nxt_nodes
return
walk
return
walk
def
metapath_randomwalk
(
graph
,
start_node
,
metapath
,
walk_length
):
"""Implementation of metapath random walk in heterogeneous graph.
Args:
graph: instance of pgl heterogeneous graph
start_node: start node to generate walk
metapath: meta path for sample nodes.
e.g: "user-item-user"
walk_length: the walk length
Return:
a list of metapath walk, each element is a node id.
"""
np
.
random
.
seed
()
walk
=
[]
metapath
=
metapath
.
split
(
'-'
)
assert
metapath
[
0
]
==
metapath
[
-
1
],
"The last meta path item should be the same as the first one"
mp_len
=
len
(
metapath
)
-
1
walk
.
append
(
start_node
)
for
i
in
range
(
1
,
walk_length
):
cur_node
=
walk
[
-
1
]
succs
=
graph
.
successor
(
cur_node
)
if
succs
.
size
>
0
:
succs_node_types
=
graph
.
_node_types
[
succs
]
else
:
# no successor of current node
break
succs_nodes
=
succs
[
np
.
where
(
succs_node_types
==
metapath
[
i
%
mp_len
])[
0
]]
if
succs_nodes
.
size
>
0
:
walk
.
append
(
np
.
random
.
choice
(
succs_nodes
))
else
:
# no successor of such node type
break
return
walk
pgl/utils/paddle_helper.py
浏览文件 @
16c41716
...
@@ -226,7 +226,6 @@ def scatter_add(input, index, updates):
...
@@ -226,7 +226,6 @@ def scatter_add(input, index, updates):
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'add'
)
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'add'
)
return
output
return
output
def
scatter_max
(
input
,
index
,
updates
):
def
scatter_max
(
input
,
index
,
updates
):
"""Scatter max updates to input by given index.
"""Scatter max updates to input by given index.
...
@@ -245,3 +244,4 @@ def scatter_max(input, index, updates):
...
@@ -245,3 +244,4 @@ def scatter_max(input, index, updates):
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'max'
)
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'max'
)
return
output
return
output
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录