提交 946a54e5 编写于 作者: D dengkaipeng

refine format

上级 9da4cc8b
...@@ -30,7 +30,7 @@ NUM_CLASSES = 10 ...@@ -30,7 +30,7 @@ NUM_CLASSES = 10
def make_optimizer(num_samples, parameter_list=None): def make_optimizer(num_samples, parameter_list=None):
step = int(num_samples / FLAGS.batch_size) step = int(num_samples / FLAGS.batch_size / FLAGS.num_devices)
boundaries = [e * step for e in [40, 60]] boundaries = [e * step for e in [40, 60]]
values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)] values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
...@@ -78,13 +78,13 @@ def main(): ...@@ -78,13 +78,13 @@ def main():
model.prepare( model.prepare(
optim, optim,
CrossEntropy(), CrossEntropy(),
Accuracy(topk=(1, 5)), metrics=Accuracy(topk=(1, 5)),
inputs=inputs, inputs=inputs,
labels=labels, labels=labels,
device=FLAGS.device) device=FLAGS.device)
if FLAGS.eval_only: if FLAGS.eval_only:
if FLGAS.weights: if FLGAS.weights is not None:
model.load(FLAGS.weights) model.load(FLAGS.weights)
model.evaluate( model.evaluate(
...@@ -96,8 +96,8 @@ def main(): ...@@ -96,8 +96,8 @@ def main():
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
model.fit(train_dataset, model.fit(train_data=train_dataset,
val_dataset, eval_data=val_dataset,
epochs=FLAGS.epoch, epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
save_dir='tsm_checkpoint', save_dir='tsm_checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册