diff --git a/example/mobilenetv2_imagenet2012/dataset.py b/example/mobilenetv2_imagenet2012/dataset.py index 92067cd759f1b9056289009e7d048d5acebd2c37..908ce87aa12712d642af9a31780933cfc08a5f1d 100644 --- a/example/mobilenetv2_imagenet2012/dataset.py +++ b/example/mobilenetv2_imagenet2012/dataset.py @@ -63,7 +63,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): if do_train: trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] else: - trans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, change_swap_op] + trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] type_cast_op = C2.TypeCast(mstype.int32) diff --git a/example/mobilenetv2_imagenet2012/eval.py b/example/mobilenetv2_imagenet2012/eval.py index 397b3a37c332796acb6117bbbc44e30ca385f37a..0060862a4e5df3aa413bf4366381336402a582b4 100644 --- a/example/mobilenetv2_imagenet2012/eval.py +++ b/example/mobilenetv2_imagenet2012/eval.py @@ -24,6 +24,7 @@ from mindspore.model_zoo.mobilenet import mobilenet_v2 from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.common import dtype as mstype parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -39,7 +40,8 @@ context.set_context(enable_mem_reuse=True) if __name__ == '__main__': loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - net = mobilenet_v2() + net = mobilenet_v2(num_classes=config.num_classes) + net.to_float(mstype.float16) dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) step_size = dataset.get_dataset_size() diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index c12f2ef9c02ce96a5631bb354b816525f5a1e85a..d36737821c6c7aeb1b8b92e9f86b8f2a6eedca2b 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -151,7 +151,7 @@ if __name__ == '__main__': epoch_size = config.epoch_size net = mobilenet_v2(num_classes=config.num_classes) - net.add_flags_recursive(fp16=True) + net.to_float(mstype.float16) for _, cell in net.cells_and_names(): if isinstance(cell, nn.Dense): cell.add_flags_recursive(fp32=True)