diff --git a/examples/graphsage/train_multi.py b/examples/graphsage/train_multi.py index 8a1e799312e25ff2314de326ea3880ca0970b708..6655db9f57024b168e502c3ff02a9018226c6fc7 100644 --- a/examples/graphsage/train_multi.py +++ b/examples/graphsage/train_multi.py @@ -195,7 +195,7 @@ def run_epoch(batch_iter, if num_trainer > 1: num_samples = sum( - [len(batch["node_index"]) for batch in batch_feed_dict]) + [len(_batch["node_index"]) for _batch in batch_feed_dict]) else: num_samples = len(batch_feed_dict["node_index"]) total_loss += batch_loss * num_samples