diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py index 1341f7a..e3df999 100644 --- a/demo/finetune_classifier_distributed.py +++ b/demo/finetune_classifier_distributed.py @@ -65,7 +65,7 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') parser.add_argument( '--save_dir', type=Path, required=True, help='model output directory') parser.add_argument( - '--wd', type=int, default=0.01, help='weight decay, aka L2 regularizer') + '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') parser.add_argument( '--init_checkpoint', type=str, @@ -110,7 +110,7 @@ def map_fn(seg_a, seg_b, label): return sentence, segments, label train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), - shuffle=False, repeat=True, use_gz=False, shard=True) \ + shuffle=True, repeat=True, use_gz=False, shard=True) \ .map(map_fn) \ .padded_batch(args.bsz, (0, 0, 0)) diff --git a/propeller/data/functional.py b/propeller/data/functional.py index 600a139..7c43812 100644 --- a/propeller/data/functional.py +++ b/propeller/data/functional.py @@ -94,7 +94,7 @@ def _cache_shuffle_shard_func(dataset, num_shards, index, seed, drop_last, len_per_shard = len(data_list) // num_shards rng = np.random.RandomState(seed) cnt = 0 - while cnt < repeat: + while cnt != repeat: cnt += 1 random.shuffle(data_list, rng.uniform)