Fix truncated `steps_per_execution` which failed in distributed training.
The problem is that for a dataset with e.g. 14 elements and `steps_per_exectuion=5`, the `DataAdapter.steps` iterator does the following: 1. Yield `0`, 2. Yield `5`, 3. Set `steps_per_execution` to `4`, yield `10`, 4. Set `steps_per_execution` back to `5`. The problem is that in distributed training, the steps are only enqueued, and not executed. So even if the value of `steps_per_execution` is adjusted to `4` for the final step, and has a value of `4` when the task is enqueued, `steps_per_execution` is set back to `5` before the task is actually run. As a result, 15 steps are computed instead of 14. This change makes the number of steps a parameter of the internal `train_function`, `predict_function`, and `test_function` functions, and passes a copy of the value of `steps_per_execution` at the time the task is enqueued, e.g. between steps 3 and 4 above. PiperOrigin-RevId: 395042946
Showing
想要评论请 注册 或 登录