提交 6c4a0850 编写于 作者: Y yelrose

fixed graph_wrapper for dtype inference

上级 03cb3621
......@@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
num_edges):
"""Recv message from given msg to dst nodes.
"""
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype="float32")
if reduce_function == "sum":
if isinstance(msg, dict):
raise TypeError("The message for build-in function"
......@@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try:
out_dim = msg.shape[-1]
init_output = fluid.layers.fill_constant(
shape=[num_nodes, out_dim], value=0, dtype="float32")
shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
init_output.stop_gradient = False
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype)
msg = msg * empty_msg_flag
output = paddle_helper.scatter_add(init_output, dst, msg)
return output
......@@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
bucketed_msg = op.nested_lod_reset(msg, bucketing_index)
output = reduce_function(bucketed_msg)
output_dim = output.shape[-1]
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype)
output = output * empty_msg_flag
init_output = fluid.layers.fill_constant(
shape=[num_nodes, output_dim], value=0, dtype="float32")
shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
init_output.stop_gradient = True
final_output = fluid.layers.scatter(init_output, uniq_dst, output)
return final_output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册