未验证 提交 fd1ee0ef 编写于 作者: 王肖 提交者: GitHub

Fix some bugs running on the GPU of dygraph/similarity_net (#4334)

* Update README.md (#4267)

* test=develop (#4269)

* 3d use new api (#4275)

* PointNet++ and PointRCNN use new API

* Update Readme of Dygraph BERT (#4277)

Fix some typos.

* Update run_classifier_multi_gpu.sh (#4279)

remove the CUDA_VISIBLE_DEVICES

* Update README.md (#4280)

* add similarity_net dygraph

* fix similarity_net dygraph

* fix bugs of dygraph/similarity_net

* Fix some bugs running on the GPU of dygraph/similarity_net
Co-authored-by: Npkpk <xiyzhouang@gmail.com>
Co-authored-by: NKaipeng Deng <dengkaipeng@baidu.com>
上级 4d28bb18
......@@ -10,7 +10,8 @@
| 模型 | 百度知道 | ECOM |QQSIM | UNICOM |
|:-----------:|:-------------:|:-------------:|:-------------:|:-------------:|
| | AUC | AUC | AUC|正逆序比|
|BOW_Pairwise|0.6815|0.7331|0.7638|1.5566|
|BOW_Pairwise|0.6815|0.7331|0.7638|1.5565|
#### 测试集说明
| 数据集 | 来源 | 垂类 |
......@@ -25,7 +26,6 @@
本项目依赖于 Paddlepaddle Fluid 1.7,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。
python版本依赖python 2.7
#### 安装代码
克隆工具集代码库到本地
```shell
......
......@@ -43,6 +43,7 @@ import io
import logging
from utils import ArgConfig
from utils import load_dygraph
from model_check import check_version
from model_check import check_cuda
......@@ -103,7 +104,7 @@ def train(conf_dict, args):
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
if args.init_checkpoint is not "":
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
model, _ = load_dygraph(args.init_checkpoint)
net.set_dict(model)
# Load loss function dynamically
loss = utils.import_class("./nets/losses",
......@@ -135,13 +136,13 @@ def train(conf_dict, args):
losses = []
start_time = time.time()
train_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
train_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True)
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
train_pyreader.decorate_sample_list_generator(
paddle.batch(get_train_examples, batch_size=args.batch_size),
place)
if args.do_valid:
valid_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
valid_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True)
get_valid_examples = simnet_process.get_reader("valid")
valid_pyreader.decorate_sample_list_generator(
paddle.batch(get_valid_examples, batch_size=args.batch_size),
......@@ -269,7 +270,7 @@ def train(conf_dict, args):
if args.do_test:
# Get Feeder and Reader
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True)
get_test_examples = simnet_process.get_reader("test")
test_pyreader.decorate_sample_list_generator(
paddle.batch(get_test_examples, batch_size=args.batch_size),
......@@ -307,7 +308,7 @@ def test(conf_dict, args):
vocab = utils.load_vocab(args.vocab_path)
simnet_process = reader.SimNetProcessor(args, vocab)
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True)
get_test_examples = simnet_process.get_reader("test")
test_pyreader.decorate_sample_list_generator(
paddle.batch(get_test_examples, batch_size=args.batch_size),
......@@ -321,7 +322,7 @@ def test(conf_dict, args):
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
model, _ = load_dygraph(args.init_checkpoint)
net.set_dict(model)
metric = fluid.metrics.Auc(name="auc")
pred_list = []
......@@ -390,7 +391,7 @@ def infer(conf_dict, args):
vocab = utils.load_vocab(args.vocab_path)
simnet_process = reader.SimNetProcessor(args, vocab)
get_infer_examples = simnet_process.get_infer_reader
infer_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
infer_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=True)
infer_pyreader.decorate_sample_list_generator(
paddle.batch(get_infer_examples, batch_size=args.batch_size),
place)
......@@ -401,7 +402,7 @@ def infer(conf_dict, args):
net = utils.import_class("./nets",
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
model, _ = load_dygraph(args.init_checkpoint)
net.set_dict(model)
pred_list = []
......
......@@ -26,6 +26,9 @@ import logging
import logging.handlers
import paddle.fluid as fluid
import io
import pickle
import warnings
from functools import partial
"""
******functions for file processing******
"""
......@@ -365,3 +368,22 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
def load_dygraph(model_path, keep_name_table=False):
"""
To load python2 saved models in python3.
"""
try:
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table)
return para_dict, opti_dict
except UnicodeDecodeError:
warnings.warn(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically.")
if six.PY3:
load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1")
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table)
pickle.load = load_bak
return para_dict, opti_dict
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册