未验证 提交 957b3d13 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quant train TrainReader error (#1638)

上级 eb6d1196
...@@ -60,7 +60,10 @@ def main(): ...@@ -60,7 +60,10 @@ def main():
"Currently only supports `--eval==True` while training in `quantization`." "Currently only supports `--eval==True` while training in `quantization`."
) )
env = os.environ env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
and 'PADDLE_TRAINERS_NUM' in env \
and int(env['PADDLE_TRAINERS_NUM']) > 1
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
if FLAGS.dist: if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
import random import random
...@@ -221,8 +224,11 @@ def main(): ...@@ -221,8 +224,11 @@ def main():
start_iter = 0 start_iter = 0
train_reader = create_reader(cfg.TrainReader, train_reader = create_reader(
(cfg.max_iters - start_iter) * devices_num) cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
cfg,
devices_num=devices_num,
num_trainers=num_trainers)
# When iterable mode, set set_sample_list_generator(train_reader, place) # When iterable mode, set set_sample_list_generator(train_reader, place)
train_loader.set_sample_list_generator(train_reader) train_loader.set_sample_list_generator(train_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册