提交 2b37298d 编写于 作者: Q Qiao Longfei

optimize infer.py

上级 0d4693bd
......@@ -32,6 +32,11 @@ def parse_args():
type=int,
default=10,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
......@@ -43,7 +48,7 @@ def infer():
inference_scope = fluid.core.Scope()
dataset = reader.Dataset()
test_reader = paddle.batch(dataset.train([args.data_path]), batch_size=1000)
test_reader = paddle.batch(dataset.train([args.data_path]), batch_size=args.batch_size)
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
......@@ -71,7 +76,7 @@ def infer():
feed=feeder.feed(data),
fetch_list=fetch_targets)
if batch_id % 100 == 0:
logger.info("TEST --> batch: {} loss: {} auc: {}".format(batch_id, loss_val, auc_val))
logger.info("TEST --> batch: {} loss: {} auc: {}".format(batch_id, loss_val/args.batch_size, auc_val))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册