未验证 提交 21113d81 编写于 作者: 王肖 提交者: GitHub

Update trainer.py

上级 b293afa7
......@@ -373,7 +373,7 @@ class Trainer(object):
self._num_epochs = reader.num_epochs
if phase == 'train':
self._train_reader = reader
self._steps_pur_epoch = reader.num_examples // batch_size // gpu_dev_count
self._steps_pur_epoch = reader.num_examples // batch_size
shape_and_dtypes = self._shape_and_dtypes
name_to_position = self._name_to_position
if self._task_id is not None:
......@@ -387,7 +387,7 @@ class Trainer(object):
elif phase == 'predict':
self._predict_reader = reader
tail = self._num_examples % batch_size > 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 // gpu_dev_count if tail else 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
shape_and_dtypes = self._pred_shape_and_dtypes
name_to_position = self._pred_name_to_position
net_inputs = self._pred_net_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册