提交 946a54e5 编写于 作者: D dengkaipeng

refine format

上级 9da4cc8b
......@@ -30,7 +30,7 @@ NUM_CLASSES = 10
def make_optimizer(num_samples, parameter_list=None):
step = int(num_samples / FLAGS.batch_size)
step = int(num_samples / FLAGS.batch_size / FLAGS.num_devices)
boundaries = [e * step for e in [40, 60]]
values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
......@@ -78,13 +78,13 @@ def main():
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 5)),
metrics=Accuracy(topk=(1, 5)),
inputs=inputs,
labels=labels,
device=FLAGS.device)
if FLAGS.eval_only:
if FLGAS.weights:
if FLGAS.weights is not None:
model.load(FLAGS.weights)
model.evaluate(
......@@ -96,8 +96,8 @@ def main():
if FLAGS.resume is not None:
model.load(FLAGS.resume)
model.fit(train_dataset,
val_dataset,
model.fit(train_data=train_dataset,
eval_data=val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='tsm_checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册