提交 01b3d9f4 编写于 作者: W wangxiao1021

fix predict_steps

上级 d71b37d0
...@@ -388,8 +388,9 @@ class Trainer(object): ...@@ -388,8 +388,9 @@ class Trainer(object):
reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone') reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
elif phase == 'predict': elif phase == 'predict':
self._predict_reader = reader self._predict_reader = reader
tail = self._num_examples % batch_size > 0 # tail = self._num_examples % batch_size > 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0 # self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size
shape_and_dtypes = self._pred_shape_and_dtypes shape_and_dtypes = self._pred_shape_and_dtypes
name_to_position = self._pred_name_to_position name_to_position = self._pred_name_to_position
net_inputs = self._pred_net_inputs net_inputs = self._pred_net_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册