提交 f8243f06 编写于 作者: J JesseyXujin 提交者: pkpk

fix args bug and remove test code (#3192)

* fix args bug and remove test code

* remove print code
上级 80719f59
......@@ -19,6 +19,15 @@ from __future__ import print_function
import argparse
def str2bool(v):
if v.lower() in ('yes', 'true', 'True'):
return True
elif v.lower() in ('no', 'false', 'False'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -65,9 +74,10 @@ def parse_args():
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument("--vocab_path", type=str, help="vocab file path")
# parser.add_argument(
# '--use_gpu', action='store_true',help='whether using gpu')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument('--enable_ce', action='store_true')
"--use_gpu", type=str2bool, default='True', help="Activate nice mode.")
parser.add_argument('--test_nccl', action='store_true')
parser.add_argument('--optim', default='adagrad', help='optimizer type')
parser.add_argument('--sample_softmax', action='store_true')
......@@ -99,9 +109,9 @@ def parse_args():
parser.add_argument('--proj_clip', type=float, default=3.0)
parser.add_argument('--cell_clip', type=float, default=3.0)
parser.add_argument('--max_epoch', type=float, default=10)
parser.add_argument('--local', type=bool, default=False)
parser.add_argument('--shuffle', type=bool, default=False)
parser.add_argument('--use_custom_samples', type=bool, default=False)
parser.add_argument('--local', type=str2bool, default='False')
parser.add_argument('--shuffle', type=str2bool, default='False')
parser.add_argument('--use_custom_samples', type=str2bool, default='False')
parser.add_argument('--para_save_dir', type=str, default='checkpoints')
parser.add_argument('--train_path', type=str, default='')
parser.add_argument('--test_path', type=str, default='')
......
......@@ -245,13 +245,6 @@ def eval(vocab, infer_progs, dev_count, logger, args):
def train():
args = parse_args()
if args.random_seed == 0:
args.random_seed = None
print("random seed is None")
if args.enable_ce:
random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger = logging.getLogger("lm")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
......@@ -259,7 +252,7 @@ def train():
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
args = parse_args()
logger.info('Running with args : {}'.format(args))
logger.info('Running paddle : {}'.format(paddle.version.commit))
......@@ -275,10 +268,6 @@ def train():
# build model
train_prog = fluid.Program()
train_startup_prog = fluid.Program()
if args.enable_ce:
train_prog.random_seed = args.random_seed
train_startup_prog.random_seed = args.random_seed
# build infer model
infer_prog = fluid.Program()
infer_startup_prog = fluid.Program()
......@@ -319,7 +308,6 @@ def train():
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
optimizer.minimize(train_model.loss * args.num_steps)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
......@@ -507,9 +495,11 @@ def train_loop(args,
n_batches_total = args.max_epoch * n_batches_per_epoch
begin_time = time.time()
ce_info = []
final_batch_id = 0
for batch_id, batch_list in enumerate(train_reader(), 1):
if batch_id > n_batches_total:
break
final_batch_id = batch_id
feed_data = batch_reader(batch_list, args)
feed = list(feeder.feed_parallel(feed_data, dev_count))
for i in range(dev_count):
......@@ -569,22 +559,9 @@ def train_loop(args,
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_prog)
if args.enable_ce:
card_num = get_cards()
ce_loss = 0
ce_time = 0
try:
ce_loss = ce_info[-2][0]
ce_time = ce_info[-2][1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time))
print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss))
end_time = time.time()
total_time += end_time - start_time
epoch_id = int(batch_id / n_batches_per_epoch)
epoch_id = int(final_batch_id / n_batches_per_epoch)
model_path = os.path.join(args.para_save_dir, str(epoch_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册