提交 0d308920 编写于 作者: R root

Merge branch 'develop' of https://github.com/PaddlePaddle/models into...

Merge branch 'develop' of https://github.com/PaddlePaddle/models into joe_upgrade_pretrain_lm_and_lm
1)升级language_model模型内接口到paddle 1.8版本
2)升级pretrain下bert和xlnet两个模块型内接口到paddle1.8版本
3)升级训练模型(language: 1个模型,pretrain_language_model: 2个模型)
接口涉及:
a) save/load
b) DataLoader
c) set_gradient_clip
d) embedding
e) one_hot
f) 部分io接口;
......@@ -200,8 +200,7 @@ def pointnet_fp_module(unknown, known, unknown_feats, known_feats, mlp, bn=True,
dist.stop_gradient = True
idx.stop_gradient = True
dist = fluid.layers.sqrt(dist)
ones = fluid.layers.fill_constant_batch_size_like(dist, dist.shape, dist.dtype, 1)
dist_recip = ones / (dist + 1e-8); # 1.0 / dist
dist_recip = 1.0 / (dist + 1e-8); # 1.0 / dist
norm = fluid.layers.reduce_sum(dist_recip, dim=-1, keep_dim=True)
weight = dist_recip / norm
weight.stop_gradient = True
......
......@@ -93,8 +93,8 @@ def get_reg_loss(pred_reg, reg_label, fg_mask, point_num, loc_scope,
x_res_norm_label = x_res_label / loc_bin_size
z_res_norm_label = z_res_label / loc_bin_size
x_bin_onehot = fluid.layers.one_hot(x_bin_label, depth=per_loc_bin_num)
z_bin_onehot = fluid.layers.one_hot(z_bin_label, depth=per_loc_bin_num)
x_bin_onehot = fluid.one_hot(x_bin_label[:, 0], depth=per_loc_bin_num)
z_bin_onehot = fluid.one_hot(z_bin_label[:, 0], depth=per_loc_bin_num)
loss_x_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, x_res_l: x_res_r] * x_bin_onehot, dim=1, keep_dim=True), x_res_norm_label)
loss_x_res = fluid.layers.reduce_mean(loss_x_res * fg_mask) * fg_scale
......@@ -115,7 +115,7 @@ def get_reg_loss(pred_reg, reg_label, fg_mask, point_num, loc_scope,
y_res_label = y_shift - (fluid.layers.cast(y_bin_label, dtype=y_shift.dtype) * loc_y_bin_size + loc_y_bin_size / 2.)
y_res_norm_label = y_res_label / loc_y_bin_size
y_bin_onehot = fluid.layers.one_hot(y_bin_label, depth=per_loc_bin_num)
y_bin_onehot = fluid.one_hot(y_bin_label[:, 0], depth=per_loc_bin_num)
loss_y_bin = fluid.layers.cross_entropy(pred_reg[:, y_bin_l: y_bin_r], y_bin_label)
loss_y_bin = fluid.layers.reduce_mean(loss_y_bin * fg_mask) * fg_scale
......@@ -169,7 +169,7 @@ def get_reg_loss(pred_reg, reg_label, fg_mask, point_num, loc_scope,
ry_res_label = shift_angle - (fluid.layers.cast(ry_bin_label, dtype=shift_angle.dtype) * angle_per_class + angle_per_class / 2)
ry_res_norm_label = ry_res_label / (angle_per_class / 2)
ry_bin_onehot = fluid.layers.one_hot(ry_bin_label, depth=num_head_bin)
ry_bin_onehot = fluid.one_hot(ry_bin_label[:, 0], depth=num_head_bin)
loss_ry_bin = fluid.layers.softmax_with_cross_entropy(pred_reg[:, ry_bin_l:ry_bin_r], ry_bin_label)
loss_ry_bin = fluid.layers.reduce_mean(loss_ry_bin * fg_mask) * fg_scale
loss_ry_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, ry_res_l: ry_res_r] * ry_bin_onehot, dim=1, keep_dim=True), ry_res_norm_label)
......
......@@ -184,8 +184,7 @@ def pointnet_fp_module(unknown, known, unknown_feats, known_feats, mlp, bn=True,
dist.stop_gradient = True
idx.stop_gradient = True
dist = fluid.layers.sqrt(dist)
ones = fluid.layers.fill_constant_batch_size_like(dist, dist.shape, dist.dtype, 1)
dist_recip = ones / (dist + 1e-8); # 1.0 / dist
dist_recip = 1.0 / (dist + 1e-8); # 1.0 / dist
norm = fluid.layers.reduce_sum(dist_recip, dim=-1, keep_dim=True)
weight = dist_recip / norm
weight.stop_gradient = True
......
......@@ -54,21 +54,31 @@ def cosine_warmup_decay(learning_rate, betas, warmup_factor, decay_factor,
warmup_step_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warmup_step), force_cpu=True)
with control_flow.Switch() as switch:
with switch.case(global_step < warmup_step_var):
cur_lr = annealing_cos(warmup_start_lr, learning_rate,
global_step / warmup_step_var)
fluid.layers.assign(cur_lr, lr)
cur_beta1 = annealing_cos(betas[0], betas[1],
global_step / warmup_step_var)
fluid.layers.assign(cur_beta1, beta1)
with switch.case(global_step >= warmup_step_var):
cur_lr = annealing_cos(learning_rate, decay_end_lr,
(global_step - warmup_step_var) / (total_step - warmup_step))
fluid.layers.assign(cur_lr, lr)
cur_beta1 = annealing_cos(betas[1], betas[0],
(global_step - warmup_step_var) / (total_step - warmup_step))
fluid.layers.assign(cur_beta1, beta1)
warmup_pred = global_step < warmup_step_var
decay_pred = global_step >= warmup_step_var
# learning rate warmup and decay
def warmup_lr():
return annealing_cos(warmup_start_lr, learning_rate,
global_step / warmup_step_var)
def decay_lr():
return annealing_cos(learning_rate, decay_end_lr,
(global_step - warmup_step_var) / (total_step - warmup_step))
lr = fluid.layers.case(pred_fn_pairs=[(warmup_pred, warmup_lr),
(decay_pred, decay_lr)])
# Adam beta1 warmup and decay
def warmup_beta1():
return annealing_cos(betas[0], betas[1],
global_step / warmup_step_var)
def decay_beta1():
return annealing_cos(betas[0], betas[1],
global_step / warmup_step_var)
beta1 = fluid.layers.case(pred_fn_pairs=[(warmup_pred, warmup_beta1),
(decay_pred, decay_beta1)])
return lr, beta1
......@@ -96,11 +106,11 @@ def optimize(loss,
raise ValueError("Unkown learning rate scheduler, should be "
"'cosine_warmup_decay'")
grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm)
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr,
beta1=scheduled_beta1,
beta2=beta2)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm))
beta2=beta2,
grad_clip=grad_clip)
param_list = dict()
......
......@@ -99,16 +99,24 @@ class reader_creator(object):
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
self.with_label = False
def reader():
batch_out = []
batch_out_label = []
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.lines)
for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ')
for i, line in enumerate(self.lines):
line = line.strip('\n\r\t').split(' ')
if len(line) > 1:
self.with_label = True
batch_out_label.append(line[1])
file = line[0]
else:
file = line[0]
self.name2id[os.path.basename(file)] = i
self.id2name[i] = os.path.basename(file)
img = Image.open(os.path.join(self.image_dir, file)).convert(
......@@ -133,10 +141,18 @@ class reader_creator(object):
batch_out.append(img)
if len(batch_out) == self.batch_size:
if return_name:
yield batch_out, batch_out_name
if self.with_label:
yield [[batch_out, batch_out_label, batch_out_name]]
batch_out_label = []
else:
yield batch_out, batch_out_name
batch_out_name = []
else:
yield [batch_out]
if self.with_label:
yield [[batch_out, batch_out_label]]
batch_out_label = []
else:
yield [batch_out]
batch_out = []
return reader
......@@ -667,8 +683,9 @@ class data_reader(object):
image_dir=dataset_dir,
list_filename=test_list,
batch_size=self.cfg.n_samples)
reader_test = test_reader.get_test_reader(
reader_test = test_reader.make_reader(
self.cfg, shuffle=False, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len()
return train_reader, reader_test, batch_num, id2name
reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num, id2name
......@@ -52,8 +52,13 @@ def eval(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[None, 1], dtype='int64')
test_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
# model definition
model = models.__dict__[model_name]()
......@@ -72,16 +77,21 @@ def eval(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
test_reader = paddle.batch(reader.test(args), batch_size=args.batch_size, drop_last=False)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_loader.set_sample_generator(
reader.test(args),
batch_size=args.batch_size,
drop_last=False,
places=place)
fetch_list = [out.name]
f, l = [], []
for batch_id, data in enumerate(test_reader()):
for batch_id, data in enumerate(test_loader()):
t1 = time.time()
[feas] = exe.run(test_program, fetch_list=fetch_list, feed=feeder.feed(data))
label = np.asarray([x[1] for x in data])
[feas] = exe.run(test_program, fetch_list=fetch_list, feed=data)
label = np.asarray(data[0]['label'])
label = np.squeeze(label)
f.append(feas)
l.append(label)
......
......@@ -51,7 +51,13 @@ def infer(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32')
infer_loader = fluid.io.DataLoader.from_generator(
feed_list=[image],
capacity=64,
use_double_buffer=True,
iterable=True)
# model definition
model = models.__dict__[model_name]()
......@@ -70,13 +76,16 @@ def infer(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
infer_reader = paddle.batch(reader.infer(args), batch_size=args.batch_size, drop_last=False)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
infer_loader.set_sample_generator(
reader.test(args),
batch_size=args.batch_size,
drop_last=False,
places=place)
fetch_list = [out.name]
for batch_id, data in enumerate(infer_reader()):
result = exe.run(test_program, fetch_list=fetch_list, feed=feeder.feed(data))
for batch_id, data in enumerate(infer_loader()):
result = exe.run(test_program, fetch_list=fetch_list, feed=data)
result = result[0][0].reshape(-1)
print("Test-{0}-feature: {1}".format(batch_id, result[:5]))
sys.stdout.flush()
......
......@@ -61,7 +61,8 @@ class ArcMarginLoss():
else:
phi = self.paddle_where_more_than(cosine, th, phi, cosine-mm)
one_hot = fluid.layers.one_hot(input=label, depth=out_dim)
one_hot = fluid.one_hot(input=label, depth=out_dim)
one_hot = fluid.layers.squeeze(input=one_hot, axes=[1])
output = fluid.layers.elementwise_mul(one_hot, phi) + fluid.layers.elementwise_mul((1.0 - one_hot), cosine)
output = output * s
return output
......
......@@ -22,6 +22,7 @@ import functools
import numpy as np
import paddle
from imgtool import process_image
import paddle.fluid as fluid
random.seed(0)
......@@ -187,7 +188,7 @@ def createreader(settings, mode):
keep_order = False if mode != 'train' or settings.loss_name in ['softmax', 'arcmargin'] else True
image_mapper = functools.partial(process_image,
mode=mode, color_jitter=False, rotate=False, crop_size=image_size)
reader = paddle.reader.xmap_readers(
reader = fluid.io.xmap_readers(
image_mapper, metric_reader, 8, 1000, order=keep_order)
return reader
......
......@@ -107,19 +107,16 @@ def build_program(is_train, main_prog, startup_prog, args):
image_shape = [int(m) for m in args.image_shape.split(",")]
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
if is_train:
queue_capacity = 64
py_reader = fluid.layers.py_reader(
queue_capacity = 64
image = fluid.layers.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=queue_capacity,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=True)
image, label = fluid.layers.read_file(py_reader)
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
use_double_buffer=True,
iterable=True)
with fluid.unique_name.guard():
avg_cost, acc_top1, acc_top5, out = net_config(image, label, model,
......@@ -137,9 +134,9 @@ def build_program(is_train, main_prog, startup_prog, args):
main_prog = main_prog.clone(for_test=True)
"""
if is_train:
return py_reader, avg_cost, acc_top1, acc_top5, global_lr
return loader, avg_cost, acc_top1, acc_top5, global_lr
else:
return out, image, label
return loader, out
def train_async(args):
......@@ -163,12 +160,12 @@ def train_async(args):
train_prog.random_seed = 1000
tmp_prog.random_seed = 1000
train_py_reader, train_cost, train_acc1, train_acc5, global_lr = build_program(
train_loader, train_cost, train_acc1, train_acc5, global_lr = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
args=args)
test_feas, image, label = build_program(
test_loader, test_feas = build_program(
is_train=False,
main_prog=tmp_prog,
startup_prog=startup_prog,
......@@ -182,6 +179,11 @@ def train_async(args):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers <= 1 and args.use_gpu:
places = fluid.framework.cuda_places()
else:
places = place
exe.run(startup_prog)
......@@ -206,12 +208,17 @@ def train_async(args):
train_batch_size = args.train_batch_size // devicenum
test_batch_size = args.test_batch_size
train_reader = paddle.batch(
reader.train(args), batch_size=train_batch_size, drop_last=True)
test_reader = paddle.batch(
reader.test(args), batch_size=test_batch_size, drop_last=False)
test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_py_reader.decorate_paddle_reader(train_reader)
train_loader.set_sample_generator(
reader.train(args),
batch_size=train_batch_size,
drop_last=True,
places=places)
test_loader.set_sample_generator(
reader.test(args),
batch_size=test_batch_size,
drop_last=False,
places=place)
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
......@@ -219,72 +226,76 @@ def train_async(args):
loss_name=train_cost.name)
totalruntime = 0
train_py_reader.start()
iter_no = 0
train_info = [0, 0, 0, 0]
while iter_no <= args.total_iter_num:
t1 = time.time()
lr, loss, acc1, acc5 = train_exe.run(fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
lr = np.mean(np.array(lr))
train_info[0] += np.mean(np.array(loss))
train_info[1] += np.mean(np.array(acc1))
train_info[2] += np.mean(np.array(acc5))
train_info[3] += 1
if iter_no % args.display_iter_step == 0:
avgruntime = totalruntime / args.display_iter_step
avg_loss = train_info[0] / train_info[3]
avg_acc1 = train_info[1] / train_info[3]
avg_acc5 = train_info[2] / train_info[3]
print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\
for train_batch in train_loader():
t1 = time.time()
lr, loss, acc1, acc5 = train_exe.run(
feed=train_batch,
fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
lr = np.mean(np.array(lr))
train_info[0] += np.mean(np.array(loss))
train_info[1] += np.mean(np.array(acc1))
train_info[2] += np.mean(np.array(acc5))
train_info[3] += 1
if iter_no % args.display_iter_step == 0:
avgruntime = totalruntime / args.display_iter_step
avg_loss = train_info[0] / train_info[3]
avg_acc1 = train_info[1] / train_info[3]
avg_acc5 = train_info[2] / train_info[3]
print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\
"acc1 %.4f, acc5 %.4f, time %2.2f sec" % \
(fmt_time(), iter_no, lr, avg_loss, avg_acc1, avg_acc5, avgruntime))
sys.stdout.flush()
totalruntime = 0
if iter_no % 1000 == 0:
train_info = [0, 0, 0, 0]
totalruntime += period
if iter_no % args.test_iter_step == 0 and iter_no != 0:
f, l = [], []
for batch_id, data in enumerate(test_reader()):
t1 = time.time()
[feas] = exe.run(test_prog,
fetch_list=test_fetch_list,
feed=test_feeder.feed(data))
label = np.asarray([x[1] for x in data])
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 20 == 0:
print("[%s] testbatch %d, time %2.2f sec" % \
sys.stdout.flush()
totalruntime = 0
if iter_no % 1000 == 0:
train_info = [0, 0, 0, 0]
totalruntime += period
if iter_no % args.test_iter_step == 0 and iter_no != 0:
f, l = [], []
for batch_id, test_batch in enumerate(test_loader()):
t1 = time.time()
[feas] = exe.run(test_prog,
feed=test_batch,
fetch_list=test_fetch_list)
label = np.asarray(test_batch[0]['label'])
label = np.squeeze(label)
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 20 == 0:
print("[%s] testbatch %d, time %2.2f sec" % \
(fmt_time(), batch_id, period))
f = np.vstack(f)
l = np.hstack(l)
recall = recall_topk(f, l, k=1)
print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \
(fmt_time(), len(f), iter_no, recall))
sys.stdout.flush()
if iter_no % args.save_iter_step == 0 and iter_no != 0:
model_path = os.path.join(model_save_dir + '/' + model_name,
str(iter_no))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
iter_no += 1
# This is for continuous evaluation only
if args.enable_ce:
# Use the mean cost/acc for training
print("kpis\ttrain_cost\t{}".format(avg_loss))
print("kpis\ttest_recall\t{}".format(recall))
f = np.vstack(f)
l = np.hstack(l)
recall = recall_topk(f, l, k=1)
print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \
(fmt_time(), len(f), iter_no, recall))
sys.stdout.flush()
if iter_no % args.save_iter_step == 0 and iter_no != 0:
model_path = os.path.join(model_save_dir + '/' + model_name,
str(iter_no))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
iter_no += 1
# This is for continuous evaluation only
if args.enable_ce:
# Use the mean cost/acc for training
print("kpis\ttrain_cost\t{}".format(avg_loss))
print("kpis\ttest_recall\t{}".format(recall))
def initlogging():
......
......@@ -114,19 +114,16 @@ def build_program(is_train, main_prog, startup_prog, args):
image_shape = [int(m) for m in args.image_shape.split(",")]
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
if is_train:
queue_capacity = 64
py_reader = fluid.layers.py_reader(
queue_capacity = 64
image = fluid.layers.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=queue_capacity,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=True)
image, label = fluid.layers.read_file(py_reader)
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
use_double_buffer=True,
iterable=True)
with fluid.unique_name.guard():
avg_cost, out = net_config(image, label, model, args, is_train)
......@@ -143,9 +140,9 @@ def build_program(is_train, main_prog, startup_prog, args):
main_prog = main_prog.clone(for_test=True)
"""
if is_train:
return py_reader, avg_cost, global_lr, out, label
return loader, avg_cost, global_lr, out, label
else:
return out, image, label
return loader, out
def train_async(args):
......@@ -161,12 +158,12 @@ def train_async(args):
train_prog = fluid.Program()
tmp_prog = fluid.Program()
train_py_reader, train_cost, global_lr, train_feas, train_label = build_program(
train_loader, train_cost, global_lr, train_feas, train_label = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
args=args)
test_feas, image, label = build_program(
test_loader, test_feas = build_program(
is_train=False,
main_prog=tmp_prog,
startup_prog=startup_prog,
......@@ -180,6 +177,11 @@ def train_async(args):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers <= 1 and args.use_gpu:
places = fluid.framework.cuda_places()
else:
places = place
exe.run(startup_prog)
......@@ -204,12 +206,17 @@ def train_async(args):
train_batch_size = args.train_batch_size / devicenum
test_batch_size = args.test_batch_size
train_reader = paddle.batch(
reader.train(args), batch_size=train_batch_size, drop_last=True)
test_reader = paddle.batch(
reader.test(args), batch_size=test_batch_size, drop_last=False)
test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_py_reader.decorate_paddle_reader(train_reader)
train_loader.set_sample_generator(
reader.train(args),
batch_size=train_batch_size,
drop_last=True,
places=places)
test_loader.set_sample_generator(
reader.test(args),
batch_size=test_batch_size,
drop_last=False,
places=place)
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
......@@ -217,64 +224,68 @@ def train_async(args):
loss_name=train_cost.name)
totalruntime = 0
train_py_reader.start()
iter_no = 0
train_info = [0, 0, 0]
while iter_no <= args.total_iter_num:
t1 = time.time()
lr, loss, feas, label = train_exe.run(fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
lr = np.mean(np.array(lr))
train_info[0] += np.mean(np.array(loss))
train_info[1] += recall_topk(feas, label, k=1)
train_info[2] += 1
if iter_no % args.display_iter_step == 0:
avgruntime = totalruntime / args.display_iter_step
avg_loss = train_info[0] / train_info[2]
avg_recall = train_info[1] / train_info[2]
print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\
for train_batch in train_loader():
t1 = time.time()
lr, loss, feas, label = train_exe.run(
feed=train_batch,
fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
lr = np.mean(np.array(lr))
train_info[0] += np.mean(np.array(loss))
train_info[1] += recall_topk(feas, label, k=1)
train_info[2] += 1
if iter_no % args.display_iter_step == 0:
avgruntime = totalruntime / args.display_iter_step
avg_loss = train_info[0] / train_info[2]
avg_recall = train_info[1] / train_info[2]
print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\
"recall %.4f, time %2.2f sec" % \
(fmt_time(), iter_no, lr, avg_loss, avg_recall, avgruntime))
sys.stdout.flush()
totalruntime = 0
if iter_no % 1000 == 0:
train_info = [0, 0, 0]
totalruntime += period
if iter_no % args.test_iter_step == 0 and iter_no != 0:
f, l = [], []
for batch_id, data in enumerate(test_reader()):
t1 = time.time()
[feas] = exe.run(test_prog,
fetch_list=test_fetch_list,
feed=test_feeder.feed(data))
label = np.asarray([x[1] for x in data])
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 20 == 0:
print("[%s] testbatch %d, time %2.2f sec" % \
sys.stdout.flush()
totalruntime = 0
if iter_no % 1000 == 0:
train_info = [0, 0, 0]
totalruntime += period
if iter_no % args.test_iter_step == 0 and iter_no != 0:
f, l = [], []
for batch_id, test_batch in enumerate(test_loader()):
t1 = time.time()
[feas] = exe.run(test_prog,
feed=test_batch,
fetch_list=test_fetch_list)
label = np.asarray(test_batch[0]['label'])
label = np.squeeze(label)
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 20 == 0:
print("[%s] testbatch %d, time %2.2f sec" % \
(fmt_time(), batch_id, period))
f = np.vstack(f)
l = np.hstack(l)
recall = recall_topk(f, l, k=1)
print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \
f = np.vstack(f)
l = np.hstack(l)
recall = recall_topk(f, l, k=1)
print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \
(fmt_time(), len(f), iter_no, recall))
sys.stdout.flush()
sys.stdout.flush()
if iter_no % args.save_iter_step == 0 and iter_no != 0:
model_path = os.path.join(model_save_dir + '/' + model_name,
if iter_no % args.save_iter_step == 0 and iter_no != 0:
model_path = os.path.join(model_save_dir + '/' + model_name,
str(iter_no))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
iter_no += 1
iter_no += 1
def initlogging():
......
......@@ -211,8 +211,8 @@ def dropout2d(input, prob, is_train=False):
return input
channels = input.shape[1]
keep_prob = 1.0 - prob
random_tensor = keep_prob + layers.uniform_random_batch_size_like(
input, [-1, channels, 1, 1], min=0., max=1.)
random_tensor = np.random.uniform(0, 1, [input.shape[0], channels, 1, 1]).astype(np.float32)
random_tensor = keep_prob + dygraph.to_variable(random_tensor)
binary_tensor = layers.floor(random_tensor)
output = input / keep_prob * binary_tensor
return output
......@@ -99,12 +99,19 @@ python train.py --model_name=STNET \
bash run.sh train STNET ./configs/stnet.yaml
```
多卡分布式训练 + GPU视频解码和预处理(仅限TSN模型)
``` bash
bash run_dist.sh train TSN ./configs/tsn_dist_and_dali.yaml
```
- 请根据`CUDA_VISIBLE_DEVICES`指定卡数修改`config`文件中的`num_gpus``batch_size`配置。
- 使用CPU训练时请在run.sh中设置use\_gpu=False,使用GPU训练时则设置use\_gpu=True
- 上述启动脚本run.sh运行时需要指定任务类型、模型名、配置文件。训练、评估和预测对应的任务类型分别是train,eval和predict。模型名称则是[AttentionCluster, AttentionLSTM, NEXTVLAD, NONLOCAL, STNET, TSN, TSM, CTCN]中的任何一个。配置文件全部在PaddleVideo/configs目录下,根据模型名称选择对应的配置文件即可。具体使用请参见各模型的说明文档。
- 目前针对TSN模型,做了GPU解码和数据预处理的优化,能明显提升训练速度,具体请参考[TSN](./models/tsn/README.md)
## 模型库结构
......
MODEL:
name: "TSN"
format: "pkl"
num_classes: 400
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 45
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 256
use_gpu: True
num_gpus: 8
filelist: "./data/dataset/kinetics/train_video_file.list"
learning_rate: 0.01
learning_rate_decay: 0.1
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 224684
num_trainers: 1 # this will be determined by fleet implicitly, no need to set
trainer_id: 0 # this will be determined by fleet implicitly, no need to set
use_dali: True
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 256
filelist: "./data/dataset/kinetics/val_video_file.list"
TEST:
seg_num: 7
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 16
filelist: "./data/dataset/kinetics/test_video_file.list"
use_dali: True
INFER:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 1
filelist: "./data/dataset/kinetics/infer_video_file.list"
video_path: ""
kinetics_labels: "./data/dataset/kinetics_labels.json"
......@@ -81,6 +81,7 @@ def test(args):
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
print_configs(test_config, "Test")
use_dali = test_config['TEST'].get('use_dali', False)
# build model
test_model = models.get_model(args.model_name, test_config, mode='test')
......@@ -127,6 +128,10 @@ def test(args):
feed=test_feeder.feed(feat_data),
return_numpy=True)
test_outs += [vinfo]
elif args.model_name == 'TSN' and use_dali:
test_outs = exe.run(fetch_list=test_fetch_list,
feed={'image': data[0],
'label': data[1]})
else:
test_outs = exe.run(fetch_list=test_fetch_list,
feed=test_feeder.feed(data))
......
......@@ -54,6 +54,18 @@ TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。
* 权重衰减系数为1e-4
* 学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减
**训练速度优化:**
* 使用GPU解码优化视频源文件读取和预处理速度,需要预先安装NVIDIA/DALI
* DALI的安装方式请参考NVIDIA/DALI[官方文档](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/compilation.html#)。由于NVIDIA/DALI提供的VideoReader OP不支持TSN模型的采样方式,请使用[SunGaofeng/DALI](https://github.com/SunGaofeng/DALI)提供的源码,提供了时间维度的稀疏采样方式
* 使用分布式训练的方式提升多卡加速比
启动脚本为:
``` bash
bash run_dist.sh train TSN ./configs/tsn_dist_and_dali.yaml
```
## 模型评估
可通过如下两种方式进行模型评估:
......
......@@ -26,6 +26,16 @@ except ImportError:
from io import BytesIO
import numpy as np
import paddle
try:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import tempfile
from nvidia.dali.plugin.paddle import DALIGenericIterator
except:
print("DALI is not installed, you can improve performance if use DALI")
from PIL import Image, ImageEnhance
import logging
......@@ -76,6 +86,13 @@ class KineticsReader(DataReader):
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
# set num_trainers and trainer_id when distributed training is implemented
self.num_trainers = self.get_config_from_sec(mode, 'num_trainers', 1)
self.trainer_id = self.get_config_from_sec(mode, 'trainer_id', 0)
self.use_dali = self.get_config_from_sec(mode, 'use_dali', False)
self.dali_mean = cfg.MODEL.image_mean * (self.seg_num * self.seglen)
self.dali_std = cfg.MODEL.image_std * (self.seg_num * self.seglen)
if self.mode == 'infer':
self.video_path = cfg[mode.upper()]['video_path']
else:
......@@ -86,6 +103,10 @@ class KineticsReader(DataReader):
self.num_reader_threads = 1
def create_reader(self):
# if use_dali to improve performance
if self.use_dali:
return self.build_dali_reader()
# if set video_path for inference mode, just load this single video
if (self.mode == 'infer') and (self.video_path != ''):
# load video from file stored at video_path
......@@ -201,11 +222,35 @@ class KineticsReader(DataReader):
return imgs_transform(imgs, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std, name = self.name), ret_label
def reader():
def reader_():
with open(pickle_list) as flist:
lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(lines)
full_lines = [line.strip() for line in flist]
if self.mode == 'train':
if (not hasattr(reader_, 'seed')):
reader_.seed = 0
random.Random(reader_.seed).shuffle(full_lines)
print("reader shuffle seed", reader_.seed)
if reader_.seed is not None:
reader_.seed += 1
per_node_lines = int(
math.ceil(len(full_lines) * 1.0 / self.num_trainers))
total_lines = per_node_lines * self.num_trainers
# aligned full_lines so that it can evenly divisible
full_lines += full_lines[:(total_lines - len(full_lines))]
assert len(full_lines) == total_lines
# trainer get own sample
lines = full_lines[self.trainer_id:total_lines:
self.num_trainers]
logger.info("trainerid %d, trainer_count %d" %
(self.trainer_id, self.num_trainers))
logger.info(
"read images from %d, length: %d, lines length: %d, total: %d"
% (self.trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
assert len(lines) == per_node_lines
for line in lines:
pickle_path = line.strip()
yield [pickle_path]
......@@ -227,7 +272,251 @@ class KineticsReader(DataReader):
img_mean=img_mean,
img_std=img_std)
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
return paddle.reader.xmap_readers(mapper, reader_, num_threads,
buf_size)
def build_dali_reader(self):
"""
build dali training reader
"""
def reader_():
with open(self.filelist) as flist:
full_lines = [line for line in flist]
if self.mode == 'train':
if (not hasattr(reader_, 'seed')):
reader_.seed = 0
random.Random(reader_.seed).shuffle(full_lines)
print("reader shuffle seed", reader_.seed)
if reader_.seed is not None:
reader_.seed += 1
per_node_lines = int(
math.ceil(len(full_lines) * 1.0 / self.num_trainers))
total_lines = per_node_lines * self.num_trainers
# aligned full_lines so that it can evenly divisible
full_lines += full_lines[:(total_lines - len(full_lines))]
assert len(full_lines) == total_lines
# trainer get own sample
lines = full_lines[self.trainer_id:total_lines:
self.num_trainers]
assert len(lines) == per_node_lines
logger.info("trainerid %d, trainer_count %d" %
(self.trainer_id, self.num_trainers))
logger.info(
"read images from %d, length: %d, lines length: %d, total: %d"
% (self.trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
video_files = ''
for item in lines:
video_files += item
tf = tempfile.NamedTemporaryFile()
tf.write(str.encode(video_files))
tf.flush()
video_files = tf.name
device_id = int(os.getenv('FLAGS_selected_gpus', 0))
print('---------- device id -----------', device_id)
if self.mode == 'train':
pipe = VideoPipe(
batch_size=self.batch_size,
num_threads=1,
device_id=device_id,
file_list=video_files,
sequence_length=self.seg_num * self.seglen,
seg_num=self.seg_num,
seg_length=self.seglen,
resize_shorter_scale=self.short_size,
crop_target_size=self.target_size,
is_training=(self.mode == 'train'),
dali_mean=self.dali_mean,
dali_std=self.dali_std)
else:
pipe = VideoTestPipe(
batch_size=self.batch_size,
num_threads=1,
device_id=device_id,
file_list=video_files,
sequence_length=self.seg_num * self.seglen,
seg_num=self.seg_num,
seg_length=self.seglen,
resize_shorter_scale=self.short_size,
crop_target_size=self.target_size,
is_training=(self.mode == 'train'),
dali_mean=self.dali_mean,
dali_std=self.dali_std)
logger.info(
'initializing dataset, it will take several minutes if it is too large .... '
)
video_loader = DALIGenericIterator(
[pipe], ['image', 'label'],
len(lines),
dynamic_shape=True,
auto_reset=True)
return video_loader
dali_reader = reader_()
def ret_reader():
for data in dali_reader:
yield data[0]['image'], data[0]['label']
return ret_reader
class VideoPipe(Pipeline):
def __init__(self,
batch_size,
num_threads,
device_id,
file_list,
sequence_length,
seg_num,
seg_length,
resize_shorter_scale,
crop_target_size,
is_training=False,
initial_prefetch_size=10,
num_shards=1,
shard_id=0,
dali_mean=0.,
dali_std=1.0):
super(VideoPipe, self).__init__(batch_size, num_threads, device_id)
self.input = ops.VideoReader(
device="gpu",
file_list=file_list,
sequence_length=sequence_length,
seg_num=seg_num,
seg_length=seg_length,
is_training=is_training,
num_shards=num_shards,
shard_id=shard_id,
random_shuffle=is_training,
initial_fill=initial_prefetch_size)
# the sequece data read by ops.VideoReader is of shape [F, H, W, C]
# Because the ops.Resize does not support sequence data,
# it will be transposed into [H, W, F, C],
# then reshaped to [H, W, FC], and then resized like a 2-D image.
self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
self.reshape = ops.Reshape(
device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
self.resize = ops.Resize(
device="gpu", resize_shorter=resize_shorter_scale)
# crops and mirror are applied by ops.CropMirrorNormalize.
# Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
# It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
self.pos_rng_x = ops.Uniform(range=(0.0, 1.0))
self.pos_rng_y = ops.Uniform(range=(0.0, 1.0))
self.mirror_generator = ops.Uniform(range=(0.0, 1.0))
self.cast_mirror = ops.Cast(dtype=types.DALIDataType.INT32)
self.crop_mirror_norm = ops.CropMirrorNormalize(
device="gpu",
crop=[crop_target_size, crop_target_size],
mean=dali_mean,
std=dali_std)
self.reshape_back = ops.Reshape(
device="gpu",
shape=[
seg_num, seg_length * 3, crop_target_size, crop_target_size
],
layout='FCHW')
self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)
def define_graph(self):
output, label = self.input(name="Reader")
output = self.transpose(output)
output = self.reshape(output)
output = self.resize(output)
output = output / 255.
pos_x = self.pos_rng_x()
pos_y = self.pos_rng_y()
mirror_flag = self.mirror_generator()
mirror_flag = (mirror_flag > 0.5)
mirror_flag = self.cast_mirror(mirror_flag)
#output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
output = self.crop_mirror_norm(
output, crop_pos_x=pos_x, crop_pos_y=pos_y, mirror=mirror_flag)
output = self.reshape_back(output)
label = self.cast_label(label)
return output, label
class VideoTestPipe(Pipeline):
def __init__(self,
batch_size,
num_threads,
device_id,
file_list,
sequence_length,
seg_num,
seg_length,
resize_shorter_scale,
crop_target_size,
is_training=False,
initial_prefetch_size=10,
num_shards=1,
shard_id=0,
dali_mean=0.,
dali_std=1.0):
super(VideoTestPipe, self).__init__(batch_size, num_threads, device_id)
self.input = ops.VideoReader(
device="gpu",
file_list=file_list,
sequence_length=sequence_length,
seg_num=seg_num,
seg_length=seg_length,
is_training=is_training,
num_shards=num_shards,
shard_id=shard_id,
random_shuffle=is_training,
initial_fill=initial_prefetch_size)
# the sequece data read by ops.VideoReader is of shape [F, H, W, C]
# Because the ops.Resize does not support sequence data,
# it will be transposed into [H, W, F, C],
# then reshaped to [H, W, FC], and then resized like a 2-D image.
self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
self.reshape = ops.Reshape(
device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
self.resize = ops.Resize(
device="gpu", resize_shorter=resize_shorter_scale)
# crops and mirror are applied by ops.CropMirrorNormalize.
# Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
# It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
self.crop_mirror_norm = ops.CropMirrorNormalize(
device="gpu",
crop=[crop_target_size, crop_target_size],
crop_pos_x=0.5,
crop_pos_y=0.5,
mirror=0,
mean=dali_mean,
std=dali_std)
self.reshape_back = ops.Reshape(
device="gpu",
shape=[
seg_num, seg_length * 3, crop_target_size, crop_target_size
],
layout='FCHW')
self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)
def define_graph(self):
output, label = self.input(name="Reader")
output = self.transpose(output)
output = self.reshape(output)
output = self.resize(output)
output = output / 255.
#output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
output = self.crop_mirror_norm(output)
output = self.reshape_back(output)
label = self.cast_label(label)
return output, label
def imgs_transform(imgs,
......
# examples of running programs:
# bash ./run.sh train CTCN ./configs/ctcn.yaml
# bash ./run.sh eval NEXTVLAD ./configs/nextvlad.yaml
# bash ./run.sh predict NONLOCAL ./cofings/nonlocal.yaml
# mode should be one of [train, eval, predict, inference]
# name should be one of [AttentionCluster, AttentionLSTM, NEXTVLAD, NONLOCAL, TSN, TSM, STNET, CTCN]
# configs should be ./configs/xxx.yaml
mode=$1
name=$2
configs=$3
pretrain="" # set pretrain model path if needed
resume="" # set pretrain model path if needed
save_dir="./data/checkpoints"
save_inference_dir="./data/inference_model"
use_gpu=True
fix_random_seed=False
log_interval=1
valid_interval=1
weights="" #set the path of weights to enable eval and predicut, just ignore this when training
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#export CUDA_VISIBLE_DEVICES=0,1,2,3
#export CUDA_VISIBLE_DEVICES=0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
if [ "$mode"x == "train"x ]; then
echo $mode $name $configs $resume $pretrain
if [ "$resume"x != ""x ]; then
python -m paddle.distributed.launch --log_dir=log \
train_dist.py --model_name=$name \
--config=$configs \
--resume=$resume \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
elif [ "$pretrain"x != ""x ]; then
python -m paddle.distributed.launch --log_dir=log \
train_dist.py --model_name=$name \
--config=$configs \
--pretrain=$pretrain \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
else
python -m paddle.distributed.launch --log_dir=log \
train_dist.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
fi
else
echo "Not implemented mode " $mode
fi
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import argparse
import ast
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from utils.train_utils import train_with_dataloader
import models
from utils.config_utils import *
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
from utils.utility import check_version
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle Video train script")
parser.add_argument(
'--model_name',
type=str,
default='AttentionCluster',
help='name of model to train.')
parser.add_argument(
'--config',
type=str,
default='configs/tsn_dist_and_dali.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--learning_rate',
type=float,
default=None,
help='learning rate use for training. None to use config file setting.')
parser.add_argument(
'--pretrain',
type=str,
default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
)
parser.add_argument(
'--resume',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--no_memory_optimize',
action='store_true',
default=False,
help='whether to use memory optimize in train')
parser.add_argument(
'--epoch',
type=int,
default=None,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--valid_interval',
type=int,
default=1,
help='validation epoch interval, 0 for no validation.')
parser.add_argument(
'--save_dir',
type=str,
default=os.path.join('data', 'checkpoints'),
help='directory name to save train snapshoot')
parser.add_argument(
'--log_interval',
type=int,
default=10,
help='mini-batch interval to log.')
parser.add_argument(
'--fix_random_seed',
type=ast.literal_eval,
default=False,
help='If set True, enable continuous evaluation job.')
# NOTE: args for profiler, used for benchmark
parser.add_argument(
'--profiler_path',
type=str,
default='./',
help='the path to store profiler output file. used for benchmark.')
parser.add_argument(
'--is_profiler',
type=int,
default=0,
help='the switch profiler. used for benchmark.')
parser.add_argument(
'--num_trainers',
type=int,
default=1,
help='the number of trainers when used in distributed training. No need to set this, it will be set automatically'
)
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id when used in distributed training. No need to set this, it will be set automatically'
)
args = parser.parse_args()
return args
def train(args):
# implement distributed training by fleet
use_fleet = True
if use_fleet:
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
args.num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
print('-------------', args.num_trainers, args.trainer_id)
if args.trainer_id == 0:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# parse config
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
print_configs(train_config, 'Train')
train_model = models.get_model(args.model_name, train_config, mode='train')
# build model
startup = fluid.Program()
train_prog = fluid.Program()
if args.fix_random_seed:
startup.random_seed = 1000
train_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
train_model.build_input(use_dataloader=True)
train_model.build_model()
# for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
train_feeds = train_model.feeds()
train_fetch_list = train_model.fetches()
train_loss = train_fetch_list[0]
optimizer = train_model.optimizer()
if use_fleet:
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(train_loss)
train_dataloader = train_model.dataloader()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
if args.resume:
# if resume weights is given, load resume weights directly
assert os.path.exists(args.resume + '.pdparams'), \
"Given resume weight dir {}.pdparams not exist.".format(args.resume)
fluid.load(train_prog, model_path=args.resume, executor=exe)
else:
# if not in resume mode, load pretrain weights
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
if args.model_name in ['CTCN']:
build_strategy.enable_sequential_execution = True
exec_strategy = fluid.ExecutionStrategy()
if use_fleet:
compiled_train_prog = fleet.main_program
else:
compiled_train_prog = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel(
loss_name=train_loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
# get reader
bs_denominator = 1
if args.use_gpu:
# check number of GPUs
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
if gpus == "":
pass
else:
gpus = gpus.split(",")
num_gpus = len(gpus)
assert num_gpus == train_config.TRAIN.num_gpus, \
"num_gpus({}) set by CUDA_VISIBLE_DEVICES " \
"shoud be the same as that " \
"set in {}({})".format(
num_gpus, args.config, train_config.TRAIN.num_gpus)
bs_denominator = train_config.TRAIN.num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)
train_reader = get_reader(args.model_name.upper(), 'train', train_config)
# get metrics
train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
epochs = args.epoch or train_model.epoch_num()
exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_dataloader.set_batch_generator(train_reader, places=place)
train_with_dataloader(
exe,
train_prog,
compiled_train_prog,
train_dataloader,
train_fetch_list,
train_metrics,
epochs=epochs,
log_interval=args.log_interval,
save_dir=args.save_dir,
num_trainers=args.num_trainers,
trainer_id=args.trainer_id,
save_model_name=args.model_name,
fix_random_seed=args.fix_random_seed,
is_profiler=args.is_profiler,
profiler_path=args.profiler_path)
if __name__ == "__main__":
args = parse_args()
# check whether the installed paddle is compiled with GPU
check_cuda(args.use_gpu)
check_version()
logger.info(args)
train(args)
......@@ -75,6 +75,7 @@ def test_with_dataloader(exe,
def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader, \
train_fetch_list, train_metrics, epochs = 10, \
log_interval = 0, valid_interval = 0, save_dir = './', \
num_trainers = 1, trainer_id = 0, \
save_model_name = 'model', fix_random_seed = False, \
compiled_test_prog = None, test_dataloader = None, \
test_fetch_list = None, test_metrics = None, \
......@@ -89,17 +90,21 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
train_iter = 0
epoch_periods = []
cur_time = time.time()
for data in train_dataloader():
cur_time = time.time()
train_outs = exe.run(compiled_train_prog,
fetch_list=train_fetch_list,
feed=data)
period = time.time() - cur_time
epoch_periods.append(period)
timeStamp = time.time()
localTime = time.localtime(timeStamp)
strTime = time.strftime("%Y-%m-%d %H:%M:%S", localTime)
if log_interval > 0 and (train_iter % log_interval == 0):
train_metrics.calculate_and_log_out(train_outs, \
info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
info = '[TRAIN {}] Epoch {}, iter {}, time {}, '.format(strTime, epoch, train_iter, period))
train_iter += 1
cur_time = time.time()
# NOTE: profiler tools, used for benchmark
if is_profiler and epoch == 0 and train_iter == log_interval:
......@@ -115,15 +120,18 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
logger.info('[TRAIN] Epoch {} training finished, average time: {}'.
format(epoch, np.mean(epoch_periods[1:])))
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if trainer_id == 0:
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if compiled_test_prog and valid_interval > 0 and (
epoch + 1) % valid_interval == 0:
test_with_dataloader(exe, compiled_test_prog, test_dataloader,
test_fetch_list, test_metrics, log_interval,
save_model_name)
save_model(exe, train_prog, save_dir, save_model_name)
if trainer_id == 0:
save_model(exe, train_prog, save_dir, save_model_name)
#when fix_random seed for debug
if fix_random_seed:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
......
......@@ -10,11 +10,11 @@
## 快速开始
**目前模型要求使用PaddlePaddle 1.6及以上版本或适当的develop版本运行。**
**目前模型要求使用PaddlePaddle 1.8及以上版本或适当的develop版本运行。**
### 1. Paddle版本安装
本项目训练模块兼容Python2.7.x以及Python3.7.x, 依赖PaddlePaddle 1.6版本以及CentOS系统环境, 安装请参考官网 [快速安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
本项目训练模块兼容Python2.7.x以及Python3.7.x, 依赖PaddlePaddle 1.8版本以及CentOS系统环境, 安装请参考官网 [快速安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
注意:该模型同时支持cpu和gpu训练和预测,用户可以根据自身需求,选择安装对应的paddlepaddle-gpu或paddlepaddle版本。
......
文件模式从 100755 更改为 100644
......@@ -452,7 +452,7 @@ def main(args):
if args.use_cuda:
test_place = fluid.cuda_places(0)
place = fluid.cuda_places()
DEV_COUNT = fluid.core.get_cuda_device_count()
DEV_COUNT = len(place)
else:
test_place = fluid.cpu_places(1)
os.environ['CPU_NUM'] = str(args.cpu_num)
......
......@@ -130,12 +130,11 @@ class DataReader(object):
assert os.path.exists(data_path), "The given data file does not exist."
if mode == "train":
train_reader = fluid.io.batch(
paddle.reader.shuffle(
fluid.io.shuffle(
self.data_reader(
data_path, self.max_len, shuffle=True),
buf_size=batch_size * 100),
batch_size)
# train_reader = fluid.io.batch(self.data_reader(data_path), batch_size)
return train_reader
else:
test_reader = fluid.io.batch(
......
......@@ -30,7 +30,7 @@
- cuda >= 9.0
- cudnn >= 7.0
- pandas >= 0.20.1
- PaddlePaddle >= 1.7.0,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装, 本模块使用bert作为pretrain model进行模型的finetuning训练,训练速度较慢,建议安装GPU版本的PaddlePaddle
- PaddlePaddle >= 1.8.0,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装, 本模块使用bert作为pretrain model进行模型的finetuning训练,训练速度较慢,建议安装GPU版本的PaddlePaddle
#### &ensp;&ensp;b、下载代码
......@@ -119,13 +119,10 @@ emb_size: embedding层大小
vocab_size: 词表大小
sample_pro: 采样比率
output_prediction_file: 输出的预测文件
init_from_checkpoint: 加载断点模型
init_from_params: 训练好的模型参数文件,一般用于预测
init_from_pretrain_model: 预训练模型路径,如bert的模型参数
inference_model_dir: inference model的保存路径
save_model_path: 训练产出模型的输出路径
save_checkpoint: 调用paddle的io接口save_persistables(把传入的层中所有参数以及优化器进行保存)来保存模型参数
save_param: 调用paddle的io接口save_params(从main_program中取出所有参数然后保存到文件中)来保存模型参数
evaluation_file: 参与评估的inference 文件
vocab_path: 词表路径
max_seq_len: 输入最大序列长度
......@@ -199,7 +196,6 @@ python -u main.py \
--loss_type="CLS" \
--max_seq_len=50 \
--save_model_path="data/saved_models/matching_pretrained" \
--save_param="params" \
--training_file="data/input/data/unlabel_data/train.ids" \
--epoch=20 \
--print_step=1 \
......@@ -217,7 +213,7 @@ python -u main.py \
#### windows环境下:
训练:
```
python -u main.py --do_train=true --use_cuda=false --loss_type=CLS --max_seq_len=50 --save_model_path=data\saved_models\matching_pretrained --save_param=params --training_file=data\input\data\unlabel_data\train.ids --epoch=20 --print_step=1 --save_step=400 --batch_size=256 --hidden_size=256 --emb_size=256 --vocab_size=484016 --learning_rate=0.001 --sample_pro=0.1
python -u main.py --do_train=true --use_cuda=false --loss_type=CLS --max_seq_len=50 --save_model_path=data\saved_models\matching_pretrained --training_file=data\input\data\unlabel_data\train.ids --epoch=20 --print_step=1 --save_step=400 --batch_size=256 --hidden_size=256 --emb_size=256 --vocab_size=484016 --learning_rate=0.001 --sample_pro=0.1
```
#### 2、第二阶段finetuning模型的训练:
......@@ -271,9 +267,8 @@ python -u main.py \
--use_cuda=${use_cuda} \
--loss_type="L2" \
--max_seq_len=50 \
--init_from_pretrain_model="data/saved_models/trained_models/matching_pretrained/params" \
--init_from_pretrain_model="data/saved_models/trained_models/matching_pretrained/params/params" \
--save_model_path="data/saved_models/human_finetuned" \
--save_param="params" \
--training_file="data/input/data/label_data/human/train.ids" \
--epoch=50 \
--print_step=1 \
......@@ -288,7 +283,7 @@ python -u main.py \
#### windows环境下:
```
python -u main.py --do_train=true --use_cuda=false --loss_type=L2 --max_seq_len=50 --save_model_path=data\saved_models\human_finetuned --save_param=params --training_file=data\input\data\label_data\human\train.ids --epoch=50 --print_step=1 --save_step=400 --batch_size=256 --hidden_size=256 --emb_size=256 --vocab_size=484016 --learning_rate=0.001 --sample_pro=0.1
python -u main.py --do_train=true --use_cuda=false --loss_type=L2 --max_seq_len=50 --save_model_path=data\saved_models\human_finetuned --training_file=data\input\data\label_data\human\train.ids --epoch=50 --print_step=1 --save_step=400 --batch_size=256 --hidden_size=256 --emb_size=256 --vocab_size=484016 --learning_rate=0.001 --sample_pro=0.1
```
### 模型预测
......
......@@ -29,7 +29,7 @@ DATA_MODEL_PATH = {
"DATA_PATH":
"https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_dataset-1.0.0.tar.gz",
"TRAINED_MODEL":
"https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.2.0.0.tar.gz"
"https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.3.0.0.tar.gz"
}
PATH_MAP = {'DATA_PATH': "./data/input", 'TRAINED_MODEL': './data/saved_models'}
......
......@@ -34,7 +34,7 @@ def create_net(is_training,
label = model_input.labels
#emb
context_emb = fluid.input.embedding(
context_emb = fluid.embedding(
input=context_wordseq,
size=[args.vocab_size, args.emb_size],
is_sparse=True,
......@@ -42,7 +42,7 @@ def create_net(is_training,
name=word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1)))
response_emb = fluid.input.embedding(
response_emb = fluid.embedding(
input=response_wordseq,
size=[args.vocab_size, args.emb_size],
is_sparse=True,
......
......@@ -14,13 +14,10 @@ emb_size: 256
vocab_size: 484016
sample_pro: 1.0
output_prediction_file: ""
init_from_checkpoint: ""
init_from_params: ""
init_from_pretrain_model: ""
inference_model_dir: ""
save_model_path: ""
save_checkpoint: ""
save_param: ""
evaluation_file: ""
vocab_path: ""
max_seq_len: 128
......
......@@ -27,7 +27,6 @@ from ade_net import create_net
from ade.utils.configure import PDConfig
from ade.utils.input_field import InputField
from ade.utils.model_check import check_cuda
import ade.utils.save_load_io as save_load_io
def do_save_inference_model(args):
......@@ -55,7 +54,7 @@ def do_save_inference_model(args):
input_inst = [context_wordseq, response_wordseq, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(
data_reader = fluid.io.DataLoader.from_generator(
feed_list=input_inst, capacity=4, iterable=False)
logits = create_net(
......@@ -72,9 +71,9 @@ def do_save_inference_model(args):
assert (args.init_from_params) or (args.init_from_pretrain_model)
if args.init_from_params:
save_load_io.init_from_params(args, exe, test_prog)
fluid.load(test_prog, args.init_from_params)
elif args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, test_prog)
fluid.load(test_prog, args.init_from_pretrain_model)
# saving inference model
fluid.io.save_inference_model(
......
......@@ -29,7 +29,6 @@ from ade_net import create_net
from ade.utils.configure import PDConfig
from ade.utils.input_field import InputField
from ade.utils.model_check import check_cuda
import ade.utils.save_load_io as save_load_io
def do_predict(args):
......@@ -59,12 +58,11 @@ def do_predict(args):
input_inst = [context_wordseq, response_wordseq, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(
data_reader = fluid.io.DataLoader.from_generator(
feed_list=input_inst, capacity=4, iterable=False)
logits = create_net(
is_training=False, model_input=input_field, args=args)
logits.persistable = True
fetch_list = [logits.name]
#for_test is True if change the is_test attribute of operators to True
......@@ -79,9 +77,9 @@ def do_predict(args):
assert (args.init_from_params) or (args.init_from_pretrain_model)
if args.init_from_params:
save_load_io.init_from_params(args, exe, test_prog)
fluid.load(test_prog, args.init_from_params, executor=exe)
if args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, test_prog)
fluid.load(test_prog, args.init_from_pretrain_model, executor=exe)
compiled_test_prog = fluid.CompiledProgram(test_prog)
......@@ -94,7 +92,7 @@ def do_predict(args):
place=place, phase="test", shuffle=False, sample_pro=1)
num_test_examples = processor.get_num_examples(phase='test')
data_reader.decorate_batch_generator(batch_generator)
data_reader.set_batch_generator(batch_generator, places=place)
data_reader.start()
scores = []
......@@ -110,7 +108,7 @@ def do_predict(args):
print("Write the predicted results into the output_prediction_file")
fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
for index, score in enumerate(scores):
fw.write("%s\t%s\n" % (index, score))
fw.write(u"%s\t%s\n" % (index, score[0]))
print("finish........................................")
......
......@@ -67,7 +67,6 @@ function pretrain_train()
--loss_type="CLS" \
--max_seq_len=50 \
--save_model_path=${pretrain_model_path} \
--save_param="params" \
--training_file="${INPUT_PATH}/unlabel_data/train.ids" \
--epoch=20 \
--print_step=1 \
......@@ -99,9 +98,8 @@ function finetuning_train()
--use_cuda=${1} \
--loss_type="L2" \
--max_seq_len=50 \
--init_from_pretrain_model="${SAVED_MODELS}/matching_pretrained/params/step_final" \
--init_from_pretrain_model="${SAVED_MODELS}/matching_pretrained/step_final" \
--save_model_path=${save_model_path} \
--save_param="params" \
--training_file="${INPUT_PATH}/label_data/${2}/train.ids" \
--epoch=50 \
--print_step=1 \
......@@ -121,7 +119,7 @@ function pretrain_predict()
--do_predict=true \
--use_cuda=${1} \
--predict_file="${INPUT_PATH}/unlabel_data/test.ids" \
--init_from_params="${SAVED_MODELS}/trained_models/matching_pretrained/params" \
--init_from_params="${SAVED_MODELS}/trained_models/matching_pretrained/params/params" \
--loss_type="CLS" \
--output_prediction_file="${OUTPUT_PATH}/pretrain_matching_predict" \
--max_seq_len=50 \
......@@ -137,7 +135,7 @@ function finetuning_predict()
--do_predict=true \
--use_cuda=${1} \
--predict_file="${INPUT_PATH}/label_data/${2}/test.ids" \
--init_from_params=${SAVED_MODELS}/trained_models/${2}_finetuned/params \
--init_from_params="${SAVED_MODELS}/trained_models/${2}_finetuned/params/params" \
--loss_type="L2" \
--output_prediction_file="${OUTPUT_PATH}/finetuning_${2}_predict" \
--max_seq_len=50 \
......
......@@ -29,7 +29,6 @@ from ade_net import create_net, set_word_embedding
from ade.utils.configure import PDConfig
from ade.utils.input_field import InputField
from ade.utils.model_check import check_cuda
import ade.utils.save_load_io as save_load_io
try:
import cPickle as pickle #python 2
......@@ -62,24 +61,27 @@ def do_train(args):
input_inst = [context_wordseq, response_wordseq, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(
data_reader = fluid.io.DataLoader.from_generator(
feed_list=input_inst, capacity=4, iterable=False)
loss = create_net(
is_training=True, model_input=input_field, args=args)
loss.persistable = True
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=args.learning_rate,
grad_clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
optimizer.minimize(loss)
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
places = fluid.cuda_places()
dev_count = len(places)
place = fluid.CUDAPlace(
int(os.getenv('FLAGS_selected_gpus', '0')))
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
places = fluid.cpu_places()
dev_count = len(places)
place = fluid.CPUPlace()
processor = reader.DataProcessor(
......@@ -99,20 +101,20 @@ def do_train(args):
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
data_reader.decorate_batch_generator(batch_generator)
data_reader.set_batch_generator(batch_generator, places=place)
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_checkpoint == "") or (
assert (args.init_from_params == "") or (
args.init_from_pretrain_model == "")
#init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
save_load_io.init_from_checkpoint(args, exe, train_prog)
if args.init_from_params:
fluid.load(train_prog, args.init_from_params, exe)
#init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, train_prog)
fluid.load(train_prog, args.init_from_pretrain_model, exe)
if args.word_emb_init:
print("start loading word embedding init ...")
......@@ -163,21 +165,17 @@ def do_train(args):
time_begin = time.time()
if steps % args.save_steps == 0:
if args.save_checkpoint:
save_load_io.save_checkpoint(args, exe, train_prog,
"step_" + str(steps))
if args.save_param:
save_load_io.save_param(args, exe, train_prog,
"step_" + str(steps))
model_path = os.path.join(args.save_model_path,
"step_" + str(steps))
fluid.save(train_prog, model_path)
steps += 1
except fluid.core.EOFException:
data_reader.reset()
break
if args.save_checkpoint:
save_load_io.save_checkpoint(args, exe, train_prog, "step_final")
if args.save_param:
save_load_io.save_param(args, exe, train_prog, "step_final")
model_path = os.path.join(args.save_model_path, "step_final")
fluid.save(train_prog, model_path)
def get_cards():
num = 0
......
......@@ -23,7 +23,7 @@
- Python >= 2.7
- cuda >= 9.0
- cudnn >= 7.0
- PaddlePaddle >= 1.7.0,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装, 由于模块内模型基于bert做finetuning, 训练速度较慢, 建议用户安装GPU版本PaddlePaddle进行训练。
- PaddlePaddle >= 1.8.0,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装, 由于模块内模型基于bert做finetuning, 训练速度较慢, 建议用户安装GPU版本PaddlePaddle进行训练。
#### &ensp;&ensp;b、下载代码
......@@ -123,13 +123,10 @@ format:conversation_content \t question \1 answer \t state1 state2 state3.....
task_name: 任务名称,可选udc、swda、mrda、atis_intent、atis_slot、dstc2
data_dir: 数据路径,如./data/input/data/udc
bert_config_path: 预训练模型bert的网络配置./data/pretrain_model/uncased_L-12_H-768_A-12/bert_config.json
init_from_checkpoint: 加载断点模型
init_from_params: 训练好的模型参数文件,一般用于预测
init_from_pretrain_model: 预训练模型路径,如bert的模型参数
inference_model_dir: inference model的保存路径
save_model_path: 训练产出模型的输出路径
save_checkpoint: 调用paddle的io接口save_persistables(把传入的层中所有参数以及优化器进行保存)来保存模型参数
save_param: 调用paddle的io接口save_params(从main_program中取出所有参数然后保存到文件中)来保存模型参数
lr_scheduler: learning rate scheduler
weight_decay: learning rate 权重衰减因子
warmup_proportion: warmup比率
......@@ -221,7 +218,6 @@ python -u main.py \
--vocab_path="${BERT_BASE_PATH}/vocab.txt" \
--init_from_pretrain_model="${BERT_BASE_PATH}/params" \
--save_model_path="./data/saved_models/${TASK_NAME}" \
--save_param="params" \
--save_steps=100 \
--learning_rate=2e-5 \
--weight_decay=0.01 \
......@@ -235,7 +231,7 @@ python -u main.py \
#### windows环境下
```
python -u main.py --task_name=atis_intent --use_cuda=false --do_train=true --epoch=20 --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --init_from_pretrain_model=data\pretrain_model\uncased_L-12_H-768_A-12\params --save_model_path=data\saved_models\atis_intent --save_param=params --save_steps=100 --learning_rate=2e-5 --weight_decay=0.01 --max_seq_len=128 --print_steps=10
python -u main.py --task_name=atis_intent --use_cuda=false --do_train=true --epoch=20 --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --init_from_pretrain_model=data\pretrain_model\uncased_L-12_H-768_A-12\params --save_model_path=data\saved_models\atis_intent --save_steps=100 --learning_rate=2e-5 --weight_decay=0.01 --max_seq_len=128 --print_steps=10
```
### 模型预测
......@@ -294,7 +290,7 @@ python -u main.py \
--batch_size=32 \
--do_lower_case=true \
--data_dir="./data/input/data/atis/${TASK_NAME}" \
--init_from_params="./data/saved_models/trained_models/${TASK_NAME}/params" \
--init_from_params="./data/saved_models/trained_models/${TASK_NAME}/params/params" \
--bert_config_path="${BERT_BASE_PATH}/bert_config.json" \
--vocab_path="${BERT_BASE_PATH}/vocab.txt" \
--output_prediction_file="./data/output/pred_${TASK_NAME}" \
......@@ -305,7 +301,7 @@ python -u main.py \
#### windows环境下
```
python -u main.py --task_name=atis_intent --use_cuda=false --do_predict=true --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --init_from_params=data\saved_models\trained_models\atis_intent\params --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --output_prediction_file=data\output\pred_atis_intent --max_seq_len=128
python -u main.py --task_name=atis_intent --use_cuda=false --do_predict=true --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --init_from_params=data\saved_models\trained_models\atis_intent\params\params --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --output_prediction_file=data\output\pred_atis_intent --max_seq_len=128
```
### 模型评估
......
task_name: ""
data_dir: ""
bert_config_path: ""
init_from_checkpoint: ""
init_from_params: ""
init_from_pretrain_model: ""
inference_model_dir: ""
save_model_path: ""
save_checkpoint: ""
save_param: ""
lr_scheduler: "linear_warmup_decay"
weight_decay: 0.01
warmup_proportion: 0.1
......
......@@ -87,21 +87,21 @@ class BertModel(object):
def _build_model(self, src_ids, position_ids, sentence_ids, input_mask):
# padding id in vocabulary must be set to 0
emb_out = fluid.input.embedding(
emb_out = fluid.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
position_emb_out = fluid.input.embedding(
position_emb_out = fluid.embedding(
input=position_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.input.embedding(
sent_emb_out = fluid.embedding(
sentence_ids,
size=[self._sent_types, self._emb_size],
dtype=self._dtype,
......
......@@ -48,8 +48,8 @@ class Paradigm(object):
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
if not params['is_training']:
if not params['is_training']:
probs = fluid.layers.softmax(logits)
results = {"probs": probs}
return results
......
......@@ -17,7 +17,6 @@ import re
import sys
import numpy as np
import paddle
import paddle.fluid as fluid
......
......@@ -59,7 +59,13 @@ def optimization(loss,
weight_decay,
scheduler='linear_warmup_decay',
use_fp16=False,
loss_scaling=1.0):
loss_scaling=1.0,
clip_norm_thres=1.0):
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
if use_fp16 and loss_scaling > 1.0:
clip_norm_thres *= loss_scaling
if warmup_steps > 0:
if scheduler == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler\
......@@ -71,19 +77,17 @@ def optimization(loss,
else:
raise ValueError("Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'")
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=scheduled_lr,
grad_clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=clip_norm_thres))
else:
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=learning_rate,
grad_clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=clip_norm_thres))
scheduled_lr = learning_rate
clip_norm_thres = 1.0
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
if use_fp16 and loss_scaling > 1.0:
clip_norm_thres *= loss_scaling
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres))
def exclude_from_weight_decay(name):
if name.find("layer_norm") > -1:
return True
......
......@@ -29,7 +29,7 @@ DATA_MODEL_PATH = {
"DATA_PATH": "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz",
"PRETRAIN_MODEL":
"https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz",
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/dgu_models_2.0.0.tar.gz"
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/dgu_models_3.0.0.tar.gz"
}
PATH_MAP = {
......
......@@ -25,7 +25,6 @@ import paddle.fluid as fluid
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
import dgu.reader as reader
from dgu_net import create_net
......@@ -97,12 +96,10 @@ def do_save_inference_model(args):
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_params) or (args.init_from_pretrain_model)
assert (args.init_from_params)
if args.init_from_params:
save_load_io.init_from_params(args, exe, test_prog)
elif args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, test_prog)
fluid.load(test_prog, args.init_from_params)
# saving inference model
fluid.io.save_inference_model(
......
......@@ -16,7 +16,6 @@ import os
import sys
import numpy as np
import paddle
import paddle.fluid as fluid
from eval import do_eval
......
......@@ -19,7 +19,6 @@ import sys
import numpy as np
import argparse
import collections
import paddle
import paddle.fluid as fluid
import dgu.reader as reader
......@@ -30,7 +29,6 @@ import dgu.define_predict_pack as define_predict_pack
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
from dgu.utils.py23 import tab_tok, rt_tok
......@@ -84,7 +82,7 @@ def do_predict(args):
input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(
data_reader = fluid.io.DataLoader.from_generator(
feed_list=input_inst, capacity=4, iterable=False)
results = create_net(
......@@ -95,9 +93,6 @@ def do_predict(args):
args=args)
probs = results.get("probs", None)
probs.persistable = True
fetch_list = [probs.name]
#for_test is True if change the is_test attribute of operators to True
......@@ -111,12 +106,10 @@ def do_predict(args):
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_params) or (args.init_from_pretrain_model)
assert (args.init_from_params)
if args.init_from_params:
save_load_io.init_from_params(args, exe, test_prog)
if args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, test_prog)
fluid.load(test_prog, args.init_from_params)
compiled_test_prog = fluid.CompiledProgram(test_prog)
......@@ -130,7 +123,7 @@ def do_predict(args):
batch_generator = processor.data_generator(
batch_size=args.batch_size, phase='test', shuffle=False)
data_reader.decorate_batch_generator(batch_generator)
data_reader.set_batch_generator(batch_generator, places=place)
data_reader.start()
all_results = []
......
......@@ -3,7 +3,7 @@
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=
if [ ! "$CUDA_VISIBLE_DEVICES" ]
then
export CPU_NUM=1
......@@ -21,7 +21,7 @@ SAVE_MODEL_PATH="./data/saved_models/${TASK_NAME}"
TRAIN_MODEL_PATH="./data/saved_models/trained_models"
OUTPUT_PATH="./data/output"
INFERENCE_MODEL="data/inference_models"
PYTHON_PATH="python3"
PYTHON_PATH="python"
if [ -f ${SAVE_MODEL_PATH} ]; then
rm ${SAVE_MODEL_PATH}
......@@ -94,7 +94,6 @@ else
exit 255
fi
#training
function train()
{
......@@ -110,7 +109,6 @@ function train()
--vocab_path=${BERT_BASE_PATH}/vocab.txt \
--init_from_pretrain_model=${BERT_BASE_PATH}/params \
--save_model_path=${SAVE_MODEL_PATH} \
--save_param="params" \
--save_steps=${save_steps} \
--learning_rate=${learning_rate} \
--weight_decay=0.01 \
......@@ -128,7 +126,7 @@ function predict()
--batch_size=${batch_size} \
--data_dir=${INPUT_PATH} \
--do_lower_case=true \
--init_from_params=${TRAIN_MODEL_PATH}/${TASK_NAME}/params \
--init_from_params=${TRAIN_MODEL_PATH}/${TASK_NAME}/params/params \
--bert_config_path=${BERT_BASE_PATH}/bert_config.json \
--vocab_path=${BERT_BASE_PATH}/vocab.txt \
--output_prediction_file=${OUTPUT_PATH}/pred_${TASK_NAME} \
......
......@@ -22,7 +22,6 @@ import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from dgu_net import create_net
......@@ -32,7 +31,6 @@ import dgu.define_paradigm as define_paradigm
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
def do_train(args):
......@@ -80,8 +78,9 @@ def do_train(args):
input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.DataLoader.from_generator(feed_list=input_inst, capacity=4, iterable=False)
data_reader = fluid.io.DataLoader.from_generator(
feed_list=input_inst, capacity=4, iterable=False)
processor = processors[task_name](data_dir=args.data_dir,
vocab_path=args.vocab_path,
......@@ -103,13 +102,8 @@ def do_train(args):
accuracy = results.get("accuracy", None)
num_seqs = results.get("num_seqs", None)
loss.persistable = True
probs.persistable = True
if accuracy:
accuracy.persistable = True
num_seqs.persistable = True
places = fluid.cuda_places() if args.use_cuda else fluid.cpu_places()
places = fluid.cuda_places() if args.use_cuda else fluid.cpu_places(
)
dev_count = len(places)
batch_generator = processor.data_generator(
......@@ -149,16 +143,13 @@ def do_train(args):
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_checkpoint == "") or (
args.init_from_pretrain_model == "")
assert args.init_from_params or args.init_from_pretrain_model
# init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
save_load_io.init_from_checkpoint(args, exe, train_prog)
# init from some pretrain models, to better solve the current task
if args.init_from_params:
fluid.load(train_prog, args.init_from_params, exe)
if args.init_from_pretrain_model:
save_load_io.init_from_pretrain_model(args, exe, train_prog)
fluid.load(train_prog, args.init_from_pretrain_model, exe)
build_strategy = fluid.compiler.BuildStrategy()
build_strategy.enable_inplace = True
......@@ -234,21 +225,16 @@ def do_train(args):
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = "step_" + str(steps)
if args.save_checkpoint:
save_load_io.save_checkpoint(args, exe, train_prog,
save_path)
if args.save_param:
save_load_io.save_param(args, exe, train_prog,
save_path)
model_path = os.path.join(args.save_model_path,
"step_" + str(steps))
fluid.save(train_prog, model_path)
except fluid.core.EOFException:
data_reader.reset()
break
if args.save_checkpoint:
save_load_io.save_checkpoint(args, exe, train_prog, "step_final")
if args.save_param:
save_load_io.save_param(args, exe, train_prog, "step_final")
model_path = os.path.join(args.save_model_path, "step_final")
fluid.save(train_prog, model_path)
def get_cards():
num = 0
......
......@@ -105,15 +105,15 @@ def create_pyreader(args,
# create lac pyreader
if mode == 'train':
pyreader.set_sample_list_generator(
fluid.io.batch(
fluid.io.shuffle(
paddle.batch(
paddle.reader.shuffle(
reader.file_reader(file_name),
buf_size=args.traindata_shuffle_buffer),
batch_size=args.batch_size / device_count),
places=place)
else:
pyreader.set_sample_list_generator(
fluid.io.batch(
paddle.batch(
reader.file_reader(
file_name, mode=mode),
batch_size=args.batch_size / device_count),
......
......@@ -116,11 +116,11 @@ def do_train(args):
feed_list=train_ret['feed_list'],
model="ernie",
place=place)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
optimizer = fluid.optimizer.Adam(
learning_rate=args.base_learning_rate)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
learning_rate=args.base_learning_rate,
grad_clip=clip)
optimizer.minimize(train_ret["avg_cost"])
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
......
......@@ -97,14 +97,14 @@ def main():
dropout=dropout)
loss = model.build_graph()
inference_program = train_program.clone(for_test=True)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm))
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm)
lr = args.learning_rate
opt_type = args.optimizer
if opt_type == "sgd":
optimizer = fluid.optimizer.SGD(lr)
optimizer = fluid.optimizer.SGD(lr, grad_clip=clip)
elif opt_type == "adam":
optimizer = fluid.optimizer.Adam(lr)
optimizer = fluid.optimizer.Adam(lr, grad_clip=clip)
else:
print("only support [sgd|adam]")
raise Exception("opt type not support")
......
......@@ -229,7 +229,7 @@ class VAE(object):
# `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
# it will be used within BeamSearchDecoder
sample_output_layer = lambda x: layers.unsqueeze(layers.one_hot(
sample_output_layer = lambda x: layers.unsqueeze(fluid.one_hot(
layers.unsqueeze(
layers.sampling_id(
layers.softmax(
......
......@@ -89,9 +89,8 @@ def main():
inference_program = fluid.default_main_program().clone(
for_test=True)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm))
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm)
learning_rate = fluid.layers.create_global_var(
name="learning_rate",
......@@ -102,9 +101,9 @@ def main():
opt_type = args.optimizer
if opt_type == "sgd":
optimizer = fluid.optimizer.SGD(learning_rate)
optimizer = fluid.optimizer.SGD(learning_rate, grad_clip=clip)
elif opt_type == "adam":
optimizer = fluid.optimizer.Adam(learning_rate)
optimizer = fluid.optimizer.Adam(learning_rate, grad_clip=clip)
else:
print("only support [sgd|adam]")
raise Exception("opt type not support")
......@@ -272,7 +271,7 @@ def main():
(old_lr, new_lr))
dir_name = args.model_path + "/epoch_" + str(best_epoch_id)
fluid.io.load_params(exe, dir_name)
fluid.load(main_program, dir_name, exe)
decay_cnt += 1
if decay_cnt == max_decay:
......
......@@ -50,7 +50,7 @@ def infer():
out_path = args.output + "/single"
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
cycle_gan = Cycle_Gan(3)
save_dir = args.init_model
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
......
......@@ -50,7 +50,7 @@ def test():
out_path = args.output + "/eval" + "/" + str(epoch)
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
cycle_gan = Cycle_Gan(3)
save_dir = args.init_model + str(epoch)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
......
......@@ -44,7 +44,7 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.")
lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
tep_per_epoch = 2974
step_per_epoch = 2974
def optimizer_setting(parameters):
......@@ -90,7 +90,8 @@ def train(args):
losses = [[], []]
t_time = 0
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters(
) + cycle_gan.build_generator_resnet_9blocks_b.parameters()
vars_da = cycle_gan.build_gen_discriminator_a.parameters()
vars_db = cycle_gan.build_gen_discriminator_b.parameters()
......
......@@ -128,6 +128,7 @@ def train_mobilenet():
test_data_loader.set_sample_list_generator(test_reader, place)
# 4. train loop
total_batch_num = 0 #this is for benchmark
for eop in range(args.num_epochs):
if num_trainers > 1:
imagenet_reader.set_shuffle_seed(eop + (
......@@ -142,6 +143,8 @@ def train_mobilenet():
# 4.1 for each batch, call net() , backward(), and minimize()
for img, label in train_data_loader():
t1 = time.time()
if args.max_iter and total_batch_num == args.max_iter:
return
label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size // place_num), 1))
t_start = time.time()
......@@ -185,6 +188,10 @@ def train_mobilenet():
total_sample += 1
batch_id += 1
t_last = time.time()
# NOTE: used for benchmark
total_batch_num = total_batch_num + 1
if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
......
......@@ -117,6 +117,10 @@ def parse_args():
add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# NOTE: used for benchmark
add_arg('max_iter', int, 0, "The number of total train max_iters.")
# READER AND PREPROCESS
add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop")
add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册