diff --git a/demo/distillation/distill.py b/demo/distillation/distill.py index 9ea119d1aa8722da127969a05c6b29e6a323e921..3bafa159ea95690198f34e62707caf53f64d0bf6 100644 --- a/demo/distillation/distill.py +++ b/demo/distillation/distill.py @@ -164,8 +164,12 @@ def compress(args): ), "teacher_pretrained_model should be set when teacher_model is not None." def if_exist(var): - return os.path.exists( + exist = os.path.exists( os.path.join(args.teacher_pretrained_model, var.name)) + if args.data == "cifar10" and (var.name == 'fc_0.w_0' or + var.name == 'fc_0.b_0'): + exist = False + return exist fluid.io.load_vars( exe,