提交 6c4a0850 编写于 作者: Y yelrose

fixed graph_wrapper for dtype inference

上级 03cb3621
...@@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
num_edges): num_edges):
"""Recv message from given msg to dst nodes. """Recv message from given msg to dst nodes.
""" """
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype="float32")
if reduce_function == "sum": if reduce_function == "sum":
if isinstance(msg, dict): if isinstance(msg, dict):
raise TypeError("The message for build-in function" raise TypeError("The message for build-in function"
...@@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try: try:
out_dim = msg.shape[-1] out_dim = msg.shape[-1]
init_output = fluid.layers.fill_constant( init_output = fluid.layers.fill_constant(
shape=[num_nodes, out_dim], value=0, dtype="float32") shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
init_output.stop_gradient = False init_output.stop_gradient = False
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype)
msg = msg * empty_msg_flag msg = msg * empty_msg_flag
output = paddle_helper.scatter_add(init_output, dst, msg) output = paddle_helper.scatter_add(init_output, dst, msg)
return output return output
...@@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
bucketed_msg = op.nested_lod_reset(msg, bucketing_index) bucketed_msg = op.nested_lod_reset(msg, bucketing_index)
output = reduce_function(bucketed_msg) output = reduce_function(bucketed_msg)
output_dim = output.shape[-1] output_dim = output.shape[-1]
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype)
output = output * empty_msg_flag output = output * empty_msg_flag
init_output = fluid.layers.fill_constant( init_output = fluid.layers.fill_constant(
shape=[num_nodes, output_dim], value=0, dtype="float32") shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
init_output.stop_gradient = True init_output.stop_gradient = True
final_output = fluid.layers.scatter(init_output, uniq_dst, output) final_output = fluid.layers.scatter(init_output, uniq_dst, output)
return final_output return final_output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册