未验证 提交 957b3d13 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quant train TrainReader error (#1638)

上级 eb6d1196
......@@ -60,7 +60,10 @@ def main():
"Currently only supports `--eval==True` while training in `quantization`."
)
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
and 'PADDLE_TRAINERS_NUM' in env \
and int(env['PADDLE_TRAINERS_NUM']) > 1
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
......@@ -221,8 +224,11 @@ def main():
start_iter = 0
train_reader = create_reader(cfg.TrainReader,
(cfg.max_iters - start_iter) * devices_num)
train_reader = create_reader(
cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
cfg,
devices_num=devices_num,
num_trainers=num_trainers)
# When iterable mode, set set_sample_list_generator(train_reader, place)
train_loader.set_sample_list_generator(train_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册