提交 74b7049e 编写于 作者: D Dilyar 提交者: Yibing Liu

bug fix on simnet (#3962)

上级 bb48a52c
...@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \ ...@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \
except Exception as e: except Exception as e:
pass pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
def check_version(): def check_version():
""" """
......
...@@ -28,7 +28,7 @@ class SimNetProcessor(object): ...@@ -28,7 +28,7 @@ class SimNetProcessor(object):
self.valid_label = np.array([]) self.valid_label = np.array([])
self.test_label = np.array([]) self.test_label = np.array([])
def get_reader(self, mode): def get_reader(self, mode, epoch=0):
""" """
Get Reader Get Reader
""" """
...@@ -85,6 +85,7 @@ class SimNetProcessor(object): ...@@ -85,6 +85,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r", with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file: encoding="utf8") as file:
for line in file: for line in file:
...@@ -166,6 +167,7 @@ class SimNetProcessor(object): ...@@ -166,6 +167,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r", with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file: encoding="utf8") as file:
for line in file: for line in file:
......
...@@ -140,7 +140,7 @@ def train(conf_dict, args): ...@@ -140,7 +140,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
# Get Reader # Get Reader
get_train_examples = simnet_process.get_reader("train") get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid: if args.do_valid:
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
...@@ -164,7 +164,7 @@ def train(conf_dict, args): ...@@ -164,7 +164,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
# Get Feeder and Reader # Get Feeder and Reader
get_train_examples = simnet_process.get_reader("train") get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid: if args.do_valid:
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
...@@ -218,7 +218,7 @@ def train(conf_dict, args): ...@@ -218,7 +218,7 @@ def train(conf_dict, args):
global_step = 0 global_step = 0
ce_info = [] ce_info = []
train_exe = exe train_exe = exe
for epoch_id in range(args.epoch): #for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch( train_batch_data = fluid.io.batch(
fluid.io.shuffle( fluid.io.shuffle(
get_train_examples, buf_size=10000), get_train_examples, buf_size=10000),
...@@ -240,12 +240,12 @@ def train(conf_dict, args): ...@@ -240,12 +240,12 @@ def train(conf_dict, args):
if args.compute_accuracy: if args.compute_accuracy:
valid_auc, valid_acc = valid_result valid_auc, valid_acc = valid_result
logging.info( logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f" % "global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc)) (global_step, valid_auc, valid_acc, np.mean(losses)))
else: else:
valid_auc = valid_result valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f" % logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc)) (global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir, model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"]) conf_dict["model_path"])
...@@ -272,8 +272,8 @@ def train(conf_dict, args): ...@@ -272,8 +272,8 @@ def train(conf_dict, args):
train_pyreader.reset() train_pyreader.reset()
break break
end_time = time.time() end_time = time.time()
logging.info("epoch: %d, loss: %f, used time: %d sec" % #logging.info("epoch: %d, loss: %f, used time: %d sec" %
(epoch_id, np.mean(losses), end_time - start_time)) #(epoch_id, np.mean(losses), end_time - start_time))
ce_info.append([np.mean(losses), end_time - start_time]) ce_info.append([np.mean(losses), end_time - start_time])
#final save #final save
logging.info("the final step is %s" % global_step) logging.info("the final step is %s" % global_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册