From fd1ee0ef107de0073484811b00289d104a26edff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=82=96?= Date: Tue, 25 Feb 2020 11:56:32 +0800 Subject: [PATCH] Fix some bugs running on the GPU of dygraph/similarity_net (#4334) * Update README.md (#4267) * test=develop (#4269) * 3d use new api (#4275) * PointNet++ and PointRCNN use new API * Update Readme of Dygraph BERT (#4277) Fix some typos. * Update run_classifier_multi_gpu.sh (#4279) remove the CUDA_VISIBLE_DEVICES * Update README.md (#4280) * add similarity_net dygraph * fix similarity_net dygraph * fix bugs of dygraph/similarity_net * Fix some bugs running on the GPU of dygraph/similarity_net Co-authored-by: pkpk Co-authored-by: Kaipeng Deng --- dygraph/similarity_net/README.md | 4 ++-- dygraph/similarity_net/run_classifier.py | 17 +++++++++-------- dygraph/similarity_net/utils.py | 22 ++++++++++++++++++++++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/dygraph/similarity_net/README.md b/dygraph/similarity_net/README.md index 0d75e04a..4f7270b0 100644 --- a/dygraph/similarity_net/README.md +++ b/dygraph/similarity_net/README.md @@ -10,7 +10,8 @@ | 模型 | 百度知道 | ECOM |QQSIM | UNICOM | |:-----------:|:-------------:|:-------------:|:-------------:|:-------------:| | | AUC | AUC | AUC|正逆序比| -|BOW_Pairwise|0.6815|0.7331|0.7638|1.5566| +|BOW_Pairwise|0.6815|0.7331|0.7638|1.5565| + #### 测试集说明 | 数据集 | 来源 | 垂类 | @@ -25,7 +26,6 @@ 本项目依赖于 Paddlepaddle Fluid 1.7,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。 -python版本依赖python 2.7 #### 安装代码 克隆工具集代码库到本地 ```shell diff --git a/dygraph/similarity_net/run_classifier.py b/dygraph/similarity_net/run_classifier.py index b790927e..a0695dd4 100644 --- a/dygraph/similarity_net/run_classifier.py +++ b/dygraph/similarity_net/run_classifier.py @@ -43,6 +43,7 @@ import io import logging from utils import ArgConfig +from utils import load_dygraph from model_check import check_version from model_check import check_cuda @@ -103,7 +104,7 @@ def train(conf_dict, args): conf_dict["net"]["module_name"], conf_dict["net"]["class_name"])(conf_dict) if args.init_checkpoint is not "": - model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint) + model, _ = load_dygraph(args.init_checkpoint) net.set_dict(model) # Load loss function dynamically loss = utils.import_class("./nets/losses", @@ -135,13 +136,13 @@ def train(conf_dict, args): losses = [] start_time = time.time() - train_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False) + train_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True) get_train_examples = simnet_process.get_reader("train",epoch=args.epoch) train_pyreader.decorate_sample_list_generator( paddle.batch(get_train_examples, batch_size=args.batch_size), place) if args.do_valid: - valid_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False) + valid_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True) get_valid_examples = simnet_process.get_reader("valid") valid_pyreader.decorate_sample_list_generator( paddle.batch(get_valid_examples, batch_size=args.batch_size), @@ -269,7 +270,7 @@ def train(conf_dict, args): if args.do_test: # Get Feeder and Reader - test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False) + test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True) get_test_examples = simnet_process.get_reader("test") test_pyreader.decorate_sample_list_generator( paddle.batch(get_test_examples, batch_size=args.batch_size), @@ -307,7 +308,7 @@ def test(conf_dict, args): vocab = utils.load_vocab(args.vocab_path) simnet_process = reader.SimNetProcessor(args, vocab) - test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False) + test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True) get_test_examples = simnet_process.get_reader("test") test_pyreader.decorate_sample_list_generator( paddle.batch(get_test_examples, batch_size=args.batch_size), @@ -321,7 +322,7 @@ def test(conf_dict, args): conf_dict["net"]["module_name"], conf_dict["net"]["class_name"])(conf_dict) - model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint) + model, _ = load_dygraph(args.init_checkpoint) net.set_dict(model) metric = fluid.metrics.Auc(name="auc") pred_list = [] @@ -390,7 +391,7 @@ def infer(conf_dict, args): vocab = utils.load_vocab(args.vocab_path) simnet_process = reader.SimNetProcessor(args, vocab) get_infer_examples = simnet_process.get_infer_reader - infer_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False) + infer_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True) infer_pyreader.decorate_sample_list_generator( paddle.batch(get_infer_examples, batch_size=args.batch_size), place) @@ -401,7 +402,7 @@ def infer(conf_dict, args): net = utils.import_class("./nets", conf_dict["net"]["module_name"], conf_dict["net"]["class_name"])(conf_dict) - model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint) + model, _ = load_dygraph(args.init_checkpoint) net.set_dict(model) pred_list = [] diff --git a/dygraph/similarity_net/utils.py b/dygraph/similarity_net/utils.py index afb87b59..77b122f7 100644 --- a/dygraph/similarity_net/utils.py +++ b/dygraph/similarity_net/utils.py @@ -26,6 +26,9 @@ import logging import logging.handlers import paddle.fluid as fluid import io +import pickle +import warnings +from functools import partial """ ******functions for file processing****** """ @@ -365,3 +368,22 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): predicate=existed_persitables) print("Load model from {}".format(init_checkpoint_path)) + +def load_dygraph(model_path, keep_name_table=False): + """ + To load python2 saved models in python3. + """ + try: + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + return para_dict, opti_dict + except UnicodeDecodeError: + warnings.warn( + "An UnicodeDecodeError is catched, which might be caused by loading " + "a python2 saved model. Encoding of pickle.load would be set and " + "load again automatically.") + if six.PY3: + load_bak = pickle.load + pickle.load = partial(load_bak, encoding="latin1") + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + pickle.load = load_bak + return para_dict, opti_dict \ No newline at end of file -- GitLab