From 2a167ec9657303cfbd9fdd87c601c0c1f837ab9d Mon Sep 17 00:00:00 2001 From: hutuxian Date: Sat, 12 Oct 2019 19:19:08 +0800 Subject: [PATCH] add check version (#3530) Add version checking for DIN and SR-GNN. --- PaddleRec/din/README.md | 1 + PaddleRec/din/cluster_train.py | 18 +++++++++++++++++- PaddleRec/din/infer.py | 16 ++++++++++++++++ PaddleRec/din/train.py | 15 +++++++++++++++ PaddleRec/gnn/README.md | 1 + PaddleRec/gnn/infer.py | 17 +++++++++++++++++ PaddleRec/gnn/train.py | 15 +++++++++++++++ 7 files changed, 82 insertions(+), 1 deletion(-) diff --git a/PaddleRec/din/README.md b/PaddleRec/din/README.md index 3538ba76..75f21e5d 100644 --- a/PaddleRec/din/README.md +++ b/PaddleRec/din/README.md @@ -29,6 +29,7 @@ DIN通过一个兴趣激活模块(Activation Unit),用预估目标Candidate AD 最后我们将这相关的用户兴趣表达、用户静态特征和上下文相关特征,以及ad相关的特征拼接起来,输入到后续的多层DNN网络,最后预测得到用户对当前目标ADs的点击概率。 +**目前模型库下模型均要求使用PaddlePaddle 1.6及以上版本或适当的develop版本。** ## 数据下载及预处理 diff --git a/PaddleRec/din/cluster_train.py b/PaddleRec/din/cluster_train.py index 6b327236..1683a78c 100644 --- a/PaddleRec/din/cluster_train.py +++ b/PaddleRec/din/cluster_train.py @@ -9,6 +9,7 @@ import time import network import reader import random +import sys logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger("fluid") @@ -85,7 +86,7 @@ def train(): #data_reader, max_len = reader.prepare_reader(train_path, args.batch_size) logger.info("reading data completes") - avg_cost, pred = network.network(item_count, cat_count, 433) + avg_cost, pred = network.network(item_count, cat_count) #fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)) base_lr = args.base_lr boundaries = [410000] @@ -167,6 +168,21 @@ def train(): logger.info("run trainer") train_loop(t.get_trainer_program()) +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) if __name__ == "__main__": + check_version() train() diff --git a/PaddleRec/din/infer.py b/PaddleRec/din/infer.py index fc1484e5..b6fb972f 100644 --- a/PaddleRec/din/infer.py +++ b/PaddleRec/din/infer.py @@ -19,6 +19,7 @@ import os import paddle import paddle.fluid as fluid import reader +import sys logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger("fluid") @@ -99,6 +100,21 @@ def infer(): auc = calc_auc(score) logger.info("TEST --> loss: {}, auc: {}".format(loss_sum / count, auc)) +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) if __name__ == '__main__': + check_version() infer() diff --git a/PaddleRec/din/train.py b/PaddleRec/din/train.py index 9c6190c3..0c865f7d 100644 --- a/PaddleRec/din/train.py +++ b/PaddleRec/din/train.py @@ -173,6 +173,21 @@ def get_cards(args): else: return args.num_devices +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) if __name__ == "__main__": + check_version() train() diff --git a/PaddleRec/gnn/README.md b/PaddleRec/gnn/README.md index 1996d9e0..0d97d07e 100644 --- a/PaddleRec/gnn/README.md +++ b/PaddleRec/gnn/README.md @@ -31,6 +31,7 @@ SR-GNN模型的介绍可以参阅论文[Session-based Recommendation with Graph 我们复现了论文效果,在DIGINETICA数据集上P@20可以达到50.7 +**目前模型库下模型均要求使用PaddlePaddle 1.6及以上版本或适当的develop版本。** ## 数据下载及预处理 diff --git a/PaddleRec/gnn/infer.py b/PaddleRec/gnn/infer.py index 21138dfc..20125f72 100644 --- a/PaddleRec/gnn/infer.py +++ b/PaddleRec/gnn/infer.py @@ -20,6 +20,7 @@ import paddle import paddle.fluid as fluid import reader import network +import sys logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger("fluid") @@ -87,6 +88,22 @@ def infer(args): logger.info("TEST --> error: there is no model in " + model_path) +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + if __name__ == "__main__": + check_version() args = parse_args() infer(args) diff --git a/PaddleRec/gnn/train.py b/PaddleRec/gnn/train.py index e292da20..d82db868 100644 --- a/PaddleRec/gnn/train.py +++ b/PaddleRec/gnn/train.py @@ -170,6 +170,21 @@ def get_cards(args): num = len(cards.split(",")) return num +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) if __name__ == "__main__": + check_version() train() -- GitLab