提交 7713736e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix bs

上级 4338730e
...@@ -311,14 +311,13 @@ def compute(feeds, net, config, mode='train'): ...@@ -311,14 +311,13 @@ def compute(feeds, net, config, mode='train'):
def create_feeds(batch, use_mix): def create_feeds(batch, use_mix):
image = to_variable(batch[0].numpy().astype("float32"))
if use_mix: if use_mix:
image = batch[0]
y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1)) y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1))
y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1)) y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1))
lam = to_variable(batch[3].numpy().astype("float32").reshape(-1, 1)) lam = to_variable(batch[3].numpy().astype("float32").reshape(-1, 1))
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam} feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
else: else:
image = batch[0]
label = to_variable(batch[1].numpy().astype('int64').reshape(-1, 1)) label = to_variable(batch[1].numpy().astype('int64').reshape(-1, 1))
feeds = {"image": image, "label": label} feeds = {"image": image, "label": label}
return feeds return feeds
...@@ -359,7 +358,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'): ...@@ -359,7 +358,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader()): for idx, batch in enumerate(dataloader()):
bs = len(batch[0]) batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix) feeds = create_feeds(batch, use_mix)
fetchs = compute(feeds, net, config, mode) fetchs = compute(feeds, net, config, mode)
if mode == 'train': if mode == 'train':
...@@ -370,10 +369,10 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'): ...@@ -370,10 +369,10 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
net.clear_gradients() net.clear_gradients()
metric_list['lr'].update( metric_list['lr'].update(
optimizer._global_learning_rate().numpy()[0], bs) optimizer._global_learning_rate().numpy()[0], batch_size)
for name, fetch in fetchs.items(): for name, fetch in fetchs.items():
metric_list[name].update(fetch.numpy()[0], bs) metric_list[name].update(fetch.numpy()[0], batch_size)
metric_list['batch_time'].update(time.time() - tic) metric_list['batch_time'].update(time.time() - tic)
tic = time.time() tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册