提交 38388a5e 编写于 作者: H hutuxian 提交者: Yi Liu

use py_reader and support multi-card training (#2410)

* use py_reader and support multi-card training

* update README
上级 1ba79a58
...@@ -76,11 +76,21 @@ gpu 单机单卡训练 ...@@ -76,11 +76,21 @@ gpu 单机单卡训练
CUDA_VISIBLE_DEVICES=1 python -u train.py --use_cuda 1 > log.txt 2>&1 & CUDA_VISIBLE_DEVICES=1 python -u train.py --use_cuda 1 > log.txt 2>&1 &
``` ```
gpu 单机多卡训练
``` bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --use_cuda 1 > log.txt 2>&1 &
```
cpu 单机训练 cpu 单机训练
``` bash ``` bash
CPU_NUM=1 python -u train.py --use_cuda 0 > log.txt 2>&1 & CPU_NUM=1 python -u train.py --use_cuda 0 > log.txt 2>&1 &
``` ```
cpu 单机多CPU训练
``` bash
CPU_NUM=5 python -u train.py --use_cuda 0 > log.txt 2>&1 &
```
值得注意的是上述单卡训练可以通过加--use_parallel 1参数使用Parallel Executor来进行加速。 值得注意的是上述单卡训练可以通过加--use_parallel 1参数使用Parallel Executor来进行加速。
......
...@@ -59,7 +59,7 @@ def infer(epoch_num): ...@@ -59,7 +59,7 @@ def infer(epoch_num):
loss_sum = 0.0 loss_sum = 0.0
acc_sum = 0.0 acc_sum = 0.0
count = 0 count = 0
for data in test_data.reader(batch_size, batch_size, False): for data in test_data.reader(batch_size, batch_size, False)():
res = exe.run(infer_program, res = exe.run(infer_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=fetch_targets) fetch_list=fetch_targets)
......
...@@ -58,6 +58,12 @@ def network(batch_size, items_num, hidden_size, step): ...@@ -58,6 +58,12 @@ def network(batch_size, items_num, hidden_size, step):
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
datas = [items, seq_index, last_index, adj_in, adj_out, mask, label]
py_reader = fluid.layers.create_py_reader_by_data(
capacity=256, feed_list=datas, name='py_reader', use_double_buffer=True)
feed_datas = fluid.layers.read_file(py_reader)
items, seq_index, last_index, adj_in, adj_out, mask, label = feed_datas
items_emb = layers.embedding( items_emb = layers.embedding(
input=items, input=items,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
...@@ -171,7 +177,7 @@ def network(batch_size, items_num, hidden_size, step): ...@@ -171,7 +177,7 @@ def network(batch_size, items_num, hidden_size, step):
[global_attention, last], axis=1) #[batch_size, 2*h] [global_attention, last], axis=1) #[batch_size, 2*h]
final_attention_fc = layers.fc( final_attention_fc = layers.fc(
input=final_attention, input=final_attention,
name="fina_attention_fc", name="final_attention_fc",
size=hidden_size, size=hidden_size,
bias_attr=False, bias_attr=False,
act=None, act=None,
...@@ -200,4 +206,4 @@ def network(batch_size, items_num, hidden_size, step): ...@@ -200,4 +206,4 @@ def network(batch_size, items_num, hidden_size, step):
logits=logits, label=label) #[batch_size, 1] logits=logits, label=label) #[batch_size, 1]
loss = layers.reduce_mean(softmax) # [1] loss = layers.reduce_mean(softmax) # [1]
acc = layers.accuracy(input=logits, label=label, k=20) acc = layers.accuracy(input=logits, label=label, k=20)
return loss, acc return loss, acc, py_reader, feed_datas
...@@ -76,7 +76,7 @@ class Data(): ...@@ -76,7 +76,7 @@ class Data():
seq_index = np.array(seq_index).astype("int32").reshape( seq_index = np.array(seq_index).astype("int32").reshape(
(batch_size, -1)) (batch_size, -1))
last_index = np.array(last_index).astype("int32").reshape( last_index = np.array(last_index).astype("int32").reshape(
(batch_size, 1)) (batch_size))
adj_in = np.array(adj_in).astype("float32").reshape( adj_in = np.array(adj_in).astype("float32").reshape(
(batch_size, max_uniq_len, max_uniq_len)) (batch_size, max_uniq_len, max_uniq_len))
adj_out = np.array(adj_out).astype("float32").reshape( adj_out = np.array(adj_out).astype("float32").reshape(
...@@ -86,28 +86,30 @@ class Data(): ...@@ -86,28 +86,30 @@ class Data():
return zip(items, seq_index, last_index, adj_in, adj_out, mask, label) return zip(items, seq_index, last_index, adj_in, adj_out, mask, label)
def reader(self, batch_size, batch_group_size, train=True): def reader(self, batch_size, batch_group_size, train=True):
if self.shuffle: def _reader():
random.shuffle(self.input) if self.shuffle:
group_remain = self.length % batch_group_size random.shuffle(self.input)
for bg_id in range(0, self.length - group_remain, batch_group_size): group_remain = self.length % batch_group_size
cur_bg = copy.deepcopy(self.input[bg_id:bg_id + batch_group_size]) for bg_id in range(0, self.length - group_remain, batch_group_size):
cur_bg = copy.deepcopy(self.input[bg_id:bg_id + batch_group_size])
if train:
cur_bg = sorted(cur_bg, key=lambda x: len(x[0]), reverse=True)
for i in range(0, batch_group_size, batch_size):
cur_batch = cur_bg[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
#deal with the remaining, discard at most batch_size data
if group_remain < batch_size:
return
remain_data = copy.deepcopy(self.input[-group_remain:])
if train: if train:
cur_bg = sorted(cur_bg, key=lambda x: len(x[0]), reverse=True) remain_data = sorted(
remain_data, key=lambda x: len(x[0]), reverse=True)
for i in range(0, batch_group_size, batch_size): for i in range(0, batch_group_size, batch_size):
cur_batch = cur_bg[i:i + batch_size] if i + batch_size <= len(remain_data):
yield self.make_data(cur_batch, batch_size) cur_batch = remain_data[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
#deal with the remaining, discard at most batch_size data return _reader
if group_remain < batch_size:
return
remain_data = copy.deepcopy(self.input[-group_remain:])
if train:
remain_data = sorted(
remain_data, key=lambda x: len(x[0]), reverse=True)
for i in range(0, batch_group_size, batch_size):
if i + batch_size <= len(remain_data):
cur_batch = remain_data[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
def read_config(path): def read_config(path):
......
...@@ -71,7 +71,7 @@ def train(): ...@@ -71,7 +71,7 @@ def train():
batch_size = args.batch_size batch_size = args.batch_size
items_num = reader.read_config(args.config_path) items_num = reader.read_config(args.config_path)
loss, acc = network.network(batch_size, items_num, args.hidden_size, loss, acc, py_reader, feed_datas = network.network(batch_size, items_num, args.hidden_size,
args.step) args.step)
data_reader = reader.Data(args.train_path, True) data_reader = reader.Data(args.train_path, True)
...@@ -98,10 +98,7 @@ def train(): ...@@ -98,10 +98,7 @@ def train():
all_vocab.set( all_vocab.set(
np.arange(1, items_num).astype("int64").reshape((-1, 1)), place) np.arange(1, items_num).astype("int64").reshape((-1, 1)), place)
feed_list = [ feed_list = [e.name for e in feed_datas]
"items", "seq_index", "last_index", "adj_in", "adj_out", "mask", "label"
]
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
if use_parallel: if use_parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
...@@ -118,23 +115,27 @@ def train(): ...@@ -118,23 +115,27 @@ def train():
acc_sum = 0.0 acc_sum = 0.0
global_step = 0 global_step = 0
PRINT_STEP = 500 PRINT_STEP = 500
py_reader.decorate_paddle_reader(data_reader.reader(batch_size, batch_size * 20, True))
for i in range(args.epoch_num): for i in range(args.epoch_num):
epoch_sum = [] epoch_sum = []
for data in data_reader.reader(batch_size, batch_size * 20, True): py_reader.start()
res = train_exe.run(feed=feeder.feed(data), try:
fetch_list=[loss.name, acc.name]) while True:
loss_sum += res[0] res = train_exe.run(fetch_list=[loss.name, acc.name])
acc_sum += res[1] loss_sum += res[0].mean()
epoch_sum.append(res[0]) acc_sum += res[1].mean()
global_step += 1 epoch_sum.append(res[0].mean())
if global_step % PRINT_STEP == 0: global_step += 1
ce_info.append([loss_sum / PRINT_STEP, acc_sum / PRINT_STEP]) if global_step % PRINT_STEP == 0:
total_time.append(time.time() - start_time) ce_info.append([loss_sum / PRINT_STEP, acc_sum / PRINT_STEP])
logger.info("global_step: %d, loss: %.4lf, train_acc: %.4lf" % ( total_time.append(time.time() - start_time)
global_step, loss_sum / PRINT_STEP, acc_sum / PRINT_STEP)) logger.info("global_step: %d, loss: %.4lf, train_acc: %.4lf" % (
loss_sum = 0.0 global_step, loss_sum / PRINT_STEP, acc_sum / PRINT_STEP))
acc_sum = 0.0 loss_sum = 0.0
start_time = time.time() acc_sum = 0.0
start_time = time.time()
except fluid.core.EOFException:
py_reader.reset()
logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum))) logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum)))
save_dir = args.model_path + "/epoch_" + str(i) save_dir = args.model_path + "/epoch_" + str(i)
fetch_vars = [loss, acc] fetch_vars = [loss, acc]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册