提交 27e0706c 编写于 作者: u010070587's avatar u010070587 提交者: kolinwei

update ce (#4171)

上级 02a5b86e
......@@ -401,7 +401,7 @@ class SPADE(object):
0], batch_time))
sys.stdout.flush()
batch_id += 1
if self.cfg.run_test:
test_program = gen_trainer.infer_program
image_name = fluid.data(
......
......@@ -97,7 +97,10 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
# sampling
batch_sampler = []
# used for continuous evaluation
if 'ce_mode' in os.environ:
random.seed(0)
np.random.seed(0)
prob = np.random.uniform(0., 1.)
if prob > settings.data_anchor_sampling_prob:
scale_array = np.array([16, 32, 64, 128, 256, 512])
......@@ -229,7 +232,7 @@ def expand_bboxes(bboxes,
def train_generator(settings, file_list, batch_size, shuffle=True):
def reader():
if shuffle:
if shuffle and 'ce_mode' not in os.environ:
np.random.shuffle(file_list)
batch_out = []
for item in file_list:
......
......@@ -150,6 +150,7 @@ def train(args, config, train_params, train_file_list):
#only for ce
if args.enable_ce:
is_shuffle = False
SEED = 102
startup_prog.random_seed = SEED
train_prog.random_seed = SEED
......
......@@ -119,7 +119,7 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
Reader function
"""
for idx in range(epoch):
if phrase == "train":
if phrase == "train" and 'ce_mode' not in os.environ:
random.shuffle(all_data)
for wids, label, seq_len in all_data:
yield wids, label, seq_len
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册