patch 1.6 KB
Newer Older
M
Meiyim 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py
index 1341f7a..e3df999 100644
--- a/demo/finetune_classifier_distributed.py
+++ b/demo/finetune_classifier_distributed.py
@@ -65,7 +65,7 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
 parser.add_argument(
     '--save_dir', type=Path, required=True, help='model output directory')
 parser.add_argument(
-    '--wd', type=int, default=0.01, help='weight decay, aka L2 regularizer')
+    '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
 parser.add_argument(
     '--init_checkpoint',
     type=str,
@@ -110,7 +110,7 @@ def map_fn(seg_a, seg_b, label):
     return sentence, segments, label
 
 train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'),
-                                            shuffle=False, repeat=True, use_gz=False, shard=True) \
+                                            shuffle=True, repeat=True, use_gz=False, shard=True) \
                                .map(map_fn) \
                                .padded_batch(args.bsz, (0, 0, 0))
 
diff --git a/propeller/data/functional.py b/propeller/data/functional.py
index 600a139..7c43812 100644
--- a/propeller/data/functional.py
+++ b/propeller/data/functional.py
@@ -94,7 +94,7 @@ def _cache_shuffle_shard_func(dataset, num_shards, index, seed, drop_last,
         len_per_shard = len(data_list) // num_shards
         rng = np.random.RandomState(seed)
         cnt = 0
-        while cnt < repeat:
+        while cnt != repeat:
             cnt += 1
             random.shuffle(data_list, rng.uniform)