提交 8eb28880 编写于 作者: X xiegegege 提交者: wuzewu

add ce for PaddleSeg (#137)

上级 10fa9405
...@@ -24,6 +24,7 @@ os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" ...@@ -24,6 +24,7 @@ os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys import sys
import argparse import argparse
import pprint import pprint
import random
import shutil import shutil
import functools import functools
...@@ -95,6 +96,12 @@ def parse_args(): ...@@ -95,6 +96,12 @@ def parse_args():
help='See utils/config.py for all options', help='See utils/config.py for all options',
default=None, default=None,
nargs=argparse.REMAINDER) nargs=argparse.REMAINDER)
parser.add_argument(
'--enable_ce',
dest='enable_ce',
help='If set True, enable continuous evaluation job.'
'This flag is only used for internal test.',
action='store_true')
return parser.parse_args() return parser.parse_args()
...@@ -194,6 +201,9 @@ def print_info(*msg): ...@@ -194,6 +201,9 @@ def print_info(*msg):
def train(cfg): def train(cfg):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
if args.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
drop_last = True drop_last = True
dataset = SegDataset( dataset = SegDataset(
...@@ -483,6 +493,9 @@ def main(args): ...@@ -483,6 +493,9 @@ def main(args):
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
if args.opts: if args.opts:
cfg.update_from_list(args.opts) cfg.update_from_list(args.opts)
if args.enable_ce:
random.seed(0)
np.random.seed(0)
cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0)) cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册