diff --git a/tools/train.py b/tools/train.py index 828f548185ff078151585f13db6f28bf45b3d257..7f058c5281e1475927352286260f76b9ed0f584d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -53,7 +53,9 @@ logger = logging.getLogger(__name__) def main(): env = os.environ - FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env + FLAGS.dist = 'PADDLE_TRAINER_ID' in env \ + and 'PADDLE_TRAINERS_NUM' in env \ + and int(env['PADDLE_TRAINERS_NUM']) > 1 if FLAGS.dist: trainer_id = int(env['PADDLE_TRAINER_ID']) local_seed = (99 + trainer_id)