未验证 提交 aec1fec6 编写于 作者: T Taylor Robie 提交者: GitHub

Fix/ncf eval default (#5438)

* improve default handling for eval_batch_size

* return eval_batch_size default to None

* fix syntax error
上级 505cad95
......@@ -128,8 +128,9 @@ def run_ncf(_):
batch_size = distribution_utils.per_device_batch_size(
int(FLAGS.batch_size), num_gpus)
eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size)
eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1
eval_batch_size = int(FLAGS.eval_batch_size or
max([FLAGS.batch_size, eval_per_user]))
if eval_batch_size % eval_per_user:
eval_batch_size = eval_batch_size // eval_per_user * eval_per_user
tf.logging.warning(
......@@ -365,7 +366,8 @@ def define_ncf_flags():
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
.format(rconst.NUM_EVAL_NEGATIVES + 1))
def eval_size_check(eval_batch_size):
return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES
return (eval_batch_size is None or
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册