提交 74b7049e 编写于 作者: D Dilyar 提交者: Yibing Liu

bug fix on simnet (#3962)

上级 bb48a52c
...@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \ ...@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \
except Exception as e: except Exception as e:
pass pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
def check_version(): def check_version():
""" """
......
...@@ -28,7 +28,7 @@ class SimNetProcessor(object): ...@@ -28,7 +28,7 @@ class SimNetProcessor(object):
self.valid_label = np.array([]) self.valid_label = np.array([])
self.test_label = np.array([]) self.test_label = np.array([])
def get_reader(self, mode): def get_reader(self, mode, epoch=0):
""" """
Get Reader Get Reader
""" """
...@@ -85,34 +85,35 @@ class SimNetProcessor(object): ...@@ -85,34 +85,35 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with io.open(self.args.train_data_dir, "r", for idx in range(epoch):
encoding="utf8") as file: with io.open(self.args.train_data_dir, "r",
for line in file: encoding="utf8") as file:
query, pos_title, neg_title = line.strip().split("\t") for line in file:
if len(query) == 0 or len(pos_title) == 0 or len( query, pos_title, neg_title = line.strip().split("\t")
neg_title) == 0: if len(query) == 0 or len(pos_title) == 0 or len(
logging.warning( neg_title) == 0:
"line not match format in test file") logging.warning(
continue "line not match format in test file")
query = [ continue
self.vocab[word] for word in query.split(" ") query = [
if word in self.vocab self.vocab[word] for word in query.split(" ")
] if word in self.vocab
pos_title = [ ]
self.vocab[word] for word in pos_title.split(" ") pos_title = [
if word in self.vocab self.vocab[word] for word in pos_title.split(" ")
] if word in self.vocab
neg_title = [ ]
self.vocab[word] for word in neg_title.split(" ") neg_title = [
if word in self.vocab self.vocab[word] for word in neg_title.split(" ")
] if word in self.vocab
if len(query) == 0: ]
query = [0] if len(query) == 0:
if len(pos_title) == 0: query = [0]
pos_title = [0] if len(pos_title) == 0:
if len(neg_title) == 0: pos_title = [0]
neg_title = [0] if len(neg_title) == 0:
yield [query, pos_title, neg_title] neg_title = [0]
yield [query, pos_title, neg_title]
def reader_with_pointwise(): def reader_with_pointwise():
""" """
...@@ -166,30 +167,31 @@ class SimNetProcessor(object): ...@@ -166,30 +167,31 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with io.open(self.args.train_data_dir, "r", for idx in range(epoch):
encoding="utf8") as file: with io.open(self.args.train_data_dir, "r",
for line in file: encoding="utf8") as file:
query, title, label = line.strip().split("\t") for line in file:
if len(query) == 0 or len(title) == 0 or len( query, title, label = line.strip().split("\t")
label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning( label) not in [0, 1]:
"line not match format in test file") logging.warning(
continue "line not match format in test file")
query = [ continue
self.vocab[word] for word in query.split(" ") query = [
if word in self.vocab self.vocab[word] for word in query.split(" ")
] if word in self.vocab
title = [ ]
self.vocab[word] for word in title.split(" ") title = [
if word in self.vocab self.vocab[word] for word in title.split(" ")
] if word in self.vocab
label = int(label) ]
if len(query) == 0: label = int(label)
query = [0] if len(query) == 0:
if len(title) == 0: query = [0]
title = [0] if len(title) == 0:
yield [query, title, label] title = [0]
yield [query, title, label]
if self.args.task_mode == "pairwise": if self.args.task_mode == "pairwise":
return reader_with_pairwise return reader_with_pairwise
......
...@@ -140,7 +140,7 @@ def train(conf_dict, args): ...@@ -140,7 +140,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
# Get Reader # Get Reader
get_train_examples = simnet_process.get_reader("train") get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid: if args.do_valid:
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
...@@ -164,7 +164,7 @@ def train(conf_dict, args): ...@@ -164,7 +164,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
# Get Feeder and Reader # Get Feeder and Reader
get_train_examples = simnet_process.get_reader("train") get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid: if args.do_valid:
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
...@@ -218,63 +218,63 @@ def train(conf_dict, args): ...@@ -218,63 +218,63 @@ def train(conf_dict, args):
global_step = 0 global_step = 0
ce_info = [] ce_info = []
train_exe = exe train_exe = exe
for epoch_id in range(args.epoch): #for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch( train_batch_data = fluid.io.batch(
fluid.io.shuffle( fluid.io.shuffle(
get_train_examples, buf_size=10000), get_train_examples, buf_size=10000),
args.batch_size, args.batch_size,
drop_last=False) drop_last=False)
train_pyreader.decorate_paddle_reader(train_batch_data) train_pyreader.decorate_paddle_reader(train_batch_data)
train_pyreader.start() train_pyreader.start()
exe.run(startup_prog) exe.run(startup_prog)
losses = [] losses = []
start_time = time.time() start_time = time.time()
while True: while True:
try: try:
global_step += 1 global_step += 1
fetch_list = [avg_cost.name] fetch_list = [avg_cost.name]
avg_loss = train_exe.run(program=train_program, fetch_list = fetch_list) avg_loss = train_exe.run(program=train_program, fetch_list = fetch_list)
if args.do_valid and global_step % args.validation_steps == 0: if args.do_valid and global_step % args.validation_steps == 0:
get_valid_examples = simnet_process.get_reader("valid") get_valid_examples = simnet_process.get_reader("valid")
valid_result = valid_and_test(test_prog,test_pyreader,get_valid_examples,simnet_process,"valid",exe,[pred.name]) valid_result = valid_and_test(test_prog,test_pyreader,get_valid_examples,simnet_process,"valid",exe,[pred.name])
if args.compute_accuracy: if args.compute_accuracy:
valid_auc, valid_acc = valid_result valid_auc, valid_acc = valid_result
logging.info( logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f" % "global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc)) (global_step, valid_auc, valid_acc, np.mean(losses)))
else: else:
valid_auc = valid_result valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f" % logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc)) (global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir, model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"]) conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step)) model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir): if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir) os.makedirs(model_save_dir)
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
feed_var_names = [left.name, pos_right.name] feed_var_names = [left.name, pos_right.name]
target_vars = [left_feat, pos_score] target_vars = [left_feat, pos_score]
else: else:
feed_var_names = [ feed_var_names = [
left.name, left.name,
right.name, right.name,
] ]
target_vars = [left_feat, pred] target_vars = [left_feat, pred]
fluid.io.save_inference_model(model_path, feed_var_names, fluid.io.save_inference_model(model_path, feed_var_names,
target_vars, exe, target_vars, exe,
test_prog) test_prog)
logging.info("saving infer model in %s" % model_path) logging.info("saving infer model in %s" % model_path)
losses.append(np.mean(avg_loss[0])) losses.append(np.mean(avg_loss[0]))
except fluid.core.EOFException: except fluid.core.EOFException:
train_pyreader.reset() train_pyreader.reset()
break break
end_time = time.time() end_time = time.time()
logging.info("epoch: %d, loss: %f, used time: %d sec" % #logging.info("epoch: %d, loss: %f, used time: %d sec" %
(epoch_id, np.mean(losses), end_time - start_time)) #(epoch_id, np.mean(losses), end_time - start_time))
ce_info.append([np.mean(losses), end_time - start_time]) ce_info.append([np.mean(losses), end_time - start_time])
#final save #final save
logging.info("the final step is %s" % global_step) logging.info("the final step is %s" % global_step)
model_save_dir = os.path.join(args.output_dir, model_save_dir = os.path.join(args.output_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册