提交 4780f758 编写于 作者: J JiabinYang

merge reyoung reader opt

上级 9eaab43a
...@@ -125,14 +125,44 @@ def parse_args(): ...@@ -125,14 +125,44 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def convert_python_to_tensor(batch_size, sample_reader):
def __reader__():
result = [[], [], [], []]
for sample in sample_reader():
for i, fea in enumerate(sample):
result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
t = fluid.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, fluid.CPUPlace())
tensor_result.append(t)
yield tensor_result
result = [[], [], [], []]
return __reader__
def train_loop(args, train_program, reader, py_reader, loss, trainer_id): def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
train_reader = paddle.batch( # train_reader = paddle.batch(
paddle.reader.shuffle( # paddle.reader.shuffle(
reader.train((args.with_hs or (not args.with_nce))), # reader.train((args.with_hs or (not args.with_nce))),
buf_size=args.batch_size * 100), # buf_size=args.batch_size * 100),
batch_size=args.batch_size) # batch_size=args.batch_size)
# py_reader.decorate_paddle_reader(train_reader)
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_tensor_provider(
convert_python_to_tensor(args.batch_size,
reader.train((args.with_hs or (
not args.with_nce)))))
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -140,6 +170,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -140,6 +170,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
print("CPU_NUM:" + str(os.getenv("CPU_NUM"))) print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
exec_strategy.num_threads = int(os.getenv("CPU_NUM")) exec_strategy.num_threads = int(os.getenv("CPU_NUM"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册