提交 b72aec53 编写于 作者: X Xinghai Sun

Enable min_batch_num in train.py and update train info print.

上级 b1e2b239
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
...@@ -143,11 +143,13 @@ def train(): ...@@ -143,11 +143,13 @@ def train():
train_batch_reader = train_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=args.trainer_count,
sortagrad=args.use_sortagrad if args.init_model_path is None else False, sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True) batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=1, # must be 1, but will have errors.
sortagrad=False, sortagrad=False,
batch_shuffle=False) batch_shuffle=False)
...@@ -157,11 +159,11 @@ def train(): ...@@ -157,11 +159,11 @@ def train():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
cost_sum += event.cost cost_sum += event.cost
cost_counter += 1 cost_counter += 1
if event.batch_id % 50 == 0: if (event.batch_id + 1) % 100 == 0:
print("\nPass: %d, Batch: %d, TrainCost: %f" % print("\nPass: %d, Batch: %d, TrainCost: %f" % (
(event.pass_id, event.batch_id, cost_sum / cost_counter)) event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
with gzip.open("params_tmp.tar.gz", 'w') as f: with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册