提交 6eccaa82 编写于 作者: Z zhangwenhui03

fix gru4rec style

上级 5fed50e8
...@@ -249,10 +249,10 @@ model saved in model_recall20/epoch_1 ...@@ -249,10 +249,10 @@ model saved in model_recall20/epoch_1
``` ```
## 预测 ## 预测
运行命令 开始预测. 运行命令 全词表运行infer.py, 负采样运行infer_sample_neg.py。
``` ```
CUDA_VISIBLE_DEVICES=0 python infer.py --test_dir test_data/ --model_dir model_recall20/ --start_index 1 --last_index 10 --use_cuda 1 CUDA_VISIBLE_DEVICES=0 python infer.py --test_dir test_data/ --model_dir model_output/ --start_index 1 --last_index 10 --use_cuda 1
``` ```
## 预测结果示例 ## 预测结果示例
......
...@@ -63,7 +63,8 @@ def infer(test_reader, use_cuda, model_path): ...@@ -63,7 +63,8 @@ def infer(test_reader, use_cuda, model_path):
accum_num_sum += (data_length) accum_num_sum += (data_length)
accum_num_recall += (data_length * acc_) accum_num_recall += (data_length * acc_)
if step_id % 1 == 0: if step_id % 1 == 0:
print("step:%d " % (step_id), accum_num_recall / accum_num_sum) print("step:%d recall@20:%.4f" %
(step_id, accum_num_recall / accum_num_sum))
t1 = time.time() t1 = time.time()
print("model:%s recall@20:%.3f time_cost(s):%.2f" % print("model:%s recall@20:%.3f time_cost(s):%.2f" %
(model_path, accum_num_recall / accum_num_sum, t1 - t0)) (model_path, accum_num_recall / accum_num_sum, t1 - t0))
......
...@@ -21,7 +21,7 @@ def parse_args(): ...@@ -21,7 +21,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--last_index', type=int, default='3', help='last index') '--last_index', type=int, default='3', help='last index')
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default='model_bpr_recall20', help='model dir') '--model_dir', type=str, default='model_neg_recall20', help='model dir')
parser.add_argument( parser.add_argument(
'--use_cuda', type=int, default='0', help='whether use cuda') '--use_cuda', type=int, default='0', help='whether use cuda')
parser.add_argument( parser.add_argument(
...@@ -76,8 +76,8 @@ def infer(args, vocab_size, test_reader, use_cuda): ...@@ -76,8 +76,8 @@ def infer(args, vocab_size, test_reader, use_cuda):
accum_num_sum += (data_length) accum_num_sum += (data_length)
accum_num_recall += (data_length * acc_) accum_num_recall += (data_length * acc_)
if step_id % 1 == 0: if step_id % 1 == 0:
print("step:%d " % (step_id), print("step:%d recall@20:%.4f" %
accum_num_recall / accum_num_sum) (step_id, accum_num_recall / accum_num_sum))
t1 = time.time() t1 = time.time()
print("model:%s recall@20:%.4f time_cost(s):%.2f" % print("model:%s recall@20:%.4f time_cost(s):%.2f" %
(model_path, accum_num_recall / accum_num_sum, t1 - t0)) (model_path, accum_num_recall / accum_num_sum, t1 - t0))
......
...@@ -29,7 +29,7 @@ def parse_args(): ...@@ -29,7 +29,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--loss', type=str, default="bpr", help='loss: bpr/cross_entropy') '--loss', type=str, default="bpr", help='loss: bpr/cross_entropy')
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default='model_bpr_recall20', help='model dir') '--model_dir', type=str, default='model_neg_recall20', help='model dir')
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size') '--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册