diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 9545a8210567e55dbadf37251ddf475a5ac9f5f9..6b9a54fb2f24516db2f09fcc4252935785567081 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, num_edges): """Recv message from given msg to dst nodes. """ - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype="float32") if reduce_function == "sum": if isinstance(msg, dict): raise TypeError("The message for build-in function" @@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, try: out_dim = msg.shape[-1] init_output = fluid.layers.fill_constant( - shape=[num_nodes, out_dim], value=0, dtype="float32") + shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) init_output.stop_gradient = False + empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype) msg = msg * empty_msg_flag output = paddle_helper.scatter_add(init_output, dst, msg) return output @@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, bucketed_msg = op.nested_lod_reset(msg, bucketing_index) output = reduce_function(bucketed_msg) output_dim = output.shape[-1] + + empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype) output = output * empty_msg_flag init_output = fluid.layers.fill_constant( - shape=[num_nodes, output_dim], value=0, dtype="float32") + shape=[num_nodes, output_dim], value=0, dtype=output.dtype) init_output.stop_gradient = True final_output = fluid.layers.scatter(init_output, uniq_dst, output) return final_output