提交 27e0706c 编写于 作者: u010070587's avatar u010070587 提交者: kolinwei

update ce (#4171)

上级 02a5b86e
...@@ -401,7 +401,7 @@ class SPADE(object): ...@@ -401,7 +401,7 @@ class SPADE(object):
0], batch_time)) 0], batch_time))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1
if self.cfg.run_test: if self.cfg.run_test:
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
image_name = fluid.data( image_name = fluid.data(
......
...@@ -97,7 +97,10 @@ def preprocess(img, bbox_labels, mode, settings, image_path): ...@@ -97,7 +97,10 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
# sampling # sampling
batch_sampler = [] batch_sampler = []
# used for continuous evaluation
if 'ce_mode' in os.environ:
random.seed(0)
np.random.seed(0)
prob = np.random.uniform(0., 1.) prob = np.random.uniform(0., 1.)
if prob > settings.data_anchor_sampling_prob: if prob > settings.data_anchor_sampling_prob:
scale_array = np.array([16, 32, 64, 128, 256, 512]) scale_array = np.array([16, 32, 64, 128, 256, 512])
...@@ -229,7 +232,7 @@ def expand_bboxes(bboxes, ...@@ -229,7 +232,7 @@ def expand_bboxes(bboxes,
def train_generator(settings, file_list, batch_size, shuffle=True): def train_generator(settings, file_list, batch_size, shuffle=True):
def reader(): def reader():
if shuffle: if shuffle and 'ce_mode' not in os.environ:
np.random.shuffle(file_list) np.random.shuffle(file_list)
batch_out = [] batch_out = []
for item in file_list: for item in file_list:
......
...@@ -150,6 +150,7 @@ def train(args, config, train_params, train_file_list): ...@@ -150,6 +150,7 @@ def train(args, config, train_params, train_file_list):
#only for ce #only for ce
if args.enable_ce: if args.enable_ce:
is_shuffle = False
SEED = 102 SEED = 102
startup_prog.random_seed = SEED startup_prog.random_seed = SEED
train_prog.random_seed = SEED train_prog.random_seed = SEED
......
...@@ -119,7 +119,7 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len): ...@@ -119,7 +119,7 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
Reader function Reader function
""" """
for idx in range(epoch): for idx in range(epoch):
if phrase == "train": if phrase == "train" and 'ce_mode' not in os.environ:
random.shuffle(all_data) random.shuffle(all_data)
for wids, label, seq_len in all_data: for wids, label, seq_len in all_data:
yield wids, label, seq_len yield wids, label, seq_len
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册