提交 7713736e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix bs

上级 4338730e
......@@ -311,14 +311,13 @@ def compute(feeds, net, config, mode='train'):
def create_feeds(batch, use_mix):
image = to_variable(batch[0].numpy().astype("float32"))
if use_mix:
image = batch[0]
y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1))
y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1))
lam = to_variable(batch[3].numpy().astype("float32").reshape(-1, 1))
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
else:
image = batch[0]
label = to_variable(batch[1].numpy().astype('int64').reshape(-1, 1))
feeds = {"image": image, "label": label}
return feeds
......@@ -359,7 +358,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
tic = time.time()
for idx, batch in enumerate(dataloader()):
bs = len(batch[0])
batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix)
fetchs = compute(feeds, net, config, mode)
if mode == 'train':
......@@ -370,10 +369,10 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
optimizer.minimize(avg_loss)
net.clear_gradients()
metric_list['lr'].update(
optimizer._global_learning_rate().numpy()[0], bs)
optimizer._global_learning_rate().numpy()[0], batch_size)
for name, fetch in fetchs.items():
metric_list[name].update(fetch.numpy()[0], bs)
metric_list[name].update(fetch.numpy()[0], batch_size)
metric_list['batch_time'].update(time.time() - tic)
tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册