未验证 提交 21113d81 编写于 作者: 王肖 提交者: GitHub

Update trainer.py

上级 b293afa7
...@@ -373,7 +373,7 @@ class Trainer(object): ...@@ -373,7 +373,7 @@ class Trainer(object):
self._num_epochs = reader.num_epochs self._num_epochs = reader.num_epochs
if phase == 'train': if phase == 'train':
self._train_reader = reader self._train_reader = reader
self._steps_pur_epoch = reader.num_examples // batch_size // gpu_dev_count self._steps_pur_epoch = reader.num_examples // batch_size
shape_and_dtypes = self._shape_and_dtypes shape_and_dtypes = self._shape_and_dtypes
name_to_position = self._name_to_position name_to_position = self._name_to_position
if self._task_id is not None: if self._task_id is not None:
...@@ -387,7 +387,7 @@ class Trainer(object): ...@@ -387,7 +387,7 @@ class Trainer(object):
elif phase == 'predict': elif phase == 'predict':
self._predict_reader = reader self._predict_reader = reader
tail = self._num_examples % batch_size > 0 tail = self._num_examples % batch_size > 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 // gpu_dev_count if tail else 0 self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
shape_and_dtypes = self._pred_shape_and_dtypes shape_and_dtypes = self._pred_shape_and_dtypes
name_to_position = self._pred_name_to_position name_to_position = self._pred_name_to_position
net_inputs = self._pred_net_inputs net_inputs = self._pred_net_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册