提交 74b7049e 编写于 作者: D Dilyar 提交者: Yibing Liu

bug fix on simnet (#3962)

上级 bb48a52c
......@@ -33,6 +33,21 @@ def check_cuda(use_cuda, err = \
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
def check_version():
"""
......
......@@ -28,7 +28,7 @@ class SimNetProcessor(object):
self.valid_label = np.array([])
self.test_label = np.array([])
def get_reader(self, mode):
def get_reader(self, mode, epoch=0):
"""
Get Reader
"""
......@@ -85,6 +85,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
......@@ -166,6 +167,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
......
......@@ -140,7 +140,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost)
# Get Reader
get_train_examples = simnet_process.get_reader("train")
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid:
test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
......@@ -164,7 +164,7 @@ def train(conf_dict, args):
optimizer.ops(avg_cost)
# Get Feeder and Reader
get_train_examples = simnet_process.get_reader("train")
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
if args.do_valid:
test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
......@@ -218,7 +218,7 @@ def train(conf_dict, args):
global_step = 0
ce_info = []
train_exe = exe
for epoch_id in range(args.epoch):
#for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
......@@ -240,12 +240,12 @@ def train(conf_dict, args):
if args.compute_accuracy:
valid_auc, valid_acc = valid_result
logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f" %
(global_step, valid_auc, valid_acc))
"global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc, np.mean(losses)))
else:
valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f" %
(global_step, valid_auc))
logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
......@@ -272,8 +272,8 @@ def train(conf_dict, args):
train_pyreader.reset()
break
end_time = time.time()
logging.info("epoch: %d, loss: %f, used time: %d sec" %
(epoch_id, np.mean(losses), end_time - start_time))
#logging.info("epoch: %d, loss: %f, used time: %d sec" %
#(epoch_id, np.mean(losses), end_time - start_time))
ce_info.append([np.mean(losses), end_time - start_time])
#final save
logging.info("the final step is %s" % global_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册