提交 74b7049e 编写于 作者: D Dilyar 提交者: Yibing Liu

bug fix on simnet (#3962)

上级 bb48a52c
......@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
def check_version():
"""
......
......@@ -28,7 +28,7 @@ class SimNetProcessor(object):
self.valid_label = np.array([])
self.test_label = np.array([])
def get_reader(self, mode):
def get_reader(self, mode, epoch=0):
"""
Get Reader
"""
......@@ -85,34 +85,35 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(
neg_title) == 0:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
pos_title = [
self.vocab[word] for word in pos_title.split(" ")
if word in self.vocab
]
neg_title = [
self.vocab[word] for word in neg_title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(pos_title) == 0:
pos_title = [0]
if len(neg_title) == 0:
neg_title = [0]
yield [query, pos_title, neg_title]
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(
neg_title) == 0:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
pos_title = [
self.vocab[word] for word in pos_title.split(" ")
if word in self.vocab
]
neg_title = [
self.vocab[word] for word in neg_title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(pos_title) == 0:
pos_title = [0]
if len(neg_title) == 0:
neg_title = [0]
yield [query, pos_title, neg_title]
def reader_with_pointwise():
"""
......@@ -166,30 +167,31 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
label = int(label)
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title, label]
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
label = int(label)
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title, label]
if self.args.task_mode == "pairwise":
return reader_with_pairwise
......
......@@ -140,7 +140,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost)
# Get Reader
get_train_examples = simnet_process.get_reader("train")
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid:
test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
......@@ -164,7 +164,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost)
# Get Feeder and Reader
get_train_examples = simnet_process.get_reader("train")
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid:
test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
......@@ -218,63 +218,63 @@ def train(conf_dict, args):
global_step = 0
ce_info = []
train_exe = exe
for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
args.batch_size,
drop_last=False)
train_pyreader.decorate_paddle_reader(train_batch_data)
train_pyreader.start()
exe.run(startup_prog)
losses = []
start_time = time.time()
while True:
try:
global_step += 1
fetch_list = [avg_cost.name]
avg_loss = train_exe.run(program=train_program, fetch_list = fetch_list)
if args.do_valid and global_step % args.validation_steps == 0:
get_valid_examples = simnet_process.get_reader("valid")
valid_result = valid_and_test(test_prog,test_pyreader,get_valid_examples,simnet_process,"valid",exe,[pred.name])
if args.compute_accuracy:
valid_auc, valid_acc = valid_result
logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f" %
(global_step, valid_auc, valid_acc))
else:
valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f" %
(global_step, valid_auc))
if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
if args.task_mode == "pairwise":
feed_var_names = [left.name, pos_right.name]
target_vars = [left_feat, pos_score]
else:
feed_var_names = [
left.name,
right.name,
]
target_vars = [left_feat, pred]
fluid.io.save_inference_model(model_path, feed_var_names,
target_vars, exe,
test_prog)
logging.info("saving infer model in %s" % model_path)
losses.append(np.mean(avg_loss[0]))
except fluid.core.EOFException:
train_pyreader.reset()
break
end_time = time.time()
logging.info("epoch: %d, loss: %f, used time: %d sec" %
(epoch_id, np.mean(losses), end_time - start_time))
ce_info.append([np.mean(losses), end_time - start_time])
#for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
args.batch_size,
drop_last=False)
train_pyreader.decorate_paddle_reader(train_batch_data)
train_pyreader.start()
exe.run(startup_prog)
losses = []
start_time = time.time()
while True:
try:
global_step += 1
fetch_list = [avg_cost.name]
avg_loss = train_exe.run(program=train_program, fetch_list = fetch_list)
if args.do_valid and global_step % args.validation_steps == 0:
get_valid_examples = simnet_process.get_reader("valid")
valid_result = valid_and_test(test_prog,test_pyreader,get_valid_examples,simnet_process,"valid",exe,[pred.name])
if args.compute_accuracy:
valid_auc, valid_acc = valid_result
logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc, np.mean(losses)))
else:
valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
if args.task_mode == "pairwise":
feed_var_names = [left.name, pos_right.name]
target_vars = [left_feat, pos_score]
else:
feed_var_names = [
left.name,
right.name,
]
target_vars = [left_feat, pred]
fluid.io.save_inference_model(model_path, feed_var_names,
target_vars, exe,
test_prog)
logging.info("saving infer model in %s" % model_path)
losses.append(np.mean(avg_loss[0]))
except fluid.core.EOFException:
train_pyreader.reset()
break
end_time = time.time()
#logging.info("epoch: %d, loss: %f, used time: %d sec" %
#(epoch_id, np.mean(losses), end_time - start_time))
ce_info.append([np.mean(losses), end_time - start_time])
#final save
logging.info("the final step is %s" % global_step)
model_save_dir = os.path.join(args.output_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册